Skip to content

Commit b26428b

Browse files
committed
Added rocBLAS support
1 parent 3e8bb9c commit b26428b

5 files changed

Lines changed: 248 additions & 1 deletion

File tree

configure.ac

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,48 @@ if test "x$enable_cublas" != "xno"; then
146146
fi
147147

148148

149+
# Check for the --enable-rocblas option
150+
AC_ARG_ENABLE([rocblas],
151+
[AS_HELP_STRING([--enable-rocblas],
152+
[Enable rocBLAS support for high-performance matrix operations])],
153+
[:],
154+
[enable_rocblas=no])
155+
156+
if test "x$enable_rocblas" != "xno"; then
157+
AC_MSG_NOTICE([Enabling rocBLAS support])
158+
159+
if test "x$enable_hip" != "xyes"; then
160+
AC_MSG_FAILURE([--enable-rocblas option requires compiling with HIP support])
161+
fi
162+
163+
if test "x$enable_rocblas" != "xyes"; then
164+
#If not directly "yes", treat it as a path to the cuda installation
165+
CXXFLAGS="$CXXFLAGS -I$enable_rocblas/include"
166+
LDFLAGS="$LDFLAGS -I$enable_rocblas/lib64"
167+
fi
168+
169+
LDFLAGS+=" -lrocblas"
170+
171+
AC_LINK_IFELSE(
172+
[
173+
AC_LANG_PROGRAM(
174+
[[#include <rocblas/rocblas.h>]],
175+
[[rocblas_handle handle; rocblas_create_handle(&handle);]]
176+
)
177+
],
178+
[
179+
AC_MSG_RESULT([rocBLAS support available])
180+
AC_DEFINE([USE_ROCBLAS],[1],[Use rocBLAS])
181+
AC_DEFINE([USE_BLAS],[1],[Use BLAS])
182+
],
183+
[AC_MSG_RESULT([rocBLAS support not available])
184+
AC_MSG_FAILURE([rocBLAS support was requested, but the test code did not compile])]
185+
)
186+
187+
188+
fi
189+
190+
149191
# Check for the --enable-onemkl option
150192
AC_ARG_ENABLE([onemkl],
151193
[AS_HELP_STRING([--enable-onemkl],

include/Accelerator.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ void accelerator_for_body(int dims[thrDim+blockDim],
163163
#define SIMT_ACTIVE
164164
#endif
165165

166+
//using std::min;
167+
166168
#define accelerator_only __device__
167169
#define accelerator __host__ __device__
168170
#define accelerator_inline __host__ __device__ inline

include/BLAS.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ void batchedGEMV(BLASop trans,
119119

120120

121121

122-
/**************************** Non-public API for column-major BLAS functions *************************************
122+
/**************************** Non-public API for column-major BLAS functions ************************************* */
123123

124124
/**
125125
* @brief Generic wrappers around the (strided) batched GEMM functionality

src/Accelerator.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,17 @@ void acceleratorReport(){
7474

7575
int nDevices = 1;
7676
d=hipGetDeviceCount(&nDevices);
77+
78+
const int len=64;
79+
char busid[len];
80+
d = hipDeviceGetPCIBusId(busid, len, device);
7781

7882
for(int w=0;w<world_nrank;w++){
7983
assert( MPI_Barrier(MPI_COMM_WORLD) == MPI_SUCCESS );
8084
if(w == world_rank)
8185
std::cout << "world:" << world_rank << '/' << world_nrank
8286
<< " device:" << device << '/' << nDevices
87+
<< " busID:" << busid
8388
<< std::endl << std::flush;
8489
}
8590
}

src/BLAS.cpp

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,204 @@ void cmGEMM<double>(BLASop transa, BLASop transb,
194194
C, ldc) == CUBLAS_STATUS_SUCCESS );
195195
}
196196

197+
#elif defined(USE_ROCBLAS)
198+
199+
#include <rocblas/rocblas.h>
200+
201+
static inline rocblas_operation rocblasOpLookup(BLASop op){
202+
switch(op){
203+
case NoTranspose:
204+
return rocblas_operation_none;
205+
case Transpose:
206+
return rocblas_operation_transpose;
207+
default:
208+
throw std::runtime_error("Unsupported BLASop");
209+
}
210+
}
211+
212+
struct rocBLAShandleContainer{
213+
rocblas_handle handle;
214+
215+
rocBLAShandleContainer(){
216+
assert( rocblas_create_handle(&handle) == rocblas_status_success );
217+
assert( rocblas_set_stream(handle, computeStream) == rocblas_status_success );
218+
}
219+
};
220+
static inline rocblas_handle getrocBLAShandle(){
221+
static rocBLAShandleContainer con;
222+
return con.handle;
223+
}
224+
225+
template<>
226+
void cmBatchedGEMM<float>(BLASop transa,
227+
BLASop transb,
228+
int m, int n, int k,
229+
const float *alpha,
230+
const float *A, int lda,
231+
long long int strideA,
232+
const float *B, int ldb,
233+
long long int strideB,
234+
const float *beta,
235+
float *C, int ldc,
236+
long long int strideC,
237+
int batchCount){
238+
239+
assert( rocblas_sgemm_strided_batched(getrocBLAShandle(), rocblasOpLookup(transa), rocblasOpLookup(transb),
240+
m, n, k,
241+
alpha,
242+
A, lda, strideA,
243+
B, ldb, strideB,
244+
beta,
245+
C, ldc, strideC,
246+
batchCount) == rocblas_status_success );
247+
}
248+
template<>
249+
void cmBatchedGEMM<double>(BLASop transa,
250+
BLASop transb,
251+
int m, int n, int k,
252+
const double *alpha,
253+
const double *A, int lda,
254+
long long int strideA,
255+
const double *B, int ldb,
256+
long long int strideB,
257+
const double *beta,
258+
double *C, int ldc,
259+
long long int strideC,
260+
int batchCount){
261+
assert( rocblas_dgemm_strided_batched(getrocBLAShandle(), rocblasOpLookup(transa), rocblasOpLookup(transb),
262+
m, n, k,
263+
alpha,
264+
A, lda, strideA,
265+
B, ldb, strideB,
266+
beta,
267+
C, ldc, strideC,
268+
batchCount) == rocblas_status_success );
269+
}
270+
271+
template<>
272+
void cmBatchedGEMV<float>(BLASop trans,
273+
int m, int n,
274+
const float *alpha,
275+
const float *A, int lda,
276+
long long int strideA,
277+
const float *x, int incx,
278+
long long int stridex,
279+
const float *beta,
280+
float *y, int incy,
281+
long long int stridey,
282+
int batchCount){
283+
284+
assert( rocblas_sgemv_strided_batched(getrocBLAShandle(),
285+
rocblasOpLookup(trans),
286+
m,n,
287+
alpha,
288+
A, lda, strideA,
289+
x, incx, stridex,
290+
beta,
291+
y, incy, stridey,
292+
batchCount) == rocblas_status_success );
293+
}
294+
295+
template<>
296+
void cmBatchedGEMV<double>(BLASop trans,
297+
int m, int n,
298+
const double *alpha,
299+
const double *A, int lda,
300+
long long int strideA,
301+
const double *x, int incx,
302+
long long int stridex,
303+
const double *beta,
304+
double *y, int incy,
305+
long long int stridey,
306+
int batchCount){
307+
assert( rocblas_dgemv_strided_batched(getrocBLAShandle(),
308+
rocblasOpLookup(trans),
309+
m,n,
310+
alpha,
311+
A, lda, strideA,
312+
x, incx, stridex,
313+
beta,
314+
y, incy, stridey,
315+
batchCount) == rocblas_status_success );
316+
}
317+
318+
template<>
319+
void cmBatchedGEMV<float>(BLASop trans,
320+
int m, int n,
321+
const float *alpha,
322+
const float *const Aarray[], int lda,
323+
const float *const xarray[], int incx,
324+
const float *beta,
325+
float * yarray[], int incy,
326+
int batchCount){
327+
328+
assert( rocblas_sgemv_batched(getrocBLAShandle(), rocblasOpLookup(trans),
329+
m, n,
330+
alpha,
331+
Aarray,lda,
332+
xarray, incx,
333+
beta,
334+
yarray, incy,
335+
batchCount) == rocblas_status_success );
336+
}
337+
338+
template<>
339+
void cmBatchedGEMV<double>(BLASop trans,
340+
int m, int n,
341+
const double *alpha,
342+
const double *const Aarray[], int lda,
343+
const double *const xarray[], int incx,
344+
const double *beta,
345+
double * yarray[], int incy,
346+
int batchCount){
347+
assert( rocblas_dgemv_batched(getrocBLAShandle(), rocblasOpLookup(trans),
348+
m, n,
349+
alpha,
350+
Aarray,lda,
351+
xarray, incx,
352+
beta,
353+
yarray, incy,
354+
batchCount) == rocblas_status_success );
355+
}
356+
357+
template<>
358+
void cmGEMM<float>(BLASop transa, BLASop transb,
359+
int m, int n, int k,
360+
const float *alpha,
361+
const float *A, int lda,
362+
const float *B, int ldb,
363+
const float *beta,
364+
float *C, int ldc){
365+
assert( rocblas_sgemm(getrocBLAShandle(),
366+
rocblasOpLookup(transa), rocblasOpLookup(transb),
367+
m, n, k,
368+
alpha,
369+
A, lda,
370+
B, ldb,
371+
beta,
372+
C, ldc) == rocblas_status_success );
373+
}
374+
375+
template<>
376+
void cmGEMM<double>(BLASop transa, BLASop transb,
377+
int m, int n, int k,
378+
const double *alpha,
379+
const double *A, int lda,
380+
const double *B, int ldb,
381+
const double *beta,
382+
double *C, int ldc){
383+
assert( rocblas_dgemm(getrocBLAShandle(),
384+
rocblasOpLookup(transa), rocblasOpLookup(transb),
385+
m, n, k,
386+
alpha,
387+
A, lda,
388+
B, ldb,
389+
beta,
390+
C, ldc) == rocblas_status_success );
391+
}
392+
393+
394+
197395
#elif defined(USE_ONEMKL)
198396

199397
#undef VERSION

0 commit comments

Comments
 (0)