@@ -13,43 +13,125 @@ namespace propr {
1313 namespace detail {
1414 namespace cuda {
1515
16+ // template <class Config>
17+ // __global__
18+ // void
19+ // lrm_basic(float* __restrict__ d_Y, offset_t d_Y_stride,
20+ // float* __restrict__ d_mean,
21+ // int nb_samples,
22+ // int nb_genes) {
23+ // int i = blockIdx.x * blockDim.x + threadIdx.x;
24+ // int j = blockIdx.y * blockDim.y + threadIdx.y;
25+ // if (i >= nb_genes || j >= i) return;
26+
27+ // float4 accum = {0.0f, 0.0f, 0.0f, 0.0f};
28+ // int k = 0;
29+ // PROPR_UNROLL
30+ // for (; k < (nb_samples/4)*4; k += 4) {
31+ // float4 y_i = thread::load<Config::LoadModifer,float4>(&d_Y[k + i * d_Y_stride]);
32+ // float4 y_j = thread::load<Config::LoadModifer,float4>(&d_Y[k + j * d_Y_stride]);
33+
34+ // accum.x = __logf(__fdividef(y_i.x, y_j.x)) + accum.x;
35+ // accum.y = __logf(__fdividef(y_i.y, y_j.y)) + accum.y;
36+ // accum.z = __logf(__fdividef(y_i.z, y_j.z)) + accum.z;
37+ // accum.w = __logf(__fdividef(y_i.w, y_j.w)) + accum.w;
38+ // }
39+
40+ // accum.x = accum.x + accum.y + accum.z + accum.w;
41+ // for (; k < nb_samples; ++k) {
42+ // float yi = d_Y[k + i * d_Y_stride];
43+ // float yj = d_Y[k + j * d_Y_stride];
44+ // accum.x = __logf(__fdividef(yi, yj)) + accum.x;
45+ // }
46+
47+ // float inv_n = __frcp_rn(static_cast<float>(nb_samples));
48+ // float mean = accum.x * inv_n;
49+ // int pair_index = (i * (i - 1)) / 2 + j;
50+ // d_mean[pair_index] = mean;
51+ // }
52+
1653 template <class Config >
1754 __global__
1855 void
19- lrm_basic (float * __restrict__ d_Y, offset_t d_Y_stride,
20- float * __restrict__ d_mean,
21- int nb_samples,
22- int nb_genes) {
23- int i = blockIdx .x * blockDim .x + threadIdx .x ;
24- int j = blockIdx .y * blockDim .y + threadIdx .y ;
25- if (i >= nb_genes || j >= i) return ;
26-
27- float4 accum = {0 .0f , 0 .0f , 0 .0f , 0 .0f };
56+ lrm_basic_phase_1 (float * __restrict__ d_Y,
57+ offset_t d_Y_stride,
58+ float * __restrict__ d_mean_log,
59+ int nb_samples,
60+ int nb_genes) {
61+ const auto EPS = std::numeric_limits<float >::epsilon ();
62+ const int g = blockIdx .x * blockDim .x + threadIdx .x ;
63+ if (g >= nb_genes) return ;
64+
65+ const offset_t g_offset = static_cast <offset_t >(g) * d_Y_stride;
66+
67+ float s0 = 0.0 ;
68+ float s1 = 0.0 ;
69+ float s2 = 0.0 ;
70+ float s3 = 0.0 ;
2871 int k = 0 ;
72+
2973 PROPR_UNROLL
30- for (; k < (nb_samples/4 )*4 ; k += 4 ) {
31- float4 y_i = thread::load<Config::LoadModifer,float4 >(&d_Y[k + i * d_Y_stride]);
32- float4 y_j = thread::load<Config::LoadModifer,float4 >(&d_Y[k + j * d_Y_stride]);
33-
34- accum.x = __logf (__fdividef (y_i.x , y_j.x )) + accum.x ;
35- accum.y = __logf (__fdividef (y_i.y , y_j.y )) + accum.y ;
36- accum.z = __logf (__fdividef (y_i.z , y_j.z )) + accum.z ;
37- accum.w = __logf (__fdividef (y_i.w , y_j.w )) + accum.w ;
74+ for (; k < (nb_samples / 4 ) * 4 ; k += 4 ) {
75+ const float4 y = thread::load<Config::LoadModifer, float4 >(&d_Y[g_offset + k]);
76+ s0 += __logf (fmaxf (y.x , EPS));
77+ s1 += __logf (fmaxf (y.y , EPS));
78+ s2 += __logf (fmaxf (y.z , EPS));
79+ s3 += __logf (fmaxf (y.w , EPS));
3880 }
3981
40- accum. x = accum. x + accum. y + accum. z + accum. w ;
82+ double sum = (s0 + s1) + (s2 + s3) ;
4183 for (; k < nb_samples; ++k) {
42- float yi = d_Y[k + i * d_Y_stride];
43- float yj = d_Y[k + j * d_Y_stride];
44- accum.x = __logf (__fdividef (yi, yj)) + accum.x ;
84+ const float y = thread::load<Config::LoadModifer, float >(&d_Y[g_offset + k]);
85+ sum += static_cast <double >(__logf (fmaxf (y, EPS)));
4586 }
4687
47- float inv_n = __frcp_rn (static_cast <float >(nb_samples));
48- float mean = accum.x * inv_n;
49- int pair_index = (i * (i - 1 )) / 2 + j;
50- d_mean[pair_index] = mean;
88+ const float mean_log = static_cast <float >(sum / static_cast <double >(nb_samples));
89+ thread::store<Config::StoreModifer, float >(&d_mean_log[g], mean_log);
90+ }
91+
92+ template <class Config >
93+ __global__
94+ void
95+ lrm_basic_phase_2 (float * __restrict__ d_mean_log,
96+ float * __restrict__ d_mean,
97+ int nb_genes) {
98+ using P2_Layout = typename Config::P2_Layout;
99+ static_assert (P2_Layout::BLK_X == P2_Layout::BLK_Y, " Tile size must be square" );
100+ constexpr int TILE_G = P2_Layout::BLK_X;
101+
102+ const int li = threadIdx .x ;
103+ const int lj = threadIdx .y ;
104+
105+ const int gi = blockIdx .x * TILE_G + li;
106+ const int gj = blockIdx .y * TILE_G + lj;
107+
108+ if (blockIdx .y > blockIdx .x ) return ;
109+
110+ __shared__ float sh_i[TILE_G], sh_j[TILE_G];
111+
112+ if (lj == 0 ) {
113+ sh_i[li] = (gi < nb_genes)
114+ ? thread::load<Config::LoadModifer, float >(&d_mean_log[gi])
115+ : 0 .0f ;
116+ }
117+
118+ if (li == 0 ) {
119+ sh_j[lj] = (gj < nb_genes)
120+ ? thread::load<Config::LoadModifer, float >(&d_mean_log[gj])
121+ : 0 .0f ;
122+ }
123+
124+ __syncthreads ();
125+
126+ if (gi < nb_genes && gj < nb_genes && gj < gi) {
127+ const offset_t pair_index =
128+ (static_cast <offset_t >(gi) * static_cast <offset_t >(gi - 1 )) / 2 +
129+ static_cast <offset_t >(gj);
130+ thread::store<Config::StoreModifer, float >(&d_mean[pair_index], sh_i[li] - sh_j[lj]);
131+ }
51132 }
52133
134+
53135 template <class Config >
54136 __global__
55137 void
@@ -354,4 +436,4 @@ namespace propr {
354436
355437 }
356438 }
357- }
439+ }
0 commit comments