diff --git a/exllamav2/exllamav2_ext/cuda/layer_norm.cu b/exllamav2/exllamav2_ext/cuda/layer_norm.cu index b286492e..d48befb4 100644 --- a/exllamav2/exllamav2_ext/cuda/layer_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/layer_norm.cu @@ -5,11 +5,15 @@ #if defined(USE_ROCM) #define NUM_WARPS (1024 / warpSize) #define WARP_SIZE (warpSize) +#define MAX_NUM_WARPS 32 #else #define NUM_WARPS 32 #define WARP_SIZE 32 +#define MAX_NUM_WARPS 32 #endif +#define NUM_THREADS_CONST 1024 + // y = x * w / sqrt(row_mean(x * x) + epsilon) #define BLOCK_SIZE WARP_SIZE @@ -75,7 +79,7 @@ __global__ void layer_norm_kernel // Shuffle to sum across lanes - __shared__ float sums[NUM_WARPS]; + __shared__ float sums[MAX_NUM_WARPS]; for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset); if (lane_id == 0) sums[warp_id] = sum; @@ -198,14 +202,14 @@ void layer_norm_cuda ) { dim3 blockDim, gridDim; - blockDim.x = NUM_THREADS; + blockDim.x = NUM_THREADS_CONST; blockDim.y = 1; gridDim.x = rows; gridDim.y = 1; float r_dim = 1.0f / (float) dim; - int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2); + int blocks_per_warp = DIVIDE(dim, NUM_THREADS_CONST * 2); fp_layer_norm_kernel kernel = pick_layer_norm_kernel(blocks_per_warp); kernel<<>>(x, w, b, y, epsilon, r_dim, rows, dim, add_residual); if (graph) graph->attach_label(stream, label, 0); diff --git a/exllamav2/exllamav2_ext/cuda/rms_norm.cu b/exllamav2/exllamav2_ext/cuda/rms_norm.cu index 94155ade..7ca307aa 100644 --- a/exllamav2/exllamav2_ext/cuda/rms_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/rms_norm.cu @@ -5,11 +5,17 @@ #if defined(USE_ROCM) #define NUM_WARPS (1024 / warpSize) #define WARP_SIZE (warpSize) +#define MAX_NUM_WARPS 32 #else #define NUM_WARPS 32 #define WARP_SIZE 32 +#define MAX_NUM_WARPS 32 #endif +// NUM_WARPS * WARP_SIZE is always 1024 regardless of warp size. +// Use this in host code where warpSize (__device__ variable) is unavailable. +#define NUM_THREADS_CONST 1024 + // y = x * w / sqrt(row_mean(x * x) + epsilon) #define BLOCK_SIZE WARP_SIZE @@ -98,7 +104,7 @@ __global__ void rms_norm_kernel // Shuffle to sum across lanes - __shared__ float sums[NUM_WARPS]; + __shared__ float sums[MAX_NUM_WARPS]; for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset); if (lane_id == 0) sums[warp_id] = sum; @@ -215,14 +221,14 @@ void rms_norm_cuda ) { dim3 blockDim, gridDim; - blockDim.x = NUM_THREADS; + blockDim.x = NUM_THREADS_CONST; blockDim.y = 1; gridDim.x = rows; gridDim.y = 1; float r_dim = 1.0f / (float) dim; - int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2); + int blocks_per_warp = DIVIDE(dim, NUM_THREADS_CONST * 2); fp_rms_norm_kernel kernel = pick_rms_norm_kernel(blocks_per_warp); kernel<<>>(x, w, y, epsilon, r_dim, rows, dim, add_residual, input_fp32, output_fp32); if (graph) graph->attach_label(stream, label, 0);