Skip to content

Commit 15fc375

Browse files
committed
Faster LRM
1 parent 79ca926 commit 15fc375

File tree

6 files changed

+138
-37
lines changed

6 files changed

+138
-37
lines changed

inst/include/propr/kernels/cuda/detail/lrm.cuh

Lines changed: 108 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,43 +13,125 @@ namespace propr {
1313
namespace detail {
1414
namespace cuda {
1515

16+
// template <class Config>
17+
// __global__
18+
// void
19+
// lrm_basic(float* __restrict__ d_Y, offset_t d_Y_stride,
20+
// float* __restrict__ d_mean,
21+
// int nb_samples,
22+
// int nb_genes) {
23+
// int i = blockIdx.x * blockDim.x + threadIdx.x;
24+
// int j = blockIdx.y * blockDim.y + threadIdx.y;
25+
// if (i >= nb_genes || j >= i) return;
26+
27+
// float4 accum = {0.0f, 0.0f, 0.0f, 0.0f};
28+
// int k = 0;
29+
// PROPR_UNROLL
30+
// for (; k < (nb_samples/4)*4; k += 4) {
31+
// float4 y_i = thread::load<Config::LoadModifer,float4>(&d_Y[k + i * d_Y_stride]);
32+
// float4 y_j = thread::load<Config::LoadModifer,float4>(&d_Y[k + j * d_Y_stride]);
33+
34+
// accum.x = __logf(__fdividef(y_i.x, y_j.x)) + accum.x;
35+
// accum.y = __logf(__fdividef(y_i.y, y_j.y)) + accum.y;
36+
// accum.z = __logf(__fdividef(y_i.z, y_j.z)) + accum.z;
37+
// accum.w = __logf(__fdividef(y_i.w, y_j.w)) + accum.w;
38+
// }
39+
40+
// accum.x = accum.x + accum.y + accum.z + accum.w;
41+
// for (; k < nb_samples; ++k) {
42+
// float yi = d_Y[k + i * d_Y_stride];
43+
// float yj = d_Y[k + j * d_Y_stride];
44+
// accum.x = __logf(__fdividef(yi, yj)) + accum.x;
45+
// }
46+
47+
// float inv_n = __frcp_rn(static_cast<float>(nb_samples));
48+
// float mean = accum.x * inv_n;
49+
// int pair_index = (i * (i - 1)) / 2 + j;
50+
// d_mean[pair_index] = mean;
51+
// }
52+
1653
template <class Config>
1754
__global__
1855
void
19-
lrm_basic(float* __restrict__ d_Y, offset_t d_Y_stride,
20-
float* __restrict__ d_mean,
21-
int nb_samples,
22-
int nb_genes) {
23-
int i = blockIdx.x * blockDim.x + threadIdx.x;
24-
int j = blockIdx.y * blockDim.y + threadIdx.y;
25-
if (i >= nb_genes || j >= i) return;
26-
27-
float4 accum = {0.0f, 0.0f, 0.0f, 0.0f};
56+
lrm_basic_phase_1(float* __restrict__ d_Y,
57+
offset_t d_Y_stride,
58+
float* __restrict__ d_mean_log,
59+
int nb_samples,
60+
int nb_genes) {
61+
const auto EPS = std::numeric_limits<float>::epsilon();
62+
const int g = blockIdx.x * blockDim.x + threadIdx.x;
63+
if (g >= nb_genes) return;
64+
65+
const offset_t g_offset = static_cast<offset_t>(g) * d_Y_stride;
66+
67+
float s0 = 0.0;
68+
float s1 = 0.0;
69+
float s2 = 0.0;
70+
float s3 = 0.0;
2871
int k = 0;
72+
2973
PROPR_UNROLL
30-
for (; k < (nb_samples/4)*4; k += 4) {
31-
float4 y_i = thread::load<Config::LoadModifer,float4>(&d_Y[k + i * d_Y_stride]);
32-
float4 y_j = thread::load<Config::LoadModifer,float4>(&d_Y[k + j * d_Y_stride]);
33-
34-
accum.x = __logf(__fdividef(y_i.x, y_j.x)) + accum.x;
35-
accum.y = __logf(__fdividef(y_i.y, y_j.y)) + accum.y;
36-
accum.z = __logf(__fdividef(y_i.z, y_j.z)) + accum.z;
37-
accum.w = __logf(__fdividef(y_i.w, y_j.w)) + accum.w;
74+
for (; k < (nb_samples / 4) * 4; k += 4) {
75+
const float4 y = thread::load<Config::LoadModifer, float4>(&d_Y[g_offset + k]);
76+
s0 += __logf(fmaxf(y.x, EPS));
77+
s1 += __logf(fmaxf(y.y, EPS));
78+
s2 += __logf(fmaxf(y.z, EPS));
79+
s3 += __logf(fmaxf(y.w, EPS));
3880
}
3981

40-
accum.x = accum.x + accum.y + accum.z + accum.w;
82+
double sum = (s0 + s1) + (s2 + s3);
4183
for (; k < nb_samples; ++k) {
42-
float yi = d_Y[k + i * d_Y_stride];
43-
float yj = d_Y[k + j * d_Y_stride];
44-
accum.x = __logf(__fdividef(yi, yj)) + accum.x;
84+
const float y = thread::load<Config::LoadModifer, float>(&d_Y[g_offset + k]);
85+
sum += static_cast<double>(__logf(fmaxf(y, EPS)));
4586
}
4687

47-
float inv_n = __frcp_rn(static_cast<float>(nb_samples));
48-
float mean = accum.x * inv_n;
49-
int pair_index = (i * (i - 1)) / 2 + j;
50-
d_mean[pair_index] = mean;
88+
const float mean_log = static_cast<float>(sum / static_cast<double>(nb_samples));
89+
thread::store<Config::StoreModifer, float>(&d_mean_log[g], mean_log);
90+
}
91+
92+
template <class Config>
93+
__global__
94+
void
95+
lrm_basic_phase_2(float* __restrict__ d_mean_log,
96+
float* __restrict__ d_mean,
97+
int nb_genes) {
98+
using P2_Layout = typename Config::P2_Layout;
99+
static_assert(P2_Layout::BLK_X == P2_Layout::BLK_Y, "Tile size must be square");
100+
constexpr int TILE_G = P2_Layout::BLK_X;
101+
102+
const int li = threadIdx.x;
103+
const int lj = threadIdx.y;
104+
105+
const int gi = blockIdx.x * TILE_G + li;
106+
const int gj = blockIdx.y * TILE_G + lj;
107+
108+
if (blockIdx.y > blockIdx.x) return;
109+
110+
__shared__ float sh_i[TILE_G], sh_j[TILE_G];
111+
112+
if (lj == 0) {
113+
sh_i[li] = (gi < nb_genes)
114+
? thread::load<Config::LoadModifer, float>(&d_mean_log[gi])
115+
: 0.0f;
116+
}
117+
118+
if (li == 0) {
119+
sh_j[lj] = (gj < nb_genes)
120+
? thread::load<Config::LoadModifer, float>(&d_mean_log[gj])
121+
: 0.0f;
122+
}
123+
124+
__syncthreads();
125+
126+
if (gi < nb_genes && gj < nb_genes && gj < gi) {
127+
const offset_t pair_index =
128+
(static_cast<offset_t>(gi) * static_cast<offset_t>(gi - 1)) / 2 +
129+
static_cast<offset_t>(gj);
130+
thread::store<Config::StoreModifer, float>(&d_mean[pair_index], sh_i[li] - sh_j[lj]);
131+
}
51132
}
52133

134+
53135
template<class Config>
54136
__global__
55137
void
@@ -354,4 +436,4 @@ namespace propr {
354436

355437
}
356438
}
357-
}
439+
}

inst/include/propr/kernels/cuda/dispatch/comparison.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
namespace propr {
66
namespace dispatch {
77
namespace cuda {
8-
int count_less_than (Rcpp::NumericVector& x, double cutoff,propr::propr_context context=DEFAULT_GLOBAL_CONTEXT);
9-
int count_greater_than (Rcpp::NumericVector& x, double cutoff,propr::propr_context context=DEFAULT_GLOBAL_CONTEXT);
10-
int count_less_equal_than (Rcpp::NumericVector& x, double cutoff,propr::propr_context context=DEFAULT_GLOBAL_CONTEXT);
8+
int count_less_than(Rcpp::NumericVector& x, double cutoff, propr::propr_context context=DEFAULT_GLOBAL_CONTEXT);
9+
int count_greater_than(Rcpp::NumericVector& x, double cutoff,propr::propr_context context=DEFAULT_GLOBAL_CONTEXT);
10+
int count_less_equal_than(Rcpp::NumericVector& x, double cutoff,propr::propr_context context=DEFAULT_GLOBAL_CONTEXT);
1111
int count_greater_equal_than(Rcpp::NumericVector& x, double cutoff,propr::propr_context context=DEFAULT_GLOBAL_CONTEXT);
1212
}
1313
}

inst/include/propr/kernels/cuda/traits/lrm.cuh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
namespace propr {
88
namespace cuda {
99
namespace traits {
10-
struct lrm_basic : thread_layout_2d<>{
10+
struct lrm_basic {
11+
using P1_Layout = thread_layout_1d<256>;
12+
using P2_Layout = thread_layout_2d<16,16>;
13+
1114
const static cub::CacheLoadModifier LoadModifer = cub::LOAD_CG;
1215
const static cub::CacheStoreModifier StoreModifer = cub::STORE_CG;
1316
};
@@ -28,4 +31,4 @@ namespace propr {
2831
};
2932
}
3033
}
31-
}
34+
}

src/dispatch/cpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ set(PROPR_SOURCES
77
${CMAKE_CURRENT_SOURCE_DIR}/lrm.cpp
88
${CMAKE_CURRENT_SOURCE_DIR}/lrv.cpp
99
${CMAKE_CURRENT_SOURCE_DIR}/omega.cpp
10+
${CMAKE_CURRENT_SOURCE_DIR}/genewise.cpp
1011
${PROPR_SOURCES}
1112
PARENT_SCOPE
1213
)

src/dispatch/cuda/lrm.cu

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,37 @@ propr::dispatch::cuda::lrm_basic(NumericVector& out, NumericMatrix &Y, propr::pr
2626
float* d_Y;
2727
offset_t stride; d_Y = RcppMatrixToDevice<float>(Y, stride);
2828

29+
float* d_mean_log;
30+
PROPR_CUDA_CHECK(cudaMalloc(&d_mean_log, static_cast<size_t>(N_genes) * sizeof(float)));
31+
2932
float* d_mean;
3033
PROPR_CUDA_CHECK(cudaMalloc(&d_mean, N_pairs * sizeof(float)));
3134

32-
dim3 blockDim(Config::BLK_X, Config::BLK_Y);
33-
dim3 gridDim(propr::ceil_div(N_genes, Config::BLK_X), propr::ceil_div(N_genes, Config::BLK_Y));
35+
dim3 block1(Config::P1_Layout::BLK_X);
36+
dim3 grid1(propr::ceil_div(N_genes, Config::P1_Layout::BLK_X));
37+
38+
dim3 block2(Config::P2_Layout::BLK_X, Config::P2_Layout::BLK_Y);
39+
dim3 grid2(propr::ceil_div(N_genes, Config::P2_Layout::BLK_X),
40+
propr::ceil_div(N_genes, Config::P2_Layout::BLK_Y));
41+
3442

3543
{
3644
PROPR_PROFILE_CUDA("kernel", context.stream);
37-
propr::detail::cuda::lrm_basic<Config><<<gridDim, blockDim, 0, context.stream>>>(
38-
d_Y, stride, d_mean, N_samples, N_genes
45+
propr::detail::cuda::lrm_basic_phase_1<Config><<<grid1, block1, 0, context.stream>>>(
46+
d_Y, stride, d_mean_log, N_samples, N_genes
47+
);
48+
PROPR_CUDA_CHECK(cudaGetLastError());
49+
50+
propr::detail::cuda::lrm_basic_phase_2<Config><<<grid2, block2, 0, context.stream>>>(
51+
d_mean_log, d_mean, N_genes
3952
);
53+
PROPR_CUDA_CHECK(cudaGetLastError());
4054
PROPR_STREAM_SYNCHRONIZE(context);
4155
}
4256

4357
copyToNumericVector(d_mean, out, N_pairs);
4458
PROPR_CUDA_CHECK(cudaFree(d_Y));
59+
PROPR_CUDA_CHECK(cudaFree(d_mean_log));
4560
PROPR_CUDA_CHECK(cudaFree(d_mean));
4661
}
4762

@@ -167,4 +182,4 @@ propr::dispatch::cuda::lrm_alpha_weighted(NumericVector& out,
167182
PROPR_CUDA_CHECK(cudaFree(d_Yfull));
168183
PROPR_CUDA_CHECK(cudaFree(d_Wfull));
169184
PROPR_CUDA_CHECK(cudaFree(d_means));
170-
}
185+
}

src/dispatch/runtime/resolve_backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Backend resolve_backend(const Rcpp::String& requested) {
2424
}
2525
}
2626

27-
if (req == "cuda") {
27+
if (req == "cuda" || req == "gpu" ) {
2828
if (cuda_is_available()) return Backend::CUDA;
2929
static bool warned = false;
3030
if (!warned) {

0 commit comments

Comments
 (0)