Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions exllamav2/exllamav2_ext/cuda/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
#if defined(USE_ROCM)
#define NUM_WARPS (1024 / warpSize)
#define WARP_SIZE (warpSize)
#define MAX_NUM_WARPS 32
#else
#define NUM_WARPS 32
#define WARP_SIZE 32
#define MAX_NUM_WARPS 32
#endif

#define NUM_THREADS_CONST 1024

// y = x * w / sqrt(row_mean(x * x) + epsilon)

#define BLOCK_SIZE WARP_SIZE
Expand Down Expand Up @@ -75,7 +79,7 @@ __global__ void layer_norm_kernel

// Shuffle to sum across lanes

__shared__ float sums[NUM_WARPS];
__shared__ float sums[MAX_NUM_WARPS];

for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
if (lane_id == 0) sums[warp_id] = sum;
Expand Down Expand Up @@ -198,14 +202,14 @@ void layer_norm_cuda
)
{
dim3 blockDim, gridDim;
blockDim.x = NUM_THREADS;
blockDim.x = NUM_THREADS_CONST;
blockDim.y = 1;
gridDim.x = rows;
gridDim.y = 1;

float r_dim = 1.0f / (float) dim;

int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2);
int blocks_per_warp = DIVIDE(dim, NUM_THREADS_CONST * 2);
fp_layer_norm_kernel kernel = pick_layer_norm_kernel(blocks_per_warp);
kernel<<<gridDim, blockDim, 0, stream>>>(x, w, b, y, epsilon, r_dim, rows, dim, add_residual);
if (graph) graph->attach_label(stream, label, 0);
Expand Down
12 changes: 9 additions & 3 deletions exllamav2/exllamav2_ext/cuda/rms_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
#if defined(USE_ROCM)
#define NUM_WARPS (1024 / warpSize)
#define WARP_SIZE (warpSize)
#define MAX_NUM_WARPS 32
#else
#define NUM_WARPS 32
#define WARP_SIZE 32
#define MAX_NUM_WARPS 32
#endif

// NUM_WARPS * WARP_SIZE is always 1024 regardless of warp size.
// Use this in host code where warpSize (__device__ variable) is unavailable.
#define NUM_THREADS_CONST 1024

// y = x * w / sqrt(row_mean(x * x) + epsilon)

#define BLOCK_SIZE WARP_SIZE
Expand Down Expand Up @@ -98,7 +104,7 @@ __global__ void rms_norm_kernel

// Shuffle to sum across lanes

__shared__ float sums[NUM_WARPS];
__shared__ float sums[MAX_NUM_WARPS];

for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
if (lane_id == 0) sums[warp_id] = sum;
Expand Down Expand Up @@ -215,14 +221,14 @@ void rms_norm_cuda
)
{
dim3 blockDim, gridDim;
blockDim.x = NUM_THREADS;
blockDim.x = NUM_THREADS_CONST;
blockDim.y = 1;
gridDim.x = rows;
gridDim.y = 1;

float r_dim = 1.0f / (float) dim;

int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2);
int blocks_per_warp = DIVIDE(dim, NUM_THREADS_CONST * 2);
fp_rms_norm_kernel kernel = pick_rms_norm_kernel(blocks_per_warp);
kernel<<<gridDim, blockDim, 0, stream>>>(x, w, y, epsilon, r_dim, rows, dim, add_residual, input_fp32, output_fp32);
if (graph) graph->attach_label(stream, label, 0);
Expand Down