@@ -194,6 +194,204 @@ void cmGEMM<double>(BLASop transa, BLASop transb,
194194 C, ldc) == CUBLAS_STATUS_SUCCESS );
195195}
196196
197+ #elif defined(USE_ROCBLAS)
198+
199+ #include < rocblas/rocblas.h>
200+
201+ static inline rocblas_operation rocblasOpLookup (BLASop op){
202+ switch (op){
203+ case NoTranspose:
204+ return rocblas_operation_none;
205+ case Transpose:
206+ return rocblas_operation_transpose;
207+ default :
208+ throw std::runtime_error (" Unsupported BLASop" );
209+ }
210+ }
211+
212+ struct rocBLAShandleContainer {
213+ rocblas_handle handle;
214+
215+ rocBLAShandleContainer (){
216+ assert ( rocblas_create_handle (&handle) == rocblas_status_success );
217+ assert ( rocblas_set_stream (handle, computeStream) == rocblas_status_success );
218+ }
219+ };
220+ static inline rocblas_handle getrocBLAShandle (){
221+ static rocBLAShandleContainer con;
222+ return con.handle ;
223+ }
224+
225+ template <>
226+ void cmBatchedGEMM<float >(BLASop transa,
227+ BLASop transb,
228+ int m, int n, int k,
229+ const float *alpha,
230+ const float *A, int lda,
231+ long long int strideA,
232+ const float *B, int ldb,
233+ long long int strideB,
234+ const float *beta,
235+ float *C, int ldc,
236+ long long int strideC,
237+ int batchCount){
238+
239+ assert ( rocblas_sgemm_strided_batched (getrocBLAShandle (), rocblasOpLookup (transa), rocblasOpLookup (transb),
240+ m, n, k,
241+ alpha,
242+ A, lda, strideA,
243+ B, ldb, strideB,
244+ beta,
245+ C, ldc, strideC,
246+ batchCount) == rocblas_status_success );
247+ }
248+ template <>
249+ void cmBatchedGEMM<double >(BLASop transa,
250+ BLASop transb,
251+ int m, int n, int k,
252+ const double *alpha,
253+ const double *A, int lda,
254+ long long int strideA,
255+ const double *B, int ldb,
256+ long long int strideB,
257+ const double *beta,
258+ double *C, int ldc,
259+ long long int strideC,
260+ int batchCount){
261+ assert ( rocblas_dgemm_strided_batched (getrocBLAShandle (), rocblasOpLookup (transa), rocblasOpLookup (transb),
262+ m, n, k,
263+ alpha,
264+ A, lda, strideA,
265+ B, ldb, strideB,
266+ beta,
267+ C, ldc, strideC,
268+ batchCount) == rocblas_status_success );
269+ }
270+
271+ template <>
272+ void cmBatchedGEMV<float >(BLASop trans,
273+ int m, int n,
274+ const float *alpha,
275+ const float *A, int lda,
276+ long long int strideA,
277+ const float *x, int incx,
278+ long long int stridex,
279+ const float *beta,
280+ float *y, int incy,
281+ long long int stridey,
282+ int batchCount){
283+
284+ assert ( rocblas_sgemv_strided_batched (getrocBLAShandle (),
285+ rocblasOpLookup (trans),
286+ m,n,
287+ alpha,
288+ A, lda, strideA,
289+ x, incx, stridex,
290+ beta,
291+ y, incy, stridey,
292+ batchCount) == rocblas_status_success );
293+ }
294+
295+ template <>
296+ void cmBatchedGEMV<double >(BLASop trans,
297+ int m, int n,
298+ const double *alpha,
299+ const double *A, int lda,
300+ long long int strideA,
301+ const double *x, int incx,
302+ long long int stridex,
303+ const double *beta,
304+ double *y, int incy,
305+ long long int stridey,
306+ int batchCount){
307+ assert ( rocblas_dgemv_strided_batched (getrocBLAShandle (),
308+ rocblasOpLookup (trans),
309+ m,n,
310+ alpha,
311+ A, lda, strideA,
312+ x, incx, stridex,
313+ beta,
314+ y, incy, stridey,
315+ batchCount) == rocblas_status_success );
316+ }
317+
318+ template <>
319+ void cmBatchedGEMV<float >(BLASop trans,
320+ int m, int n,
321+ const float *alpha,
322+ const float *const Aarray[], int lda,
323+ const float *const xarray[], int incx,
324+ const float *beta,
325+ float * yarray[], int incy,
326+ int batchCount){
327+
328+ assert ( rocblas_sgemv_batched (getrocBLAShandle (), rocblasOpLookup (trans),
329+ m, n,
330+ alpha,
331+ Aarray,lda,
332+ xarray, incx,
333+ beta,
334+ yarray, incy,
335+ batchCount) == rocblas_status_success );
336+ }
337+
338+ template <>
339+ void cmBatchedGEMV<double >(BLASop trans,
340+ int m, int n,
341+ const double *alpha,
342+ const double *const Aarray[], int lda,
343+ const double *const xarray[], int incx,
344+ const double *beta,
345+ double * yarray[], int incy,
346+ int batchCount){
347+ assert ( rocblas_dgemv_batched (getrocBLAShandle (), rocblasOpLookup (trans),
348+ m, n,
349+ alpha,
350+ Aarray,lda,
351+ xarray, incx,
352+ beta,
353+ yarray, incy,
354+ batchCount) == rocblas_status_success );
355+ }
356+
357+ template <>
358+ void cmGEMM<float >(BLASop transa, BLASop transb,
359+ int m, int n, int k,
360+ const float *alpha,
361+ const float *A, int lda,
362+ const float *B, int ldb,
363+ const float *beta,
364+ float *C, int ldc){
365+ assert ( rocblas_sgemm (getrocBLAShandle (),
366+ rocblasOpLookup (transa), rocblasOpLookup (transb),
367+ m, n, k,
368+ alpha,
369+ A, lda,
370+ B, ldb,
371+ beta,
372+ C, ldc) == rocblas_status_success );
373+ }
374+
375+ template <>
376+ void cmGEMM<double >(BLASop transa, BLASop transb,
377+ int m, int n, int k,
378+ const double *alpha,
379+ const double *A, int lda,
380+ const double *B, int ldb,
381+ const double *beta,
382+ double *C, int ldc){
383+ assert ( rocblas_dgemm (getrocBLAShandle (),
384+ rocblasOpLookup (transa), rocblasOpLookup (transb),
385+ m, n, k,
386+ alpha,
387+ A, lda,
388+ B, ldb,
389+ beta,
390+ C, ldc) == rocblas_status_success );
391+ }
392+
393+
394+
197395#elif defined(USE_ONEMKL)
198396
199397#undef VERSION
0 commit comments