From 6e5359602c28744193d86acb6056bcd987a530b2 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 30 Mar 2026 06:44:47 -0700 Subject: [PATCH 01/17] working tests except for downsampling --- tests/test_attention.py | 2 +- torch_harmonics/attention/_attention_utils.py | 24 + .../attention/csrc/attention_cuda.cuh | 23 + .../attention/csrc/attention_cuda_bwd.cu | 494 ++++++++++++++++++ .../attention/csrc/attention_cuda_fwd.cu | 252 +++++++++ .../attention/csrc/attention_interface.cpp | 3 + torch_harmonics/distributed/__init__.py | 3 + 7 files changed, 800 insertions(+), 1 deletion(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index d966d47b..96cee076 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -87,7 +87,7 @@ def setUp(self): skip_on_empty=True, ) @unittest.skipUnless(optimized_kernels_is_available(), "skipping test because optimized kernels are not available") - def test_custom_implementation(self, batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False): + def test_custom_implementation(self, batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): """Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation""" if (self.device.type == "cuda") and (not cuda_kernels_is_available()): diff --git a/torch_harmonics/attention/_attention_utils.py b/torch_harmonics/attention/_attention_utils.py index a8c1b58a..b7ef7b67 100644 --- a/torch_harmonics/attention/_attention_utils.py +++ b/torch_harmonics/attention/_attention_utils.py @@ -67,6 +67,30 @@ def _(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.T dq = torch.empty_like(qw) return dk, dv, dq + # fake implementations for ring step ops + @torch.library.register_fake("attention_kernels::forward_ring_step") + def _(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, + y_acc: torch.Tensor, alpha_sum_buf: torch.Tensor, qdotk_max_buf: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, row_idx: torch.Tensor, + nlon_in: int, lon_lo_kx: int, lat_halo_start: int, nlat_out: int, nlon_out: int) -> None: + pass + + @torch.library.register_fake("attention_kernels::backward_ring_step_pass1") + def _(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + alpha_sum_buf: torch.Tensor, qdotk_max_buf: torch.Tensor, integral_buf: torch.Tensor, + alpha_k_buf: torch.Tensor, alpha_kvw_buf: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, row_idx: torch.Tensor, + nlon_in: int, lon_lo_kx: int, lat_halo_start: int, nlat_out: int, nlon_out: int) -> None: + pass + + @torch.library.register_fake("attention_kernels::backward_ring_step_pass2") + def _(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + alpha_sum_buf: torch.Tensor, qdotk_max_buf: torch.Tensor, integral_norm_buf: torch.Tensor, + dkx: torch.Tensor, dvx: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, row_idx: torch.Tensor, + nlon_in: int, lon_lo_kx: int, lat_halo_start: int, nlat_out: int, nlon_out: int) -> None: + pass + # forward @torch.library.custom_op("attention_kernels::_neighborhood_s2_attention_optimized", mutates_args=()) def _neighborhood_s2_attention_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, diff --git a/torch_harmonics/attention/csrc/attention_cuda.cuh b/torch_harmonics/attention/csrc/attention_cuda.cuh index 0042bfce..b4818d07 100644 --- a/torch_harmonics/attention/csrc/attention_cuda.cuh +++ b/torch_harmonics/attention/csrc/attention_cuda.cuh @@ -51,4 +51,27 @@ std::tuple s2_attention_bwd_dkvq_cuda(at::Te at::Tensor psi_col_idx, at::Tensor psi_row_off, int64_t nlon_in, int64_t nlat_out, int64_t nlon_out); +void s2_attention_fwd_ring_step_cuda( + at::Tensor kx, at::Tensor vx, at::Tensor qy, + at::Tensor y_acc, at::Tensor alpha_sum_buf, at::Tensor qdotk_max_buf, + at::Tensor quad_weights, at::Tensor psi_col_idx, at::Tensor psi_row_off, at::Tensor psi_row_idx, + int64_t nlon_in, int64_t lon_lo_kx, int64_t lat_halo_start, + int64_t nlat_out, int64_t nlon_out); + +void s2_attention_bwd_ring_step_pass1_cuda( + at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor dy, + at::Tensor alpha_sum_buf, at::Tensor qdotk_max_buf, at::Tensor integral_buf, + at::Tensor alpha_k_buf, at::Tensor alpha_kvw_buf, + at::Tensor quad_weights, at::Tensor psi_col_idx, at::Tensor psi_row_off, at::Tensor psi_row_idx, + int64_t nlon_in, int64_t lon_lo_kx, int64_t lat_halo_start, + int64_t nlat_out, int64_t nlon_out); + +void s2_attention_bwd_ring_step_pass2_cuda( + at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor dy, + at::Tensor alpha_sum_buf, at::Tensor qdotk_max_buf, at::Tensor integral_norm_buf, + at::Tensor dkx, at::Tensor dvx, + at::Tensor quad_weights, at::Tensor psi_col_idx, at::Tensor psi_row_off, at::Tensor psi_row_idx, + int64_t nlon_in, int64_t lon_lo_kx, int64_t lat_halo_start, + int64_t nlat_out, int64_t nlon_out); + } \ No newline at end of file diff --git a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu index be2b6dd9..a825d4d7 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu @@ -926,4 +926,498 @@ TORCH_LIBRARY_IMPL(attention_kernels, CUDA, m) m.impl("backward", &s2_attention_bwd_dkvq_cuda); } +// BEGIN - backward ring step kernels and functions + +// Pass 1: accumulate softmax statistics across ring steps. +// After all ring steps, finalize dqy in Python using the accumulated state. +// col_idx must have wi pre-shifted by lon_lo_out (see Python __init__ preprocessing). +template +__global__ +__launch_bounds__(BDIM_X) +void s2_attn_bwd_ring_step_pass1_generic_vec_k( + int nchan_in, + int nchan_out, + int nlat_halo, + int nlon_kx, + int nlon_in, + int lon_lo_kx, + int lat_halo_start, + int nlat_out, + int nlon_out, + const FLOATV_T *__restrict__ kx, // [batch][nlat_halo][nlon_kx][nchan_in] + const FLOATV_T *__restrict__ vx, // [batch][nlat_halo][nlon_kx][nchan_out] + const FLOATV_T *__restrict__ qy, // [batch][nlat_out][nlon_out][nchan_in] + const FLOATV_T *__restrict__ dy, // [batch][nlat_out][nlon_out][nchan_out] + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const float *__restrict__ quad_weights, + float *__restrict__ alpha_sum_buf, // [batch][nlat_out][nlon_out] (in/out) + float *__restrict__ qdotk_max_buf, // [batch][nlat_out][nlon_out] (in/out) + float *__restrict__ integral_buf, // [batch][nlat_out][nlon_out] unnormalized (in/out) + FLOATV_T *__restrict__ alpha_k_buf, // [batch][nlat_out][nlon_out][nchan_in] (in/out) + FLOATV_T *__restrict__ alpha_kvw_buf // [batch][nlat_out][nlon_out][nchan_in] (in/out) +) { + extern __shared__ __align__(sizeof(float4)) float shext[]; + // sh_alpha_k[nchan_in], sh_alpha_kvw[nchan_in], sh_dy[nchan_out] + FLOATV_T *sh_alpha_k = reinterpret_cast(shext) + threadIdx.y * (2*nchan_in + nchan_out); + FLOATV_T *sh_alpha_kvw = sh_alpha_k + nchan_in; + FLOATV_T *sh_dy = sh_alpha_kvw + nchan_in; + + const int batch = blockIdx.y; + const int wid = blockIdx.x * blockDim.y + threadIdx.y; + if (wid >= nlat_out * nlon_out) return; + + const int tidx = threadIdx.x; + const int h = wid / nlon_out; + const int wo = wid - (h * nlon_out); + const int ho = row_idx[h]; + + kx += int64_t(batch) * nlat_halo * nlon_kx * nchan_in; + vx += int64_t(batch) * nlat_halo * nlon_kx * nchan_out; + qy += int64_t(batch) * nlat_out * nlon_out * nchan_in + + int64_t(ho) * nlon_out * nchan_in + int64_t(wo) * nchan_in; + dy += int64_t(batch) * nlat_out * nlon_out * nchan_out + + int64_t(ho) * nlon_out * nchan_out + int64_t(wo) * nchan_out; + + const int64_t out_flat = int64_t(batch) * nlat_out * nlon_out + + int64_t(ho) * nlon_out + wo; + alpha_sum_buf += out_flat; + qdotk_max_buf += out_flat; + integral_buf += out_flat; + alpha_k_buf += out_flat * nchan_in; + alpha_kvw_buf += out_flat * nchan_in; + + // Load current state + float alpha_sum = alpha_sum_buf[0]; + float qdotk_max = qdotk_max_buf[0]; + float integral = integral_buf[0]; + + for (int chan = tidx; chan < nchan_in; chan += WARP_SIZE) { + sh_alpha_k[chan] = alpha_k_buf[chan]; + sh_alpha_kvw[chan] = alpha_kvw_buf[chan]; + } + for (int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + sh_dy[chan] = dy[chan]; + } + +#if __CUDA_ARCH__ < 900 + if constexpr(std::is_same::value) { __syncwarp(); } +#endif + + const int64_t rbeg = row_off[ho]; + const int64_t rend = row_off[ho + 1]; + col_idx += rbeg; + const int rlen = rend - rbeg; + + for (int off = 0; off < rlen; off++) { + const int64_t col = col_idx[off]; + const int hi_global = col / nlon_in; + const int wi = col - (hi_global * nlon_in); + const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; + + if (wip < lon_lo_kx || wip >= lon_lo_kx + nlon_kx) continue; + + const int hi_local = hi_global - lat_halo_start; + if (hi_local < 0 || hi_local >= nlat_halo) continue; + const int wip_local = wip - lon_lo_kx; + + const FLOATV_T *_kx = kx + int64_t(hi_local)*nlon_kx*nchan_in + int64_t(wip_local)*nchan_in; + const FLOATV_T *_vx = vx + int64_t(hi_local)*nlon_kx*nchan_out + int64_t(wip_local)*nchan_out; + + FLOATV_T qdotk_v = __vset(0.0f); + FLOATV_T gdotv_v = __vset(0.0f); + for (int chan = tidx; chan < nchan_in; chan += WARP_SIZE) { + qdotk_v = __vadd(qdotk_v, __vmul(qy[chan], _kx[chan])); + } + for (int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[chan], _vx[chan])); + } + const float qdotk = __warp_sum(__vred(qdotk_v)); + const float gdotv = __warp_sum(__vred(gdotv_v)); + + const float qdotk_max_tmp = max(qdotk_max, qdotk); + const float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi_global]; + const float max_correction = expf(qdotk_max - qdotk_max_tmp); + + alpha_sum = alpha_sum * max_correction + alpha_inz; + integral = integral * max_correction + alpha_inz * gdotv; + + const float ainz_gdotv = alpha_inz * gdotv; + for (int chan = tidx; chan < nchan_in; chan += WARP_SIZE) { + const FLOATV_T kxval = _kx[chan]; + sh_alpha_k[chan] = __vadd(__vscale(max_correction, sh_alpha_k[chan]), + __vscale(alpha_inz, kxval)); + sh_alpha_kvw[chan] = __vadd(__vscale(max_correction, sh_alpha_kvw[chan]), + __vscale(ainz_gdotv, kxval)); + } + qdotk_max = qdotk_max_tmp; + } + + // Store updated state + alpha_sum_buf[0] = alpha_sum; + qdotk_max_buf[0] = qdotk_max; + integral_buf[0] = integral; + for (int chan = tidx; chan < nchan_in; chan += WARP_SIZE) { + alpha_k_buf[chan] = sh_alpha_k[chan]; + alpha_kvw_buf[chan] = sh_alpha_kvw[chan]; + } +} + +// Pass 2: scatter dkx/dvx contributions for the current KV chunk. +// Requires FINALIZED state from pass 1: alpha_sum, qdotk_max, integral_norm (= integral/alpha_sum). +template +__global__ +__launch_bounds__(BDIM_X) +void s2_attn_bwd_ring_step_pass2_generic_vec_k( + int nchan_in, + int nchan_out, + int nlat_halo, + int nlon_kx, + int nlon_in, + int lon_lo_kx, + int lat_halo_start, + int nlat_out, + int nlon_out, + const FLOATV_T *__restrict__ kx, + const FLOATV_T *__restrict__ vx, + const FLOATV_T *__restrict__ qy, + const FLOATV_T *__restrict__ dy, + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const float *__restrict__ quad_weights, + const float *__restrict__ alpha_sum_buf, // finalized [batch][nlat_out][nlon_out] + const float *__restrict__ qdotk_max_buf, // finalized [batch][nlat_out][nlon_out] + const float *__restrict__ integral_norm_buf, // finalized, normalized [batch][nlat_out][nlon_out] + FLOATV_T *__restrict__ dkx, // [batch][nlat_halo][nlon_kx][nchan_in] (atomically updated) + FLOATV_T *__restrict__ dvx // [batch][nlat_halo][nlon_kx][nchan_out] (atomically updated) +) { + extern __shared__ __align__(sizeof(float4)) float shext[]; + FLOATV_T *sh_qy = reinterpret_cast(shext) + threadIdx.y * (nchan_in + nchan_out); + FLOATV_T *sh_dy = sh_qy + nchan_in; + + const int batch = blockIdx.y; + const int wid = blockIdx.x * blockDim.y + threadIdx.y; + if (wid >= nlat_out * nlon_out) return; + + const int tidx = threadIdx.x; + const int h = wid / nlon_out; + const int wo = wid - (h * nlon_out); + const int ho = row_idx[h]; + + kx += int64_t(batch) * nlat_halo * nlon_kx * nchan_in; + vx += int64_t(batch) * nlat_halo * nlon_kx * nchan_out; + dkx += int64_t(batch) * nlat_halo * nlon_kx * nchan_in; + dvx += int64_t(batch) * nlat_halo * nlon_kx * nchan_out; + + qy += int64_t(batch) * nlat_out * nlon_out * nchan_in + + int64_t(ho) * nlon_out * nchan_in + int64_t(wo) * nchan_in; + dy += int64_t(batch) * nlat_out * nlon_out * nchan_out + + int64_t(ho) * nlon_out * nchan_out + int64_t(wo) * nchan_out; + + const int64_t out_flat = int64_t(batch) * nlat_out * nlon_out + int64_t(ho) * nlon_out + wo; + const float alpha_sum = alpha_sum_buf[out_flat]; + const float qdotk_max = qdotk_max_buf[out_flat]; + const float integral_norm = integral_norm_buf[out_flat]; + const float alpha_sum_inv = 1.0f / alpha_sum; + + for (int chan = tidx; chan < nchan_in; chan += WARP_SIZE) { + sh_qy[chan] = qy[chan]; + } + for (int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + sh_dy[chan] = dy[chan]; + } + +#if __CUDA_ARCH__ < 900 + if constexpr(std::is_same::value) { __syncwarp(); } +#endif + + const int64_t rbeg = row_off[ho]; + const int64_t rend = row_off[ho + 1]; + col_idx += rbeg; + const int rlen = rend - rbeg; + + for (int off = 0; off < rlen; off++) { + const int64_t col = col_idx[off]; + const int hi_global = col / nlon_in; + const int wi = col - (hi_global * nlon_in); + const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; + + if (wip < lon_lo_kx || wip >= lon_lo_kx + nlon_kx) continue; + + const int hi_local = hi_global - lat_halo_start; + if (hi_local < 0 || hi_local >= nlat_halo) continue; + const int wip_local = wip - lon_lo_kx; + + const FLOATV_T *_kx = kx + int64_t(hi_local)*nlon_kx*nchan_in + int64_t(wip_local)*nchan_in; + const FLOATV_T *_vx = vx + int64_t(hi_local)*nlon_kx*nchan_out + int64_t(wip_local)*nchan_out; + FLOATV_T *_dkx = dkx + int64_t(hi_local)*nlon_kx*nchan_in + int64_t(wip_local)*nchan_in; + FLOATV_T *_dvx = dvx + int64_t(hi_local)*nlon_kx*nchan_out + int64_t(wip_local)*nchan_out; + + FLOATV_T qdotk_v = __vset(0.0f); + FLOATV_T gdotv_v = __vset(0.0f); + for (int chan = tidx; chan < nchan_in; chan += WARP_SIZE) { + qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[chan], _kx[chan])); + } + for (int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[chan], _vx[chan])); + } + const float qdotk = __warp_sum(__vred(qdotk_v)); + const float gdotv = __warp_sum(__vred(gdotv_v)); + + const float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi_global]; + const float alpha_mul = alpha_inz * alpha_sum_inv; + const float scale_dkx = (gdotv - integral_norm) * alpha_mul; + const float scale_dvx = alpha_mul; + +#if __CUDA_ARCH__ < 900 + float *sh_qy_scl = reinterpret_cast(sh_qy); + float *sh_dy_scl = reinterpret_cast(sh_dy); + float *_dkx_scl = reinterpret_cast(_dkx); + float *_dvx_scl = reinterpret_cast(_dvx); + constexpr int VEC_SIZE = sizeof(FLOATV_T)/sizeof(float); + for (int chan = tidx; chan < nchan_in*VEC_SIZE; chan += WARP_SIZE) { + atomicAdd(_dkx_scl + chan, scale_dkx * sh_qy_scl[chan]); + } + for (int chan = tidx; chan < nchan_out*VEC_SIZE; chan += WARP_SIZE) { + atomicAdd(_dvx_scl + chan, scale_dvx * sh_dy_scl[chan]); + } +#else + for (int chan = tidx; chan < nchan_in; chan += WARP_SIZE) { + atomicAdd(_dkx + chan, __vscale(scale_dkx, sh_qy[chan])); + } + for (int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + atomicAdd(_dvx + chan, __vscale(scale_dvx, sh_dy[chan])); + } +#endif + } +} + +static void s2_attn_bwd_ring_step_pass1_dispatch( + int64_t batch_size, int64_t nchans_in, int64_t nchans_out, + int64_t nlon_in, int64_t nlat_halo, int64_t nlon_kx, + int64_t lon_lo_kx, int64_t lat_halo_start, int64_t nlat_out, int64_t nlon_out, + at::Tensor kxP, at::Tensor vxP, at::Tensor qyP, at::Tensor dyP, + at::Tensor row_idx, at::Tensor row_off, at::Tensor col_idx, at::Tensor quad_weights, + at::Tensor alpha_sum_buf, at::Tensor qdotk_max_buf, at::Tensor integral_buf, + at::Tensor alpha_k_buf, at::Tensor alpha_kvw_buf) { + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + dim3 block(WARP_SIZE, THREADS / WARP_SIZE); + dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size); + + float *_kxp = reinterpret_cast(kxP.data_ptr()); + float *_vxp = reinterpret_cast(vxP.data_ptr()); + float *_qyp = reinterpret_cast(qyP.data_ptr()); + float *_dyp = reinterpret_cast(dyP.data_ptr()); + float *_alpha_sum = reinterpret_cast(alpha_sum_buf.data_ptr()); + float *_qdotk_max = reinterpret_cast(qdotk_max_buf.data_ptr()); + float *_integral = reinterpret_cast(integral_buf.data_ptr()); + int32_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); + float *_quad_weights = reinterpret_cast(quad_weights.data_ptr()); + + constexpr int VEC_SIZE = sizeof(float4) / sizeof(float); + + if (!is_aligned(_kxp) || + !is_aligned(_vxp) || + !is_aligned(_qyp) || + !is_aligned(_dyp) || + (nchans_in % VEC_SIZE) != 0 || + (nchans_out % VEC_SIZE) != 0) { + + size_t shsize = sizeof(float) * (2*nchans_in + nchans_out) * block.y; + float *_alpha_k = reinterpret_cast(alpha_k_buf.data_ptr()); + float *_alpha_kvw = reinterpret_cast(alpha_kvw_buf.data_ptr()); + s2_attn_bwd_ring_step_pass1_generic_vec_k + <<>>( + nchans_in, nchans_out, nlat_halo, nlon_kx, + nlon_in, lon_lo_kx, lat_halo_start, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, + _alpha_sum, _qdotk_max, _integral, _alpha_k, _alpha_kvw); + CHECK_ERROR("s2_attn_bwd_ring_step_pass1_generic_vec_k"); + + } else { + + float4 *_kxp4 = reinterpret_cast(_kxp); + float4 *_vxp4 = reinterpret_cast(_vxp); + float4 *_qyp4 = reinterpret_cast(_qyp); + float4 *_dyp4 = reinterpret_cast(_dyp); + float4 *_alpha_k4 = reinterpret_cast(alpha_k_buf.data_ptr()); + float4 *_alpha_kvw4 = reinterpret_cast(alpha_kvw_buf.data_ptr()); + + size_t shsize = sizeof(float4) * (2*(nchans_in/VEC_SIZE) + nchans_out/VEC_SIZE) * block.y; + s2_attn_bwd_ring_step_pass1_generic_vec_k + <<>>( + nchans_in/VEC_SIZE, nchans_out/VEC_SIZE, nlat_halo, nlon_kx, + nlon_in, lon_lo_kx, lat_halo_start, nlat_out, nlon_out, + _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, + _alpha_sum, _qdotk_max, _integral, _alpha_k4, _alpha_kvw4); + CHECK_ERROR("s2_attn_bwd_ring_step_pass1_generic_vec_k"); + } +} + +static void s2_attn_bwd_ring_step_pass2_dispatch( + int64_t batch_size, int64_t nchans_in, int64_t nchans_out, + int64_t nlon_in, int64_t nlat_halo, int64_t nlon_kx, + int64_t lon_lo_kx, int64_t lat_halo_start, int64_t nlat_out, int64_t nlon_out, + at::Tensor kxP, at::Tensor vxP, at::Tensor qyP, at::Tensor dyP, + at::Tensor row_idx, at::Tensor row_off, at::Tensor col_idx, at::Tensor quad_weights, + at::Tensor alpha_sum_buf, at::Tensor qdotk_max_buf, at::Tensor integral_norm_buf, + at::Tensor dkxP, at::Tensor dvxP) { + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + dim3 block(WARP_SIZE, THREADS / WARP_SIZE); + dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size); + + float *_kxp = reinterpret_cast(kxP.data_ptr()); + float *_vxp = reinterpret_cast(vxP.data_ptr()); + float *_qyp = reinterpret_cast(qyP.data_ptr()); + float *_dyp = reinterpret_cast(dyP.data_ptr()); + float *_alpha_sum = reinterpret_cast(alpha_sum_buf.data_ptr()); + float *_qdotk_max = reinterpret_cast(qdotk_max_buf.data_ptr()); + float *_integral_n = reinterpret_cast(integral_norm_buf.data_ptr()); + int32_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); + float *_quad_weights = reinterpret_cast(quad_weights.data_ptr()); + + constexpr int VEC_SIZE = sizeof(float4) / sizeof(float); + + if (!is_aligned(_kxp) || + !is_aligned(_vxp) || + !is_aligned(_qyp) || + !is_aligned(_dyp) || + (nchans_in % VEC_SIZE) != 0 || + (nchans_out % VEC_SIZE) != 0) { + + float *_dkxp = reinterpret_cast(dkxP.data_ptr()); + float *_dvxp = reinterpret_cast(dvxP.data_ptr()); + size_t shsize = sizeof(float) * (nchans_in + nchans_out) * block.y; + s2_attn_bwd_ring_step_pass2_generic_vec_k + <<>>( + nchans_in, nchans_out, nlat_halo, nlon_kx, + nlon_in, lon_lo_kx, lat_halo_start, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, + _alpha_sum, _qdotk_max, _integral_n, _dkxp, _dvxp); + CHECK_ERROR("s2_attn_bwd_ring_step_pass2_generic_vec_k"); + + } else { + + float4 *_kxp4 = reinterpret_cast(_kxp); + float4 *_vxp4 = reinterpret_cast(_vxp); + float4 *_qyp4 = reinterpret_cast(_qyp); + float4 *_dyp4 = reinterpret_cast(_dyp); + float4 *_dkxp4 = reinterpret_cast(dkxP.data_ptr()); + float4 *_dvxp4 = reinterpret_cast(dvxP.data_ptr()); + size_t shsize = sizeof(float4) * ((nchans_in + nchans_out)/VEC_SIZE) * block.y; + s2_attn_bwd_ring_step_pass2_generic_vec_k + <<>>( + nchans_in/VEC_SIZE, nchans_out/VEC_SIZE, nlat_halo, nlon_kx, + nlon_in, lon_lo_kx, lat_halo_start, nlat_out, nlon_out, + _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, + _alpha_sum, _qdotk_max, _integral_n, _dkxp4, _dvxp4); + CHECK_ERROR("s2_attn_bwd_ring_step_pass2_generic_vec_k"); + } +} + +void s2_attention_bwd_ring_step_pass1_cuda( + at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor dy, + at::Tensor alpha_sum_buf, at::Tensor qdotk_max_buf, at::Tensor integral_buf, + at::Tensor alpha_k_buf, at::Tensor alpha_kvw_buf, + at::Tensor quad_weights, at::Tensor psi_col_idx, at::Tensor psi_row_off, at::Tensor psi_row_idx, + int64_t nlon_in, int64_t lon_lo_kx, int64_t lat_halo_start, + int64_t nlat_out, int64_t nlon_out) +{ + CHECK_CUDA_INPUT_TENSOR(kx); CHECK_CUDA_INPUT_TENSOR(vx); + CHECK_CUDA_INPUT_TENSOR(qy); CHECK_CUDA_INPUT_TENSOR(dy); + CHECK_CUDA_TENSOR(alpha_sum_buf); CHECK_CUDA_TENSOR(qdotk_max_buf); + CHECK_CUDA_TENSOR(integral_buf); CHECK_CUDA_TENSOR(alpha_k_buf); + CHECK_CUDA_TENSOR(alpha_kvw_buf); CHECK_CUDA_TENSOR(quad_weights); + CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_row_off); CHECK_CUDA_TENSOR(psi_row_idx); + + const int batch_size = kx.size(0); + const int nlat_halo = kx.size(2); // kx is [B,C,H,W], H is dim 2 + const int nlon_kx = kx.size(3); // W is dim 3 + const size_t nchans_in = qy.size(1); + const size_t nchans_out = vx.size(1); + + torch::Tensor kxP = kx.to(torch::kFloat32); + torch::Tensor vxP = vx.to(torch::kFloat32); + torch::Tensor qyP = qy.to(torch::kFloat32); + torch::Tensor dyP = dy.to(torch::kFloat32); + + if (kxP.strides()[1] != 1) { kxP = permute_4D_to0231(kxP); } + if (vxP.strides()[1] != 1) { vxP = permute_4D_to0231(vxP); } + if (qyP.strides()[1] != 1) { qyP = permute_4D_to0231(qyP); } + if (dyP.strides()[1] != 1) { dyP = permute_4D_to0231(dyP); } + + s2_attn_bwd_ring_step_pass1_dispatch( + batch_size, nchans_in, nchans_out, + nlon_in, nlat_halo, nlon_kx, lon_lo_kx, lat_halo_start, + nlat_out, nlon_out, + kxP, vxP, qyP, dyP, + psi_row_idx, psi_row_off, psi_col_idx, quad_weights, + alpha_sum_buf, qdotk_max_buf, integral_buf, alpha_k_buf, alpha_kvw_buf); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void s2_attention_bwd_ring_step_pass2_cuda( + at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor dy, + at::Tensor alpha_sum_buf, at::Tensor qdotk_max_buf, at::Tensor integral_norm_buf, + at::Tensor dkx, at::Tensor dvx, + at::Tensor quad_weights, at::Tensor psi_col_idx, at::Tensor psi_row_off, at::Tensor psi_row_idx, + int64_t nlon_in, int64_t lon_lo_kx, int64_t lat_halo_start, + int64_t nlat_out, int64_t nlon_out) +{ + CHECK_CUDA_INPUT_TENSOR(kx); CHECK_CUDA_INPUT_TENSOR(vx); + CHECK_CUDA_INPUT_TENSOR(qy); CHECK_CUDA_INPUT_TENSOR(dy); + CHECK_CUDA_TENSOR(alpha_sum_buf); CHECK_CUDA_TENSOR(qdotk_max_buf); + CHECK_CUDA_TENSOR(integral_norm_buf); + CHECK_CUDA_TENSOR(dkx); CHECK_CUDA_TENSOR(dvx); + CHECK_CUDA_TENSOR(quad_weights); + CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_row_off); CHECK_CUDA_TENSOR(psi_row_idx); + + const int batch_size = kx.size(0); + const int nlat_halo = kx.size(2); // kx is [B,C,H,W], H is dim 2 + const int nlon_kx = kx.size(3); // W is dim 3 + const size_t nchans_in = qy.size(1); + const size_t nchans_out = vx.size(1); + + torch::Tensor kxP = kx.to(torch::kFloat32); + torch::Tensor vxP = vx.to(torch::kFloat32); + torch::Tensor qyP = qy.to(torch::kFloat32); + torch::Tensor dyP = dy.to(torch::kFloat32); + + if (kxP.strides()[1] != 1) { kxP = permute_4D_to0231(kxP); } + if (vxP.strides()[1] != 1) { vxP = permute_4D_to0231(vxP); } + if (qyP.strides()[1] != 1) { qyP = permute_4D_to0231(qyP); } + if (dyP.strides()[1] != 1) { dyP = permute_4D_to0231(dyP); } + + // dkx/dvx are already in channels-last format (allocated that way in Python) + s2_attn_bwd_ring_step_pass2_dispatch( + batch_size, nchans_in, nchans_out, + nlon_in, nlat_halo, nlon_kx, lon_lo_kx, lat_halo_start, + nlat_out, nlon_out, + kxP, vxP, qyP, dyP, + psi_row_idx, psi_row_off, psi_col_idx, quad_weights, + alpha_sum_buf, qdotk_max_buf, integral_norm_buf, + dkx, dvx); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +TORCH_LIBRARY_IMPL(attention_kernels, CUDA, m) +{ + m.impl("backward_ring_step_pass1", &s2_attention_bwd_ring_step_pass1_cuda); + m.impl("backward_ring_step_pass2", &s2_attention_bwd_ring_step_pass2_cuda); +} + +// END - backward ring step kernels and functions + } diff --git a/torch_harmonics/attention/csrc/attention_cuda_fwd.cu b/torch_harmonics/attention/csrc/attention_cuda_fwd.cu index 2c7f63ae..16ea96b6 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_fwd.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_fwd.cu @@ -583,4 +583,256 @@ TORCH_LIBRARY_IMPL(attention_kernels, CUDA, m) m.impl("forward", &s2_attention_fwd_cuda); } +// BEGIN - forward ring step kernel and functions +// Ring step variant: processes one KV chunk per call, accumulates into external state buffers. +// col_idx must have wi pre-shifted by lon_lo_out (see Python __init__ preprocessing). + +template +__global__ +__launch_bounds__(BDIM_X) +void s2_attn_fwd_ring_step_generic_vec_k( + int nchan_in, // no. of FLOATV_T elements along channel dim + int nchan_out, // no. of FLOATV_T elements along channel dim + int nlat_halo, // number of lat rows in kx/vx chunk (with halo) + int nlon_kx, // number of lon columns in kx/vx chunk + int nlon_in, // GLOBAL nlon_in (for modular arithmetic) + int lon_lo_kx, // global lon start of kx chunk + int lat_halo_start, // global lat index of first row in kx chunk + int nlat_out, // local output lat size + int nlon_out, // local output lon size + const FLOATV_T *__restrict__ kx, // [batch][nlat_halo][nlon_kx][nchan_in] + const FLOATV_T *__restrict__ vx, // [batch][nlat_halo][nlon_kx][nchan_out] + const FLOATV_T *__restrict__ qy, // [batch][nlat_out][nlon_out][nchan_in] + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, // wi already shifted by lon_lo_out + const float *__restrict__ quad_weights, // [nlat_in_global] + FLOATV_T *__restrict__ y_acc, // [batch][nlat_out][nlon_out][nchan_out] (in/out) + float *__restrict__ alpha_sum_buf, // [batch][nlat_out][nlon_out] (in/out) + float *__restrict__ qdotk_max_buf // [batch][nlat_out][nlon_out] (in/out) +) { + extern __shared__ __align__(sizeof(float4)) float shext[]; + FLOATV_T *shy = reinterpret_cast(shext) + threadIdx.y * nchan_out; + + const int batch = blockIdx.y; + const int wid = blockIdx.x * blockDim.y + threadIdx.y; + if (wid >= nlat_out * nlon_out) return; + + const int tidx = threadIdx.x; + const int h = wid / nlon_out; + const int wo = wid - (h * nlon_out); // LOCAL wo + const int ho = row_idx[h]; + + kx += int64_t(batch) * nlat_halo * nlon_kx * nchan_in; + vx += int64_t(batch) * nlat_halo * nlon_kx * nchan_out; + qy += int64_t(batch) * nlat_out * nlon_out * nchan_in + + int64_t(ho) * nlon_out * nchan_in + + int64_t(wo) * nchan_in; + + const int64_t out_flat = int64_t(batch) * nlat_out * nlon_out + + int64_t(ho) * nlon_out + wo; + y_acc += out_flat * nchan_out; + alpha_sum_buf += out_flat; + qdotk_max_buf += out_flat; + + // Load current state from buffers + float alpha_sum = alpha_sum_buf[0]; + float qdotk_max = qdotk_max_buf[0]; + for (int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + shy[chan] = y_acc[chan]; + } + + const int64_t rbeg = row_off[ho]; + const int64_t rend = row_off[ho + 1]; + col_idx += rbeg; + const int rlen = rend - rbeg; + + for (int off = 0; off < rlen; off++) { + const int64_t col = col_idx[off]; + + // col_idx stores hi_global * nlon_in + wi_shifted + // where wi_shifted = (wi_canonical + lon_lo_out) % nlon_in (baked in at Python __init__) + const int hi_global = col / nlon_in; + const int wi = col - (hi_global * nlon_in); + // wip = (wi + wo_local) % nlon_in = (wi_canonical + lon_lo_out + wo_local) % nlon_in + const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; + + // Skip neighbors not in current kx chunk + if (wip < lon_lo_kx || wip >= lon_lo_kx + nlon_kx) continue; + + const int hi_local = hi_global - lat_halo_start; + // Skip neighbors outside the halo-padded lat range (needed for distributed case) + if (hi_local < 0 || hi_local >= nlat_halo) continue; + const int wip_local = wip - lon_lo_kx; + + const FLOATV_T *_kx = kx + int64_t(hi_local) * nlon_kx * nchan_in + int64_t(wip_local) * nchan_in; + const FLOATV_T *_vx = vx + int64_t(hi_local) * nlon_kx * nchan_out + int64_t(wip_local) * nchan_out; + + FLOATV_T qdotkv = __vset(0.f); + for (int chan = tidx; chan < nchan_in; chan += WARP_SIZE) { + qdotkv = __vadd(qdotkv, __vmul(qy[chan], _kx[chan])); + } + float qdotk = __warp_sum(__vred(qdotkv)); + + const float qdotk_max_tmp = max(qdotk_max, qdotk); + const float alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi_global]; + const float exp_save = expf(qdotk_max - qdotk_max_tmp); + + alpha_sum = alpha + alpha_sum * exp_save; + for (int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + shy[chan] = __vadd(__vscale(exp_save, shy[chan]), + __vscale(alpha, _vx[chan])); + } + qdotk_max = qdotk_max_tmp; + } + + // Store updated state back to buffers + alpha_sum_buf[0] = alpha_sum; + qdotk_max_buf[0] = qdotk_max; + for (int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + y_acc[chan] = shy[chan]; + } +} + +static void s2_attn_fwd_ring_step_dispatch( + int64_t batch_size, + int64_t nchans_in, + int64_t nchans_out, + int64_t nlon_in, + int64_t nlat_halo, + int64_t nlon_kx, + int64_t lon_lo_kx, + int64_t lat_halo_start, + int64_t nlat_out, + int64_t nlon_out, + at::Tensor kxP, + at::Tensor vxP, + at::Tensor qyP, + at::Tensor row_idx, + at::Tensor row_off, + at::Tensor col_idx, + at::Tensor quad_weights, + at::Tensor y_acc, + at::Tensor alpha_sum_buf, + at::Tensor qdotk_max_buf) { + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + dim3 block(WARP_SIZE, THREADS / WARP_SIZE); + dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size); + + float *_kxp = reinterpret_cast(kxP.data_ptr()); + float *_vxp = reinterpret_cast(vxP.data_ptr()); + float *_qyp = reinterpret_cast(qyP.data_ptr()); + float *_y_acc = reinterpret_cast(y_acc.data_ptr()); + float *_alpha_sum = reinterpret_cast(alpha_sum_buf.data_ptr()); + float *_qdotk_max = reinterpret_cast(qdotk_max_buf.data_ptr()); + int32_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); + float *_quad_weights = reinterpret_cast(quad_weights.data_ptr()); + + constexpr int VEC_SIZE = sizeof(float4) / sizeof(float); + + if (!is_aligned(_kxp) || + !is_aligned(_vxp) || + !is_aligned(_qyp) || + !is_aligned(_y_acc) || + (nchans_in % VEC_SIZE) != 0 || + (nchans_out % VEC_SIZE) != 0) { + + size_t shsize = sizeof(float) * nchans_out * block.y; + s2_attn_fwd_ring_step_generic_vec_k + <<>>( + nchans_in, nchans_out, nlat_halo, nlon_kx, + nlon_in, lon_lo_kx, lat_halo_start, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, + _y_acc, _alpha_sum, _qdotk_max); + CHECK_ERROR("s2_attn_fwd_ring_step_generic_vec_k"); + + } else { + + float4 *_kxp4 = reinterpret_cast(_kxp); + float4 *_vxp4 = reinterpret_cast(_vxp); + float4 *_qyp4 = reinterpret_cast(_qyp); + float4 *_yacc4 = reinterpret_cast(_y_acc); + + size_t shsize = sizeof(float4) * (nchans_out / VEC_SIZE) * block.y; + s2_attn_fwd_ring_step_generic_vec_k + <<>>( + nchans_in / VEC_SIZE, nchans_out / VEC_SIZE, + nlat_halo, nlon_kx, nlon_in, lon_lo_kx, lat_halo_start, + nlat_out, nlon_out, + _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, + _yacc4, _alpha_sum, _qdotk_max); + CHECK_ERROR("s2_attn_fwd_ring_step_generic_vec_k"); + } +} + +void s2_attention_fwd_ring_step_cuda( + at::Tensor kx, + at::Tensor vx, + at::Tensor qy, + at::Tensor y_acc, + at::Tensor alpha_sum_buf, + at::Tensor qdotk_max_buf, + at::Tensor quad_weights, + at::Tensor psi_col_idx, + at::Tensor psi_row_off, + at::Tensor psi_row_idx, + int64_t nlon_in, + int64_t lon_lo_kx, + int64_t lat_halo_start, + int64_t nlat_out, + int64_t nlon_out) +{ + CHECK_CUDA_INPUT_TENSOR(kx); + CHECK_CUDA_INPUT_TENSOR(vx); + CHECK_CUDA_INPUT_TENSOR(qy); + CHECK_CUDA_TENSOR(y_acc); + CHECK_CUDA_TENSOR(alpha_sum_buf); + CHECK_CUDA_TENSOR(qdotk_max_buf); + CHECK_CUDA_TENSOR(quad_weights); + CHECK_CUDA_TENSOR(psi_col_idx); + CHECK_CUDA_TENSOR(psi_row_off); + CHECK_CUDA_TENSOR(psi_row_idx); + + const int batch_size = kx.size(0); + const int nlat_halo = kx.size(2); // kx is [B,C,H,W], H is dim 2 + const int nlon_kx = kx.size(3); // W is dim 3 + const size_t nchans_in = qy.size(1); + const size_t nchans_out = vx.size(1); + + torch::Tensor kxP = kx.to(torch::kFloat32); + torch::Tensor vxP = vx.to(torch::kFloat32); + torch::Tensor qyP = qy.to(torch::kFloat32); + + bool kx_is_channels_last = kxP.strides()[1] == 1; + bool vx_is_channels_last = vxP.strides()[1] == 1; + bool qy_is_channels_last = qyP.strides()[1] == 1; + + if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); } + if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); } + if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); } + + s2_attn_fwd_ring_step_dispatch( + batch_size, nchans_in, nchans_out, + nlon_in, nlat_halo, nlon_kx, lon_lo_kx, lat_halo_start, + nlat_out, nlon_out, + kxP, vxP, qyP, + psi_row_idx, psi_row_off, psi_col_idx, + quad_weights, + y_acc, alpha_sum_buf, qdotk_max_buf); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +TORCH_LIBRARY_IMPL(attention_kernels, CUDA, m) +{ + m.impl("forward_ring_step", &s2_attention_fwd_ring_step_cuda); +} + +// END - forward ring step kernel and functions + } diff --git a/torch_harmonics/attention/csrc/attention_interface.cpp b/torch_harmonics/attention/csrc/attention_interface.cpp index 89ff8d43..c3e9c422 100644 --- a/torch_harmonics/attention/csrc/attention_interface.cpp +++ b/torch_harmonics/attention/csrc/attention_interface.cpp @@ -56,6 +56,9 @@ namespace attention_kernels { TORCH_LIBRARY(attention_kernels, m) { m.def("forward(Tensor kx, Tensor vx, Tensor qy, Tensor quad_weights, Tensor col_idx, Tensor row_off, int nlon_in, int nlat_out, int nlon_out) -> Tensor", {at::Tag::pt2_compliant_tag}); m.def("backward(Tensor kx, Tensor vx, Tensor qy, Tensor dy, Tensor quad_weights, Tensor col_idx, Tensor row_off, int nlon_in, int nlat_out, int nlon_out) -> (Tensor, Tensor, Tensor)", {at::Tag::pt2_compliant_tag}); + m.def("forward_ring_step(Tensor kx, Tensor vx, Tensor qy, Tensor(a!) y_acc, Tensor(b!) alpha_sum_buf, Tensor(c!) qdotk_max_buf, Tensor quad_weights, Tensor col_idx, Tensor row_off, Tensor row_idx, int nlon_in, int lon_lo_kx, int lat_halo_start, int nlat_out, int nlon_out) -> ()"); + m.def("backward_ring_step_pass1(Tensor kx, Tensor vx, Tensor qy, Tensor dy, Tensor(a!) alpha_sum_buf, Tensor(b!) qdotk_max_buf, Tensor(c!) integral_buf, Tensor(d!) alpha_k_buf, Tensor(e!) alpha_kvw_buf, Tensor quad_weights, Tensor col_idx, Tensor row_off, Tensor row_idx, int nlon_in, int lon_lo_kx, int lat_halo_start, int nlat_out, int nlon_out) -> ()"); + m.def("backward_ring_step_pass2(Tensor kx, Tensor vx, Tensor qy, Tensor dy, Tensor alpha_sum_buf, Tensor qdotk_max_buf, Tensor integral_norm_buf, Tensor(a!) dkx, Tensor(b!) dvx, Tensor quad_weights, Tensor col_idx, Tensor row_off, Tensor row_idx, int nlon_in, int lon_lo_kx, int lat_halo_start, int nlat_out, int nlon_out) -> ()"); } } \ No newline at end of file diff --git a/torch_harmonics/distributed/__init__.py b/torch_harmonics/distributed/__init__.py index 5728afca..f457df39 100644 --- a/torch_harmonics/distributed/__init__.py +++ b/torch_harmonics/distributed/__init__.py @@ -63,3 +63,6 @@ # import resampling from .distributed_resample import DistributedResampleS2 + +# import distributed neighborhood attention +from .distributed_attention import DistributedNeighborhoodAttentionS2 From 02443c7af6abd4926f8f61e0a3e6a6ce1b164b4c Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 30 Mar 2026 07:04:13 -0700 Subject: [PATCH 02/17] adding missing files --- tests/test_distributed_attention.py | 216 ++++++ .../distributed/distributed_attention.py | 679 ++++++++++++++++++ 2 files changed, 895 insertions(+) create mode 100644 tests/test_distributed_attention.py create mode 100644 torch_harmonics/distributed/distributed_attention.py diff --git a/tests/test_distributed_attention.py b/tests/test_distributed_attention.py new file mode 100644 index 00000000..8ac58e96 --- /dev/null +++ b/tests/test_distributed_attention.py @@ -0,0 +1,216 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2026 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import unittest +from parameterized import parameterized + +import torch +import torch_harmonics as th +import torch_harmonics.distributed as thd + +from testutils import ( + set_seed, + setup_module, + teardown_module, + setup_class_from_context, + split_tensor_hw, + gather_tensor_hw, + compare_tensors, +) + +# shared state +_DIST_CTX = {} + +def setUpModule(): + setup_module(_DIST_CTX) + +def tearDownModule(): + teardown_module(_DIST_CTX) + + +class TestDistributedNeighborhoodAttention(unittest.TestCase): + """ + Compare serial NeighborhoodAttentionS2 against DistributedNeighborhoodAttentionS2. + + CPU-only runs are skipped: distributed attention requires CUDA (NCCL + custom kernels). + """ + + @classmethod + def setUpClass(cls): + setup_class_from_context(cls, _DIST_CTX) + if not torch.cuda.is_available(): + raise unittest.SkipTest("Distributed neighborhood attention requires CUDA") + + def _split_helper(self, tensor): + return split_tensor_hw( + tensor, + hdim=-2, + wdim=-1, + hsize=self.grid_size_h, + wsize=self.grid_size_w, + hrank=self.hrank, + wrank=self.wrank, + ) + + def _gather_helper_fwd(self, tensor, attn_dist): + return gather_tensor_hw( + tensor, + hdim=-2, + wdim=-1, + hshapes=attn_dist.lat_out_shapes, + wshapes=attn_dist.lon_out_shapes, + hsize=self.grid_size_h, + wsize=self.grid_size_w, + hrank=self.hrank, + wrank=self.wrank, + hgroup=self.h_group, + wgroup=self.w_group, + ) + + def _gather_helper_bwd(self, tensor, attn_dist): + return gather_tensor_hw( + tensor, + hdim=-2, + wdim=-1, + hshapes=attn_dist.lat_in_shapes, + wshapes=attn_dist.lon_in_shapes, + hsize=self.grid_size_h, + wsize=self.grid_size_w, + hrank=self.hrank, + wrank=self.wrank, + hgroup=self.h_group, + wgroup=self.w_group, + ) + + @parameterized.expand( + [ + # nlat_in, nlon_in, nlat_out, nlon_out, batch_size, in_channels, num_heads, k_channels, out_channels, grid_in, grid_out, atol, rtol + [64, 128, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], + [64, 128, 64, 128, 2, 16, 2, None, None, "equiangular", "equiangular", 1e-5, 1e-4], + [64, 128, 64, 128, 2, 16, 1, 8, 8, "equiangular", "equiangular", 1e-5, 1e-4], + # [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], + [65, 128, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], + ], + skip_on_empty=True, + ) + def test_distributed_neighborhood_attention( + self, + nlat_in, nlon_in, nlat_out, nlon_out, + batch_size, in_channels, num_heads, k_channels, out_channels, + grid_in, grid_out, + atol, rtol, + verbose=True, + ): + set_seed(333) + + B, C, Hi, Wi, Ho, Wo = batch_size, in_channels, nlat_in, nlon_in, nlat_out, nlon_out + + attn_args = dict( + in_channels=C, + in_shape=(nlat_in, nlon_in), + out_shape=(nlat_out, nlon_out), + grid_in=grid_in, + grid_out=grid_out, + num_heads=num_heads, + bias=True, + k_channels=k_channels, + out_channels=out_channels, + ) + + # build serial and distributed modules with identical weights + attn_serial = th.NeighborhoodAttentionS2(**attn_args).to(self.device) + attn_dist = thd.DistributedNeighborhoodAttentionS2(**attn_args).to(self.device) + + with torch.no_grad(): + attn_dist.k_weights.copy_(attn_serial.k_weights) + attn_dist.v_weights.copy_(attn_serial.v_weights) + attn_dist.q_weights.copy_(attn_serial.q_weights) + attn_dist.proj_weights.copy_(attn_serial.proj_weights) + if attn_args["bias"]: + attn_dist.k_bias.copy_(attn_serial.k_bias) + attn_dist.v_bias.copy_(attn_serial.v_bias) + attn_dist.q_bias.copy_(attn_serial.q_bias) + attn_dist.proj_bias.copy_(attn_serial.proj_bias) + + # Helper: create inputs + inp_full = { + "k": torch.randn(B, C, Hi, Wi, requires_grad=True, device=self.device, dtype=torch.float32), + "v": torch.randn(B, C, Hi, Wi, requires_grad=True, device=self.device, dtype=torch.float32), + "q": torch.randn(B, C, Ho, Wo, requires_grad=True, device=self.device, dtype=torch.float32), + } + + # ---- serial forward ---- + out_full = attn_serial(inp_full["q"], inp_full["k"], inp_full["v"]) + + torch.cuda.synchronize() + + # ---- serial backward ---- + with torch.no_grad(): + ograd_full = torch.randn_like(out_full) + out_full.backward(ograd_full) + + igrad_full = {} + for inp in ["q", "k", "v"]: + igrad_full[inp] = inp_full[inp].grad.clone() + + torch.cuda.synchronize() + + # ---- distributed forward ---- + inp_local = {} + for inp in ["q", "k", "v"]: + inp_local[inp] = self._split_helper(inp_full[inp].detach()) + inp_local[inp].requires_grad_(True) + out_local = attn_dist(inp_local["q"], inp_local["k"], inp_local["v"]) + + torch.cuda.synchronize() + + # ---- distributed backward ---- + ograd_local = self._split_helper(ograd_full) + out_local.backward(ograd_local) + + igrad_local = {} + for inp in ["q", "k", "v"]: + igrad_local[inp] = inp_local[inp].grad.clone() + + torch.cuda.synchronize() + + # ---- compare forward ---- + out_gather = self._gather_helper_fwd(out_local, attn_dist) + self.assertTrue(compare_tensors("forward output", out_full, out_gather, atol=atol, rtol=rtol, verbose=verbose)) + + # ---- compare backward ---- + for inp in ["q", "k", "v"]: + igrad_gather = self._gather_helper_bwd(igrad_local[inp], attn_dist) + self.assertTrue(compare_tensors(f"input gradient {inp}", igrad_full[inp], igrad_gather, atol=atol, rtol=rtol, verbose=verbose)) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_harmonics/distributed/distributed_attention.py b/torch_harmonics/distributed/distributed_attention.py new file mode 100644 index 00000000..07c40f0e --- /dev/null +++ b/torch_harmonics/distributed/distributed_attention.py @@ -0,0 +1,679 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import math +from itertools import accumulate +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.distributed as dist + +from torch_harmonics.attention.attention import NeighborhoodAttentionS2 + +from .utils import polar_group, azimuth_group +from .utils import polar_group_size, polar_group_rank +from .utils import azimuth_group_size, azimuth_group_rank +from .primitives import compute_split_shapes + +from attention_helpers import optimized_kernels_is_available +from torch_harmonics.attention import attention_kernels + + +# --------------------------------------------------------------------------- +# helpers: lat halo and lon ring exchange +# --------------------------------------------------------------------------- + +def _get_group_neighbors(group): + group_size = dist.get_world_size(group) + global_rank = dist.get_rank() + group_ranks = dist.get_process_group_ranks(group) + my_rank_id = group_ranks.index(global_rank) + prev_rank = group_ranks[(my_rank_id - 1) % group_size] + next_rank = group_ranks[(my_rank_id + 1) % group_size] + + return prev_rank, next_rank + +class _LatHaloExchangeFn(torch.autograd.Function): + """Differentiable lat halo exchange for polar-distributed tensors. + + Forward: gathers r_lat rows from neighbouring polar ranks and returns a + halo-padded tensor of shape [B, C, H_local + 2*r_lat, W]. + Backward: communicates halo gradient contributions back to their owning + ranks and accumulates them onto the local input gradient. + + Ranks at the polar boundary (rank 0 / rank group_size-1) receive + zero-padding on the missing side in the forward pass; the corresponding + halo-gradient portion is discarded in the backward (no neighbour to send + it to), which is the correct adjoint of padding with zeros. + """ + + @staticmethod + def forward(ctx, x, r_lat, polar_group): + ctx.r_lat = r_lat + ctx.polar_group = polar_group + + group_size = dist.get_world_size(polar_group) + # this needs to be the global rank, not the group rank + global_rank = dist.get_rank() + group_rank = dist.get_rank(polar_group) + ctx.group_size = group_size + ctx.group_rank = group_rank + prev_rank, next_rank = _get_group_neighbors(polar_group) + ctx.prev_rank = prev_rank + ctx.next_rank = next_rank + ctx.H = x.shape[2] + + B, C, H, W = x.shape + device, dtype = x.device, x.dtype + + # setup send buffers + send_top = x[:, :, :r_lat, :].contiguous() # top r_lat rows → rank-1 + send_bot = x[:, :, -r_lat:, :].contiguous() # bottom r_lat rows → rank+1 + + # setup recv buffers + recv_top = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) + recv_bot = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) + + ops = [] + if group_rank > 0: + ops.append(dist.P2POp(dist.isend, send_top, prev_rank, polar_group)) + ops.append(dist.P2POp(dist.irecv, recv_top, prev_rank, polar_group)) + if group_rank < group_size - 1: + ops.append(dist.P2POp(dist.isend, send_bot, next_rank, polar_group)) + ops.append(dist.P2POp(dist.irecv, recv_bot, next_rank, polar_group)) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + return torch.cat([recv_top, x, recv_bot], dim=2).contiguous() + + @staticmethod + def backward(ctx, dout): + r_lat = ctx.r_lat + polar_group = ctx.polar_group + group_size = ctx.group_size + group_rank = ctx.group_rank + H = ctx.H + prev_rank = ctx.prev_rank + next_rank = ctx.next_rank + + B, C, _, W = dout.shape + device, dtype = dout.device, dout.dtype + + # Direct gradient for the local (non-halo) rows. + dx = dout[:, :, r_lat:r_lat + H, :].contiguous().clone() + + # The halo slices carry gradients that belong to neighbouring ranks: + # dout[:, :, :r_lat, :] → came FROM rank-1; send gradient back to rank-1 + # dout[:, :, r_lat + H:, :] → came FROM rank+1; send gradient back to rank+1 + # Simultaneously receive from each neighbour the gradient they owe us + # for the rows we sent them in the forward pass. + send_to_prev = dout[:, :, :r_lat, :].contiguous() + send_to_next = dout[:, :, r_lat + H:, :].contiguous() + + recv_from_prev = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) + recv_from_next = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) + + ops = [] + if group_rank > 0: + ops.append(dist.P2POp(dist.isend, send_to_prev, prev_rank, polar_group)) + ops.append(dist.P2POp(dist.irecv, recv_from_prev, prev_rank, polar_group)) + if group_rank < group_size - 1: + ops.append(dist.P2POp(dist.isend, send_to_next, next_rank, polar_group)) + ops.append(dist.P2POp(dist.irecv, recv_from_next, next_rank, polar_group)) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # Accumulate gradient contributions for rows we sent in the forward. + # recv_from_prev = gradient for our top r_lat rows (sent as prev rank's recv_bot) + # recv_from_next = gradient for our bottom r_lat rows (sent as next rank's recv_top) + if group_rank > 0: + dx[:, :, :r_lat, :] = dx[:, :, :r_lat, :] + recv_from_prev + if group_rank < group_size - 1: + dx[:, :, H - r_lat:, :] = dx[:, :, H - r_lat:, :] + recv_from_next + + # Gradients for r_lat and polar_group are None (not tensors / non-differentiable) + return dx, None, None + + +def _ring_step(chunk: torch.Tensor, az_group) -> Tuple[torch.Tensor, list]: + """Send chunk to the previous rank, receive the next chunk from the next rank. + + Returns (recv_buf, requests). Call req.wait() before using recv_buf. + """ + send_to, recv_from = _get_group_neighbors(az_group) + recv_buf = torch.empty_like(chunk) + ops = [ + dist.P2POp(dist.isend, chunk, send_to, az_group), + dist.P2POp(dist.irecv, recv_buf, recv_from, az_group), + ] + reqs = dist.batch_isend_irecv(ops) + return recv_buf, reqs + + +# --------------------------------------------------------------------------- +# autograd.Function for the ring-step attention kernel calls +# --------------------------------------------------------------------------- + +def _ring_kv(kw_chunk, vw_chunk, az_group, next_nlon_kw, next_nlon_kv): + """Async send current chunks, receive next chunks with known shapes.""" + send_to, recv_from = _get_group_neighbors(az_group) + B, C_k, H, _ = kw_chunk.shape + B, C_v, H, _ = vw_chunk.shape + recv_kw = torch.empty(B, C_k, H, next_nlon_kw, device=kw_chunk.device, dtype=kw_chunk.dtype) + recv_vw = torch.empty(B, C_v, H, next_nlon_kv, device=vw_chunk.device, dtype=vw_chunk.dtype) + ops = [ + dist.P2POp(dist.isend, kw_chunk, send_to, az_group), + dist.P2POp(dist.irecv, recv_kw, recv_from, az_group), + dist.P2POp(dist.isend, vw_chunk, send_to, az_group), + dist.P2POp(dist.irecv, recv_vw, recv_from, az_group), + ] + reqs = dist.batch_isend_irecv(ops) + return recv_kw, recv_vw, reqs + + +class _RingNeighborhoodAttentionFn(torch.autograd.Function): + """Forward ring attention + backward ring for one attention head group. + + kw, vw : [B*nh, C_k/C_v, H_halo, W_local] channels-first, lat-halo-padded + qw : [B*nh, C_k, H_out_local, W_out_local] channels-first + + State buffers use channels-last layout as required by the CUDA kernels: + y_acc : [B, H_out, W_out, C_v] + alpha_k/kvw : [B, H_out, W_out, C_k] + alpha_sum/qdotk_max/integral : [B, H_out, W_out] + """ + + @staticmethod + def forward( + ctx, + kw, vw, qw, + psi_col_idx, psi_roff_idx, psi_row_idx, + quad_weights, + nlon_in: int, + lon_chunk_starts: list, + nlon_kx_list: list, + lat_halo_start: int, + nlat_out_local: int, + nlon_out_local: int, + r_lat: int, + az_group, + az_rank: int, + az_size: int, + ): + B, C_k, H_halo, _ = kw.shape + _, C_v, _, _ = vw.shape + device = kw.device + + # Allocate state buffers in formats expected by the CUDA kernels: + # y_acc: channels-last [B, H, W, C_v]; scalars: [B, H, W] + y_acc = torch.zeros(B, nlat_out_local, nlon_out_local, C_v, + device=device, dtype=torch.float32) + alpha_sum = torch.zeros(B, nlat_out_local, nlon_out_local, + device=device, dtype=torch.float32) + qdotk_max = torch.full ((B, nlat_out_local, nlon_out_local), float('-inf'), + device=device, dtype=torch.float32) + + kw_chunk = kw.contiguous() + vw_chunk = vw.contiguous() + + for step in range(az_size): + src_rank = (az_rank + step) % az_size + lon_lo_kx = lon_chunk_starts[src_rank] + + # Pre-allocate receive buffers for the NEXT chunk (correct shape) + if step < az_size - 1: + next_src = (az_rank + step + 1) % az_size + recv_kw, recv_vw, reqs = _ring_kv( + kw_chunk, vw_chunk, az_group, + nlon_kx_list[next_src], nlon_kx_list[next_src]) + + attention_kernels.forward_ring_step.default( + kw_chunk, vw_chunk, qw, + y_acc, alpha_sum, qdotk_max, + quad_weights, psi_col_idx, psi_roff_idx, psi_row_idx, + nlon_in, lon_lo_kx, lat_halo_start, + nlat_out_local, nlon_out_local, + ) + + if step < az_size - 1: + for req in reqs: + req.wait() + kw_chunk = recv_kw + vw_chunk = recv_vw + + # Finalize: y = y_acc / alpha_sum (both channels-last layout) + y_out = y_acc / alpha_sum.unsqueeze(-1) # [B, H, W, C_v] + y_out = y_out.permute(0, 3, 1, 2).contiguous() # [B, C_v, H, W] + + # Save for backward (kw/vw: channels-first; scalars: [B,H,W]) + ctx.save_for_backward(kw, vw, qw, psi_col_idx, psi_roff_idx, psi_row_idx, + quad_weights, alpha_sum, qdotk_max) + ctx.nlon_in = nlon_in + ctx.lon_chunk_starts = lon_chunk_starts + ctx.nlon_kx_list = nlon_kx_list + ctx.lat_halo_start = lat_halo_start + ctx.nlat_out_local = nlat_out_local + ctx.nlon_out_local = nlon_out_local + ctx.r_lat = r_lat + ctx.az_group = az_group + ctx.az_rank = az_rank + ctx.az_size = az_size + return y_out + + @staticmethod + def backward(ctx, dy): + (kw, vw, qw, + psi_col_idx, psi_roff_idx, psi_row_idx, + quad_weights, + fwd_alpha_sum, fwd_qdotk_max) = ctx.saved_tensors + + nlon_in = ctx.nlon_in + lon_chunk_starts = ctx.lon_chunk_starts + nlon_kx_list = ctx.nlon_kx_list + lat_halo_start = ctx.lat_halo_start + nlat_out_local = ctx.nlat_out_local + nlon_out_local = ctx.nlon_out_local + r_lat = ctx.r_lat + az_group = ctx.az_group + az_rank = ctx.az_rank + az_size = ctx.az_size + + B, C_k, H_halo, _ = kw.shape + _, C_v, _, _ = vw.shape + device = kw.device + + dy_cf = dy.contiguous() # channels-first [B, C_v, H, W] + + # ---------------------------------------------------------------- + # Backward pass 1: re-accumulate {alpha_sum, qdotk_max, integral, + # alpha_k, alpha_kvw} via ring. + # Start fresh (NOT from the saved forward values). + # ---------------------------------------------------------------- + bwd_alpha_sum = torch.zeros(B, nlat_out_local, nlon_out_local, + device=device, dtype=torch.float32) + bwd_qdotk_max = torch.full ((B, nlat_out_local, nlon_out_local), float('-inf'), + device=device, dtype=torch.float32) + integral_buf = torch.zeros_like(bwd_alpha_sum) + alpha_k_buf = torch.zeros(B, nlat_out_local, nlon_out_local, C_k, + device=device, dtype=torch.float32) + alpha_kvw_buf = torch.zeros_like(alpha_k_buf) + + kw_chunk = kw.contiguous() + vw_chunk = vw.contiguous() + + for step in range(az_size): + src_rank = (az_rank + step) % az_size + lon_lo_kx = lon_chunk_starts[src_rank] + + if step < az_size - 1: + next_src = (az_rank + step + 1) % az_size + recv_kw, recv_vw, reqs = _ring_kv( + kw_chunk, vw_chunk, az_group, + nlon_kx_list[next_src], nlon_kx_list[next_src]) + + attention_kernels.backward_ring_step_pass1.default( + kw_chunk, vw_chunk, qw, dy_cf, + bwd_alpha_sum, bwd_qdotk_max, integral_buf, + alpha_k_buf, alpha_kvw_buf, + quad_weights, psi_col_idx, psi_roff_idx, psi_row_idx, + nlon_in, lon_lo_kx, lat_halo_start, + nlat_out_local, nlon_out_local, + ) + + if step < az_size - 1: + for req in reqs: + req.wait() + kw_chunk = recv_kw + vw_chunk = recv_vw + + # Finalize pass-1: normalize integral, compute dqy + # Use the SAVED forward alpha_sum/qdotk_max (same values, but authoritative) + alpha_sum_inv = 1.0 / fwd_alpha_sum # [B, H, W] + integral_norm = integral_buf * alpha_sum_inv # [B, H, W] + alpha_sum_inv_sq = alpha_sum_inv ** 2 + + # dqy[b,h,w,c] = inv_sq*(alpha_sum*alpha_kvw - integral*alpha_k) + dqy_cl = alpha_sum_inv_sq.unsqueeze(-1) * ( + fwd_alpha_sum.unsqueeze(-1) * alpha_kvw_buf + - integral_buf.unsqueeze(-1) * alpha_k_buf + ) # [B, H, W, C_k] + dqy = dqy_cl.permute(0, 3, 1, 2).contiguous() # [B, C_k, H, W] + + # ---------------------------------------------------------------- + # Backward pass 2: scatter dkw/dvw contributions. + # Each GPU computes its contribution to every lon chunk it visits; + # then allreduce across azimuth ranks, extract local chunk. + # TODO: replace allreduce with ring reduce-scatter for efficiency. + # ---------------------------------------------------------------- + kw_chunk = kw.contiguous() + vw_chunk = vw.contiguous() + nlon_in_total = sum(nlon_kx_list) + # Gradient buffers in channels-last as expected by CUDA kernel + dkw_full_cl = torch.zeros(B, H_halo, nlon_in_total, C_k, + device=device, dtype=torch.float32) + dvw_full_cl = torch.zeros(B, H_halo, nlon_in_total, C_v, + device=device, dtype=torch.float32) + + for step in range(az_size): + src_rank = (az_rank + step) % az_size + lon_lo_kx = lon_chunk_starts[src_rank] + nlon_kx = nlon_kx_list[src_rank] + + # Channels-last gradient buffers for this chunk + dkw_chunk_cl = torch.zeros(B, H_halo, nlon_kx, C_k, + device=device, dtype=torch.float32) + dvw_chunk_cl = torch.zeros(B, H_halo, nlon_kx, C_v, + device=device, dtype=torch.float32) + + attention_kernels.backward_ring_step_pass2.default( + kw_chunk, vw_chunk, qw, dy_cf, + fwd_alpha_sum, fwd_qdotk_max, integral_norm, + dkw_chunk_cl, dvw_chunk_cl, + quad_weights, psi_col_idx, psi_roff_idx, psi_row_idx, + nlon_in, lon_lo_kx, lat_halo_start, + nlat_out_local, nlon_out_local, + ) + + dkw_full_cl[:, :, lon_lo_kx:lon_lo_kx + nlon_kx, :].add_(dkw_chunk_cl) + dvw_full_cl[:, :, lon_lo_kx:lon_lo_kx + nlon_kx, :].add_(dvw_chunk_cl) + + if step < az_size - 1: + next_src = (az_rank + step + 1) % az_size + recv_kw, recv_vw, reqs = _ring_kv( + kw_chunk, vw_chunk, az_group, + nlon_kx_list[next_src], nlon_kx_list[next_src]) + for req in reqs: + req.wait() + kw_chunk = recv_kw + vw_chunk = recv_vw + + if az_size > 1 and az_group is not None: + dist.all_reduce(dkw_full_cl, group=az_group) + dist.all_reduce(dvw_full_cl, group=az_group) + + my_lo = lon_chunk_starts[az_rank] + my_nlon = nlon_kx_list[az_rank] + # Extract local chunk and convert channels-last → channels-first + dkw_cl = dkw_full_cl[:, :, my_lo:my_lo + my_nlon, :].contiguous() + dvw_cl = dvw_full_cl[:, :, my_lo:my_lo + my_nlon, :].contiguous() + dkw = dkw_cl.permute(0, 3, 1, 2).contiguous() # [B, C_k, H_halo, W_local] + dvw = dvw_cl.permute(0, 3, 1, 2).contiguous() # [B, C_v, H_halo, W_local] + # No halo stripping: dkw/dvw must match kw/vw shape (= key_halo/value_halo). + # The autograd through torch.cat in _exchange_lat_halo extracts the + # middle H_in rows as the gradient for key_proj/value_proj. + + # Return grads for (kw, vw, qw, psi_col, psi_roff, psi_row, quad_weights, + # nlon_in, lon_chunk_starts, nlon_kx_list, lat_halo_start, + # nlat_out_local, nlon_out_local, r_lat, + # az_group, az_rank, az_size, polar_group) + return dkw, dvw, dqy, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +# --------------------------------------------------------------------------- +# Distributed Neighborhood Attention on the 2-sphere +# --------------------------------------------------------------------------- + +class DistributedNeighborhoodAttentionS2(NeighborhoodAttentionS2): + """ + Distributed neighborhood attention on the 2-sphere using a ring exchange + strategy for the longitude dimension and halo exchange for the latitude + dimension. + + Data is assumed to be split along both the latitude (polar group) and + longitude (azimuth group) dimensions. The forward pass uses ring exchange + of key/value chunks over the azimuth group so that every output point can + attend to its full spherical neighborhood. + + Inherits learnable parameters from :class:`NeighborhoodAttentionS2`. + """ + + def __init__( + self, + in_channels: int, + in_shape: Tuple[int, int], + out_shape: Tuple[int, int], + grid_in: Optional[str] = "equiangular", + grid_out: Optional[str] = "equiangular", + num_heads: Optional[int] = 1, + scale: Optional[Union[torch.Tensor, float]] = None, + bias: Optional[bool] = True, + theta_cutoff: Optional[float] = None, + k_channels: Optional[int] = None, + out_channels: Optional[int] = None, + optimized_kernel: Optional[bool] = True, + ): + if not optimized_kernels_is_available(): + raise RuntimeError("Optimized kernels are required to run DistributedNeighborhoodAttentionS2.") + + # initialise base class (builds global psi, creates parameters) + super().__init__( + in_channels, in_shape, out_shape, + grid_in=grid_in, grid_out=grid_out, + num_heads=num_heads, scale=scale, bias=bias, + theta_cutoff=theta_cutoff, k_channels=k_channels, + out_channels=out_channels, + optimized_kernel=True, + ) + + # ---- distributed info ---- + self.comm_size_polar = polar_group_size() + self.comm_rank_polar = polar_group_rank() + self.comm_size_azimuth = azimuth_group_size() + self.comm_rank_azimuth = azimuth_group_rank() + + # split shapes + self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar) + self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth) + self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar) + self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth) + + # local sizes for this rank + self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar] + self.nlon_in_local = self.lon_in_shapes[self.comm_rank_azimuth] + self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar] + self.nlon_out_local = self.lon_out_shapes[self.comm_rank_azimuth] + + # global lon offsets + self.lon_in_starts = list(accumulate([0] + self.lon_in_shapes[:-1])) + self.lon_out_starts = list(accumulate([0] + self.lon_out_shapes[:-1])) + self.lat_out_starts = list(accumulate([0] + self.lat_out_shapes[:-1])) + + self.lon_lo_out = self.lon_out_starts[self.comm_rank_azimuth] + self.lat_lo_out = self.lat_out_starts[self.comm_rank_polar] + + # ---- build local psi ---- + # The global psi built by the base class covers all output lat rows. + # We filter to only the rows owned by this rank and shift the wi + # component of col_idx by lon_lo_out so that the kernel can use + # local wo directly without knowing the global lon offset. + self._build_local_psi() + + # ---- lat halo size ---- + # Compute r_lat from the global psi: maximum |hi_global - ho_global| + # over all (ho, hi) pairs in the neighbourhood. + # Use the lat_out_lo of our polar rank to compute ho_global. + self.r_lat = self._compute_r_lat() + + # ----------------------------------------------------------------------- + + def _build_local_psi(self): + """Filter global psi to local output lat rows and shift col_idx wi.""" + + lat_lo = self.lat_lo_out + lat_hi = lat_lo + self.nlat_out_local + + # global psi from the base class (built over all nlat_out rows) + col_idx_global = self.psi_col_idx # [nnz] int64 + row_idx_global = self.psi_row_idx # [nnz] int32 + roff_global = self.psi_roff_idx # [nlat_out+1] int64 + + # psi_row_idx stores the sorted permutation: value is the row index. + # psi_roff_idx[ho] .. psi_roff_idx[ho+1] gives entries for row ho. + # (The row_idx buffer is the *sort order*, not the row indices directly.) + # For the distributed case we rebuild roff for the local rows only. + + # Build local roff: select rows lat_lo..lat_hi-1 + roff_local = roff_global[lat_lo:lat_hi + 1] - roff_global[lat_lo] # offset by first entry + + # Select the corresponding col_idx entries + start = roff_global[lat_lo].item() + end = roff_global[lat_hi].item() + col_idx_local = col_idx_global[start:end].clone() + + # Shift wi by lon_lo_out: + # col stores hi_global * nlon_in + wi_canonical + # We want hi_global * nlon_in + (wi_canonical + lon_lo_out) % nlon_in + nlon_in = self.nlon_in + lon_lo = self.lon_lo_out + hi_global = col_idx_local // nlon_in + wi_canon = col_idx_local - hi_global * nlon_in + wi_shifted = (wi_canon + lon_lo) % nlon_in + col_idx_shifted = hi_global * nlon_in + wi_shifted + + # Build sorted row_idx for local output rows (0-indexed within local range) + # Reuse the serial sort order: just re-sort by nnz per local row + nnz_per_row = (roff_local[1:] - roff_local[:-1]).cpu() + row_idx_local = torch.argsort(nnz_per_row, descending=True).to(torch.int32) + + self.register_buffer("psi_col_idx_local", col_idx_shifted, persistent=False) + self.register_buffer("psi_roff_idx_local", roff_local, persistent=False) + self.register_buffer("psi_row_idx_local", row_idx_local, persistent=False) + + def _compute_r_lat(self) -> int: + """Max lat halo radius needed for this rank's output lat range.""" + + if polar_group_size() == 1: + return 0 + + lat_lo = self.lat_lo_out + lat_hi = lat_lo + self.nlat_out_local + + col_idx = self.psi_col_idx_local + roff = self.psi_roff_idx_local + nlon_in = self.nlon_in + + hi_global = (col_idx // nlon_in).long() + + # output lat indices for each entry + ho_indices = torch.zeros(col_idx.shape[0], dtype=torch.long, device=col_idx.device) + for ho_local in range(self.nlat_out_local): + b = roff[ho_local].item() + e = roff[ho_local + 1].item() + ho_indices[b:e] = ho_local + lat_lo + + r = (hi_global - ho_indices).abs().max().item() if col_idx.numel() > 0 else 0 + return int(r) + + # ----------------------------------------------------------------------- + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if key is None: + key = query + if value is None: + value = query + + assert query.dim() == 4 + + # scale query + query_scaled = query * self.scale + + # ---- 1. project to k/v/q ---- + key_proj = nn.functional.conv2d(key, self.k_weights, bias=self.k_bias) + value_proj = nn.functional.conv2d(value, self.v_weights, bias=self.v_bias) + query_proj = nn.functional.conv2d(query_scaled, self.q_weights, bias=self.q_bias) + + # fold num_heads into batch + B, _, H, W = key_proj.shape + key_proj = key_proj.reshape(B * self.num_heads, -1, H, W) + B, _, H, W = value_proj.shape + value_proj = value_proj.reshape(B * self.num_heads, -1, H, W) + B, _, H, W = query_proj.shape + query_proj = query_proj.reshape(B * self.num_heads, -1, H, W) + + Bnh = B # B*nh after reshape + + # ---- 2. lat halo exchange ---- + # key_proj/value_proj: [Bnh, C, H_in_local, W_in_local] + # Use differentiable halo exchange when there is an actual polar split; + # otherwise fall through to the identity (no-op). + if self.r_lat > 0 and self.comm_size_polar > 1: + key_halo = _LatHaloExchangeFn.apply(key_proj, self.r_lat, polar_group()) + value_halo = _LatHaloExchangeFn.apply(value_proj, self.r_lat, polar_group()) + else: + key_halo = key_proj + value_halo = value_proj + + # global lat index of first halo row + lat_in_starts = list(accumulate([0] + self.lat_in_shapes[:-1])) + lat_halo_start = lat_in_starts[self.comm_rank_polar] - self.r_lat + + # ---- 3. ring attention ---- + out = _RingNeighborhoodAttentionFn.apply( + key_halo, + value_halo, + query_proj, + self.psi_col_idx_local, + self.psi_roff_idx_local, + self.psi_row_idx_local, + self.quad_weights, + self.nlon_in, + self.lon_in_starts, # lon chunk starts for kv (same as lon_in) + self.lon_in_shapes, # lon chunk sizes for kv + lat_halo_start, + self.nlat_out_local, + self.nlon_out_local, + self.r_lat, + azimuth_group(), + self.comm_rank_azimuth, + self.comm_size_azimuth, + ) # [Bnh, C_v, H_out_local, W_out_local] + + # unfold num_heads + B_nh, C_v, H_out, W_out = out.shape + B_orig = B_nh // self.num_heads + out = out.reshape(B_orig, self.num_heads * C_v, H_out, W_out) + + # ---- 4. output projection ---- + out = nn.functional.conv2d(out, self.proj_weights, bias=self.proj_bias) + + return out From d15fb521a715fe3a5de7abeed1c917e868f8d956 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 30 Mar 2026 09:00:11 -0700 Subject: [PATCH 03/17] slight refactoring --- .../distributed/distributed_attention.py | 153 +----------------- torch_harmonics/distributed/primitives.py | 133 ++++++++++++++- 2 files changed, 138 insertions(+), 148 deletions(-) diff --git a/torch_harmonics/distributed/distributed_attention.py b/torch_harmonics/distributed/distributed_attention.py index 07c40f0e..efe05178 100644 --- a/torch_harmonics/distributed/distributed_attention.py +++ b/torch_harmonics/distributed/distributed_attention.py @@ -28,7 +28,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import math from itertools import accumulate from typing import Optional, Tuple, Union @@ -38,159 +37,22 @@ from torch_harmonics.attention.attention import NeighborhoodAttentionS2 -from .utils import polar_group, azimuth_group +from .utils import azimuth_group from .utils import polar_group_size, polar_group_rank from .utils import azimuth_group_size, azimuth_group_rank -from .primitives import compute_split_shapes +from .primitives import compute_split_shapes, get_group_neighbors, polar_halo_exchange from attention_helpers import optimized_kernels_is_available from torch_harmonics.attention import attention_kernels -# --------------------------------------------------------------------------- -# helpers: lat halo and lon ring exchange -# --------------------------------------------------------------------------- - -def _get_group_neighbors(group): - group_size = dist.get_world_size(group) - global_rank = dist.get_rank() - group_ranks = dist.get_process_group_ranks(group) - my_rank_id = group_ranks.index(global_rank) - prev_rank = group_ranks[(my_rank_id - 1) % group_size] - next_rank = group_ranks[(my_rank_id + 1) % group_size] - - return prev_rank, next_rank - -class _LatHaloExchangeFn(torch.autograd.Function): - """Differentiable lat halo exchange for polar-distributed tensors. - - Forward: gathers r_lat rows from neighbouring polar ranks and returns a - halo-padded tensor of shape [B, C, H_local + 2*r_lat, W]. - Backward: communicates halo gradient contributions back to their owning - ranks and accumulates them onto the local input gradient. - - Ranks at the polar boundary (rank 0 / rank group_size-1) receive - zero-padding on the missing side in the forward pass; the corresponding - halo-gradient portion is discarded in the backward (no neighbour to send - it to), which is the correct adjoint of padding with zeros. - """ - - @staticmethod - def forward(ctx, x, r_lat, polar_group): - ctx.r_lat = r_lat - ctx.polar_group = polar_group - - group_size = dist.get_world_size(polar_group) - # this needs to be the global rank, not the group rank - global_rank = dist.get_rank() - group_rank = dist.get_rank(polar_group) - ctx.group_size = group_size - ctx.group_rank = group_rank - prev_rank, next_rank = _get_group_neighbors(polar_group) - ctx.prev_rank = prev_rank - ctx.next_rank = next_rank - ctx.H = x.shape[2] - - B, C, H, W = x.shape - device, dtype = x.device, x.dtype - - # setup send buffers - send_top = x[:, :, :r_lat, :].contiguous() # top r_lat rows → rank-1 - send_bot = x[:, :, -r_lat:, :].contiguous() # bottom r_lat rows → rank+1 - - # setup recv buffers - recv_top = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) - recv_bot = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) - - ops = [] - if group_rank > 0: - ops.append(dist.P2POp(dist.isend, send_top, prev_rank, polar_group)) - ops.append(dist.P2POp(dist.irecv, recv_top, prev_rank, polar_group)) - if group_rank < group_size - 1: - ops.append(dist.P2POp(dist.isend, send_bot, next_rank, polar_group)) - ops.append(dist.P2POp(dist.irecv, recv_bot, next_rank, polar_group)) - - if ops: - reqs = dist.batch_isend_irecv(ops) - for req in reqs: - req.wait() - - return torch.cat([recv_top, x, recv_bot], dim=2).contiguous() - - @staticmethod - def backward(ctx, dout): - r_lat = ctx.r_lat - polar_group = ctx.polar_group - group_size = ctx.group_size - group_rank = ctx.group_rank - H = ctx.H - prev_rank = ctx.prev_rank - next_rank = ctx.next_rank - - B, C, _, W = dout.shape - device, dtype = dout.device, dout.dtype - - # Direct gradient for the local (non-halo) rows. - dx = dout[:, :, r_lat:r_lat + H, :].contiguous().clone() - - # The halo slices carry gradients that belong to neighbouring ranks: - # dout[:, :, :r_lat, :] → came FROM rank-1; send gradient back to rank-1 - # dout[:, :, r_lat + H:, :] → came FROM rank+1; send gradient back to rank+1 - # Simultaneously receive from each neighbour the gradient they owe us - # for the rows we sent them in the forward pass. - send_to_prev = dout[:, :, :r_lat, :].contiguous() - send_to_next = dout[:, :, r_lat + H:, :].contiguous() - - recv_from_prev = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) - recv_from_next = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) - - ops = [] - if group_rank > 0: - ops.append(dist.P2POp(dist.isend, send_to_prev, prev_rank, polar_group)) - ops.append(dist.P2POp(dist.irecv, recv_from_prev, prev_rank, polar_group)) - if group_rank < group_size - 1: - ops.append(dist.P2POp(dist.isend, send_to_next, next_rank, polar_group)) - ops.append(dist.P2POp(dist.irecv, recv_from_next, next_rank, polar_group)) - - if ops: - reqs = dist.batch_isend_irecv(ops) - for req in reqs: - req.wait() - - # Accumulate gradient contributions for rows we sent in the forward. - # recv_from_prev = gradient for our top r_lat rows (sent as prev rank's recv_bot) - # recv_from_next = gradient for our bottom r_lat rows (sent as next rank's recv_top) - if group_rank > 0: - dx[:, :, :r_lat, :] = dx[:, :, :r_lat, :] + recv_from_prev - if group_rank < group_size - 1: - dx[:, :, H - r_lat:, :] = dx[:, :, H - r_lat:, :] + recv_from_next - - # Gradients for r_lat and polar_group are None (not tensors / non-differentiable) - return dx, None, None - - -def _ring_step(chunk: torch.Tensor, az_group) -> Tuple[torch.Tensor, list]: - """Send chunk to the previous rank, receive the next chunk from the next rank. - - Returns (recv_buf, requests). Call req.wait() before using recv_buf. - """ - send_to, recv_from = _get_group_neighbors(az_group) - recv_buf = torch.empty_like(chunk) - ops = [ - dist.P2POp(dist.isend, chunk, send_to, az_group), - dist.P2POp(dist.irecv, recv_buf, recv_from, az_group), - ] - reqs = dist.batch_isend_irecv(ops) - return recv_buf, reqs - - # --------------------------------------------------------------------------- # autograd.Function for the ring-step attention kernel calls # --------------------------------------------------------------------------- def _ring_kv(kw_chunk, vw_chunk, az_group, next_nlon_kw, next_nlon_kv): """Async send current chunks, receive next chunks with known shapes.""" - send_to, recv_from = _get_group_neighbors(az_group) + send_to, recv_from = get_group_neighbors(az_group) B, C_k, H, _ = kw_chunk.shape B, C_v, H, _ = vw_chunk.shape recv_kw = torch.empty(B, C_k, H, next_nlon_kw, device=kw_chunk.device, dtype=kw_chunk.dtype) @@ -234,7 +96,7 @@ def forward( az_rank: int, az_size: int, ): - B, C_k, H_halo, _ = kw.shape + B, _, _, _ = kw.shape _, C_v, _, _ = vw.shape device = kw.device @@ -288,7 +150,6 @@ def forward( ctx.lat_halo_start = lat_halo_start ctx.nlat_out_local = nlat_out_local ctx.nlon_out_local = nlon_out_local - ctx.r_lat = r_lat ctx.az_group = az_group ctx.az_rank = az_rank ctx.az_size = az_size @@ -307,7 +168,6 @@ def backward(ctx, dy): lat_halo_start = ctx.lat_halo_start nlat_out_local = ctx.nlat_out_local nlon_out_local = ctx.nlon_out_local - r_lat = ctx.r_lat az_group = ctx.az_group az_rank = ctx.az_rank az_size = ctx.az_size @@ -580,7 +440,6 @@ def _compute_r_lat(self) -> int: return 0 lat_lo = self.lat_lo_out - lat_hi = lat_lo + self.nlat_out_local col_idx = self.psi_col_idx_local roff = self.psi_roff_idx_local @@ -637,8 +496,8 @@ def forward( # Use differentiable halo exchange when there is an actual polar split; # otherwise fall through to the identity (no-op). if self.r_lat > 0 and self.comm_size_polar > 1: - key_halo = _LatHaloExchangeFn.apply(key_proj, self.r_lat, polar_group()) - value_halo = _LatHaloExchangeFn.apply(value_proj, self.r_lat, polar_group()) + key_halo = polar_halo_exchange(key_proj, self.r_lat) + value_halo = polar_halo_exchange(value_proj, self.r_lat) else: key_halo = key_proj value_halo = value_proj diff --git a/torch_harmonics/distributed/primitives.py b/torch_harmonics/distributed/primitives.py index ceb343f9..21088f1b 100644 --- a/torch_harmonics/distributed/primitives.py +++ b/torch_harmonics/distributed/primitives.py @@ -35,9 +35,21 @@ from torch.amp import custom_fwd, custom_bwd from .utils import config as thd_config -from .utils import polar_group, azimuth_group, polar_group_size +from .utils import polar_group, azimuth_group, polar_group_size, polar_group_rank from .utils import is_distributed_polar, is_distributed_azimuth + +def get_group_neighbors(group): + group_size = dist.get_world_size(group) + global_rank = dist.get_rank() + group_ranks = dist.get_process_group_ranks(group) + my_rank_id = group_ranks.index(global_rank) + prev_rank = group_ranks[(my_rank_id - 1) % group_size] + next_rank = group_ranks[(my_rank_id + 1) % group_size] + + return prev_rank, next_rank + + def _check_shapes(msg, shapes_gather, shapes_expected): for idx, (size_gather, size_expected) in enumerate(zip(shapes_gather, shapes_expected)): if size_gather != size_expected: @@ -527,3 +539,122 @@ def reduce_from_scatter_to_polar_region(input_, dim_): @torch.compiler.disable() def gather_from_copy_to_polar_region(input_, dim_, shapes_): return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_) + + +# --------------------------------------------------------------------------- +# nearest neighbor exchange algorithms +# --------------------------------------------------------------------------- +class _PolarHaloExchangeFn(torch.autograd.Function): + """Differentiable lat halo exchange for polar-distributed tensors. + + Forward: gathers r_lat rows from neighbouring polar ranks and returns a + halo-padded tensor of shape [B, C, H_local + 2*r_lat, W]. + Backward: communicates halo gradient contributions back to their owning + ranks and accumulates them onto the local input gradient. + + Ranks at the polar boundary (rank 0 / rank group_size-1) receive + zero-padding on the missing side in the forward pass; the corresponding + halo-gradient portion is discarded in the backward (no neighbour to send + it to), which is the correct adjoint of padding with zeros. + """ + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, x, r_lat): + + if not is_distributed_polar(): + return x + + ctx.r_lat = r_lat + group_size = polar_group_size() + group_rank = polar_group_rank() + ctx.group_size = group_size + ctx.group_rank = group_rank + prev_rank, next_rank = get_group_neighbors(polar_group()) + ctx.prev_rank = prev_rank + ctx.next_rank = next_rank + ctx.H = x.shape[2] + + B, C, H, W = x.shape + device, dtype = x.device, x.dtype + + # setup send buffers + send_top = x[:, :, :r_lat, :].contiguous() # top r_lat rows → rank-1 + send_bot = x[:, :, -r_lat:, :].contiguous() # bottom r_lat rows → rank+1 + + # setup recv buffers + recv_top = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) + recv_bot = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) + + ops = [] + if group_rank > 0: + ops.append(dist.P2POp(dist.isend, send_top, prev_rank, polar_group())) + ops.append(dist.P2POp(dist.irecv, recv_top, prev_rank, polar_group())) + if group_rank < group_size - 1: + ops.append(dist.P2POp(dist.isend, send_bot, next_rank, polar_group())) + ops.append(dist.P2POp(dist.irecv, recv_bot, next_rank, polar_group())) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + return torch.cat([recv_top, x, recv_bot], dim=2).contiguous() + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, dout): + + if not is_distributed_polar(): + return dout, None + + r_lat = ctx.r_lat + group_size = ctx.group_size + group_rank = ctx.group_rank + H = ctx.H + prev_rank = ctx.prev_rank + next_rank = ctx.next_rank + + B, C, _, W = dout.shape + device, dtype = dout.device, dout.dtype + + # Direct gradient for the local (non-halo) rows. + dx = dout[:, :, r_lat:r_lat + H, :].contiguous().clone() + + # The halo slices carry gradients that belong to neighbouring ranks: + # dout[:, :, :r_lat, :] → came FROM rank-1; send gradient back to rank-1 + # dout[:, :, r_lat + H:, :] → came FROM rank+1; send gradient back to rank+1 + # Simultaneously receive from each neighbour the gradient they owe us + # for the rows we sent them in the forward pass. + send_to_prev = dout[:, :, :r_lat, :].contiguous() + send_to_next = dout[:, :, r_lat + H:, :].contiguous() + + recv_from_prev = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) + recv_from_next = torch.zeros(B, C, r_lat, W, device=device, dtype=dtype) + + ops = [] + if group_rank > 0: + ops.append(dist.P2POp(dist.isend, send_to_prev, prev_rank, polar_group())) + ops.append(dist.P2POp(dist.irecv, recv_from_prev, prev_rank, polar_group())) + if group_rank < group_size - 1: + ops.append(dist.P2POp(dist.isend, send_to_next, next_rank, polar_group())) + ops.append(dist.P2POp(dist.irecv, recv_from_next, next_rank, polar_group())) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # Accumulate gradient contributions for rows we sent in the forward. + # recv_from_prev = gradient for our top r_lat rows (sent as prev rank's recv_bot) + # recv_from_next = gradient for our bottom r_lat rows (sent as next rank's recv_top) + if group_rank > 0: + dx[:, :, :r_lat, :] = dx[:, :, :r_lat, :] + recv_from_prev + if group_rank < group_size - 1: + dx[:, :, H - r_lat:, :] = dx[:, :, H - r_lat:, :] + recv_from_next + + # Gradients for r_lat is None (not tensors / non-differentiable) + return dx, None + +def polar_halo_exchange(x, r_lat): + return _PolarHaloExchangeFn.apply(x, r_lat) From db3e88fd710e60f3b555a327f21bd9c20cebf738 Mon Sep 17 00:00:00 2001 From: Mauro Bisson Date: Mon, 30 Mar 2026 14:44:27 -0700 Subject: [PATCH 04/17] Fixed a typo in bwd kernels early-exit checks. --- .../attention/csrc/attention_cuda_bwd.cu | 18 ++++++++++++++++-- .../attention/csrc/attention_cuda_fwd.cu | 16 +++++++++++++++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu index a825d4d7..0e8fa50b 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu @@ -688,13 +688,27 @@ void launch_spc_attn_bwd(int nloc, // "BDIM_X*nloc" >= nchans_out dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); size_t shsize = sizeof(FLOATV_T)*(nchans_in+nchans_out) * block.y; // 2 arrays per cta, block.y > 1 iif block.x==32 - +#if 0 + printf("Launching s2_attn_bwd_special_vec_k<%d, %d, %d, %d, float%s><<<(%d, %d), (%d, %d), %zu, ...>>> with:\n" + "\tnchans_in: %d\n" + "\tnchans_out: %d\n" + "\tnlat_in: %d\n" + "\tnlon_in: %d\n" + "\tnlat_out: %d\n" + "\tnlon_out: %d\n", + BDIM_X, BDIM_Y, + (nchans_out >= BDIM_X*(CUR_LOC_SIZE-1) && nchans_out <= BDIM_X* CUR_LOC_SIZE), + CUR_LOC_SIZE, + sizeof(FLOATV_T)==16?"4":"", + grid.x, grid.y, block.x, block.y, shsize, + nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out); +#endif // nloc determines the size of local arrays used to store // temporary buffers loc_k__[], loc_vw_[] and loc_kvw[], // of size nchans_in each; // if nchans_out is >= BDIM_X*(nloc-1) and <= BDIM_X*nloc // then we can use the same compile-time known loops used - // for input channels, with the execpetion of testing + // for input channels, with the exception of testing // whether to execute the last iteration based on "nchans_out" // instead of "nchans_in"; in this way as long as the // difference between the number of input and output channels diff --git a/torch_harmonics/attention/csrc/attention_cuda_fwd.cu b/torch_harmonics/attention/csrc/attention_cuda_fwd.cu index 16ea96b6..9e116efb 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_fwd.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_fwd.cu @@ -374,7 +374,21 @@ void launch_spc_attn_fwd(int nloc, // "BDIM_X*nloc" >= nchans_out //size_t shsize = sizeof(FLOATV_T)*nchans_out * block.y; // block.y > 1 iif block.x==32 size_t shsize = sizeof(FLOATV_T)*nchans_in * block.y; // block.y > 1 iif block.x==32 - +#if 0 + printf("Launching s2_attn_fwd_special_vec_k<%d, %d, %d, %d, float%s><<<(%d, %d), (%d, %d), %zu, ...>>> with:\n" + "\tnchans_in: %d\n" + "\tnchans_out: %d\n" + "\tnlat_in: %d\n" + "\tnlon_in: %d\n" + "\tnlat_out: %d\n" + "\tnlon_out: %d\n", + BDIM_X, BDIM_Y, + (nchans_in >= BDIM_X*(CUR_LOC_SIZE-1) && nchans_in <= BDIM_X* CUR_LOC_SIZE), + CUR_LOC_SIZE, + sizeof(FLOATV_T)==16?"4":"", + grid.x, grid.y, block.x, block.y, shsize, + nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out); +#endif // nloc determines the size of local arrays used to store // y vectors, of length nchans_out; // if nchans_in is >= BDIM_X*(nloc-1) and <= BDIM_X*nloc From 48f7510c1ee733646b7e17d9579c90903cb0cfbf Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 6 Apr 2026 00:12:06 -0700 Subject: [PATCH 05/17] distributed attention working again --- .../distributed/distributed_attention.py | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/torch_harmonics/distributed/distributed_attention.py b/torch_harmonics/distributed/distributed_attention.py index efe05178..49464295 100644 --- a/torch_harmonics/distributed/distributed_attention.py +++ b/torch_harmonics/distributed/distributed_attention.py @@ -37,7 +37,7 @@ from torch_harmonics.attention.attention import NeighborhoodAttentionS2 -from .utils import azimuth_group +from .utils import azimuth_group, polar_group from .utils import polar_group_size, polar_group_rank from .utils import azimuth_group_size, azimuth_group_rank from .primitives import compute_split_shapes, get_group_neighbors, polar_halo_exchange @@ -434,28 +434,40 @@ def _build_local_psi(self): self.register_buffer("psi_row_idx_local", row_idx_local, persistent=False) def _compute_r_lat(self) -> int: - """Max lat halo radius needed for this rank's output lat range.""" + """Max lat halo radius needed across all polar ranks. + + Computed locally from the global psi (built identically on every rank + by the base class), so no communication is required. + """ if polar_group_size() == 1: return 0 - lat_lo = self.lat_lo_out - - col_idx = self.psi_col_idx_local - roff = self.psi_roff_idx_local - nlon_in = self.nlon_in - - hi_global = (col_idx // nlon_in).long() - - # output lat indices for each entry - ho_indices = torch.zeros(col_idx.shape[0], dtype=torch.long, device=col_idx.device) - for ho_local in range(self.nlat_out_local): - b = roff[ho_local].item() - e = roff[ho_local + 1].item() - ho_indices[b:e] = ho_local + lat_lo + col_idx = self.psi_col_idx # global, all nlat_out rows + if col_idx.numel() == 0: + return 0 - r = (hi_global - ho_indices).abs().max().item() if col_idx.numel() > 0 else 0 - return int(r) + lat_in_starts = list(accumulate([0] + self.lat_in_shapes[:-1])) + roff = self.psi_roff_idx + + r = 0 + for rank in range(self.comm_size_polar): + lat_in_lo = lat_in_starts[rank] + lat_in_hi = lat_in_lo + self.lat_in_shapes[rank] + lat_out_lo = self.lat_out_starts[rank] + lat_out_hi = lat_out_lo + self.lat_out_shapes[rank] + + start = roff[lat_out_lo].item() + end = roff[lat_out_hi].item() + if start == end: + continue + + hi = (col_idx[start:end] // self.nlon_in).long() + r_top = max(0, lat_in_lo - int(hi.min().item())) + r_bot = max(0, int(hi.max().item()) - (lat_in_hi - 1)) + r = max(r, r_top, r_bot) + + return r # ----------------------------------------------------------------------- From 03b5b2c9bbbe21cb3469914772bdb1b809e56204 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 6 Apr 2026 00:42:40 -0700 Subject: [PATCH 06/17] working downsampling attention --- tests/test_distributed_attention.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_distributed_attention.py b/tests/test_distributed_attention.py index 8ac58e96..6eb90e24 100644 --- a/tests/test_distributed_attention.py +++ b/tests/test_distributed_attention.py @@ -95,13 +95,15 @@ def _gather_helper_fwd(self, tensor, attn_dist): wgroup=self.w_group, ) - def _gather_helper_bwd(self, tensor, attn_dist): + def _gather_helper_bwd(self, tensor, attn_dist, use_out_shapes=False): + hshapes = attn_dist.lat_out_shapes if use_out_shapes else attn_dist.lat_in_shapes + wshapes = attn_dist.lon_out_shapes if use_out_shapes else attn_dist.lon_in_shapes return gather_tensor_hw( tensor, hdim=-2, wdim=-1, - hshapes=attn_dist.lat_in_shapes, - wshapes=attn_dist.lon_in_shapes, + hshapes=hshapes, + wshapes=wshapes, hsize=self.grid_size_h, wsize=self.grid_size_w, hrank=self.hrank, @@ -116,7 +118,7 @@ def _gather_helper_bwd(self, tensor, attn_dist): [64, 128, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], [64, 128, 64, 128, 2, 16, 2, None, None, "equiangular", "equiangular", 1e-5, 1e-4], [64, 128, 64, 128, 2, 16, 1, 8, 8, "equiangular", "equiangular", 1e-5, 1e-4], - # [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], + [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], [65, 128, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], ], skip_on_empty=True, @@ -208,7 +210,8 @@ def test_distributed_neighborhood_attention( # ---- compare backward ---- for inp in ["q", "k", "v"]: - igrad_gather = self._gather_helper_bwd(igrad_local[inp], attn_dist) + use_out = (inp == "q") + igrad_gather = self._gather_helper_bwd(igrad_local[inp], attn_dist, use_out_shapes=use_out) self.assertTrue(compare_tensors(f"input gradient {inp}", igrad_full[inp], igrad_gather, atol=atol, rtol=rtol, verbose=verbose)) From 83457eae95c546e2a916a88d4dc0b5af1be2d8dd Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 7 Apr 2026 06:25:43 -0700 Subject: [PATCH 07/17] adding shape checks --- torch_harmonics/attention/attention.py | 7 +++++++ torch_harmonics/distributed/distributed_attention.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/torch_harmonics/attention/attention.py b/torch_harmonics/attention/attention.py index 42841a38..c17335c1 100644 --- a/torch_harmonics/attention/attention.py +++ b/torch_harmonics/attention/attention.py @@ -330,6 +330,13 @@ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value # change this later to allow arbitrary number of batch dims assert (query.dim() == key.dim()) and (key.dim() == value.dim()) and (value.dim() == 4) + if query.shape[-2] != self.nlat_out or query.shape[-1] != self.nlon_out: + raise ValueError(f"query spatial shape {(query.shape[-2], query.shape[-1])} does not match out_shape {(self.nlat_out, self.nlon_out)}") + if key.shape[-2] != self.nlat_in or key.shape[-1] != self.nlon_in: + raise ValueError(f"key spatial shape {(key.shape[-2], key.shape[-1])} does not match in_shape {(self.nlat_in, self.nlon_in)}") + if value.shape[-2] != self.nlat_in or value.shape[-1] != self.nlon_in: + raise ValueError(f"value spatial shape {(value.shape[-2], value.shape[-1])} does not match in_shape {(self.nlat_in, self.nlon_in)}") + # do the scaling query_scaled = query * self.scale diff --git a/torch_harmonics/distributed/distributed_attention.py b/torch_harmonics/distributed/distributed_attention.py index 49464295..a875d386 100644 --- a/torch_harmonics/distributed/distributed_attention.py +++ b/torch_harmonics/distributed/distributed_attention.py @@ -485,6 +485,13 @@ def forward( assert query.dim() == 4 + if query.shape[-2] != self.nlat_out_local or query.shape[-1] != self.nlon_out_local: + raise ValueError(f"query spatial shape {(query.shape[-2], query.shape[-1])} does not match local out_shape {(self.nlat_out_local, self.nlon_out_local)}") + if key.shape[-2] != self.nlat_in_local or key.shape[-1] != self.nlon_in_local: + raise ValueError(f"key spatial shape {(key.shape[-2], key.shape[-1])} does not match local in_shape {(self.nlat_in_local, self.nlon_in_local)}") + if value.shape[-2] != self.nlat_in_local or value.shape[-1] != self.nlon_in_local: + raise ValueError(f"value spatial shape {(value.shape[-2], value.shape[-1])} does not match local in_shape {(self.nlat_in_local, self.nlon_in_local)}") + # scale query query_scaled = query * self.scale From efb151cb061e324ae0d3952dd5bca6bade08f6f7 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 7 Apr 2026 06:58:04 -0700 Subject: [PATCH 08/17] added wrong shape assertion test in attention --- tests/test_attention.py | 35 +++++++++++++++++++++++++++++ tests/test_distributed_attention.py | 34 ++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/tests/test_attention.py b/tests/test_attention.py index 96cee076..c4ced0fc 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -418,5 +418,40 @@ def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, g threshold = _perf_test_thresholds[self.device.type]["bwd_ms"] self.assertTrue(duration <= threshold, msg=f"Backward execution time on device {self.device.type} is too high: {duration:.2f} ms > {threshold:.2f} ms") + def test_wrong_shape_assertions(self): + """Verify that forward raises ValueError on spatial-shape mismatches.""" + B, C = 2, 16 + in_shape = (12, 24) + out_shape = (6, 12) + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + model = NeighborhoodAttentionS2( + in_channels=C, + in_shape=in_shape, + out_shape=out_shape, + grid_in="equiangular", + grid_out="equiangular", + num_heads=1, + bias=False, + ).to(self.device) + + q = torch.randn(B, C, nlat_out, nlon_out, device=self.device) + kv = torch.randn(B, C, nlat_in, nlon_in, device=self.device) + + # 1. Self-attention on an up/downsampling module: a single tensor cannot + # simultaneously satisfy in_shape (for k/v) and out_shape (for q). + with self.assertRaises(ValueError): + model(q) # key defaults to query, but key must have in_shape + + # 2. q_shape == k_shape != v_shape: key carries out_shape instead of in_shape. + with self.assertRaises(ValueError): + model(q, q, kv) + + # 3. q_shape == v_shape != k_shape: value carries out_shape instead of in_shape. + with self.assertRaises(ValueError): + model(q, kv, q) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_distributed_attention.py b/tests/test_distributed_attention.py index 6eb90e24..a2eccf7e 100644 --- a/tests/test_distributed_attention.py +++ b/tests/test_distributed_attention.py @@ -215,5 +215,39 @@ def test_distributed_neighborhood_attention( self.assertTrue(compare_tensors(f"input gradient {inp}", igrad_full[inp], igrad_gather, atol=atol, rtol=rtol, verbose=verbose)) + def test_wrong_shape_assertions(self): + """Verify that forward raises ValueError on spatial-shape mismatches.""" + B, C = 2, 16 + in_shape = (64, 128) + out_shape = (32, 64) + + attn = thd.DistributedNeighborhoodAttentionS2( + in_channels=C, + in_shape=in_shape, + out_shape=out_shape, + grid_in="equiangular", + grid_out="equiangular", + num_heads=1, + bias=False, + ).to(self.device) + + # Build correctly-shaped local tensors using the module's own local extents. + q_local = torch.randn(B, C, attn.nlat_out_local, attn.nlon_out_local, device=self.device) + k_local = torch.randn(B, C, attn.nlat_in_local, attn.nlon_in_local, device=self.device) + + # 1. Self-attention on an up/downsampling module: a single tensor cannot + # simultaneously satisfy in_shape (for k/v) and out_shape (for q). + with self.assertRaises(ValueError): + attn(q_local) # key defaults to query, but key must have in_shape + + # 2. q_shape == k_shape != v_shape: key carries out_shape instead of in_shape. + with self.assertRaises(ValueError): + attn(q_local, q_local, k_local) + + # 3. q_shape == v_shape != k_shape: value carries out_shape instead of in_shape. + with self.assertRaises(ValueError): + attn(q_local, k_local, q_local) + + if __name__ == "__main__": unittest.main() From fb46eae73b4fa67457fb39fd58e8ef1d67e5d526 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 13 Apr 2026 02:07:28 -0700 Subject: [PATCH 09/17] adding qknorm --- tests/test_attention.py | 53 +++-- torch_harmonics/attention/_attention_utils.py | 199 +++--------------- torch_harmonics/attention/attention.py | 61 +++++- .../distributed/distributed_attention.py | 28 ++- 4 files changed, 136 insertions(+), 205 deletions(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index c4ced0fc..63d0e12b 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -35,6 +35,7 @@ from parameterized import parameterized, parameterized_class import torch +import torch.nn.functional as F from torch.library import opcheck # from torch.autograd import gradcheck @@ -71,23 +72,35 @@ def setUp(self): @parameterized.expand( [ - # Format: [batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol] - [4, 4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], - [4, 4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], - [4, 4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], - [4, 4, 8, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], - [4, 8, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], - [4, 8, 4, 4, (12, 24), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], - [4, 8, 4, 4, (6, 12), (12, 24), "equiangular", "equiangular", 1e-5, 1e-3], - [4, 1, 1, 1, (2, 4), (2, 4), "equiangular", "equiangular", 1e-5, 1e-3], - [4, 1, 4, 1, (2, 4), (2, 4), "equiangular", "equiangular", 1e-5, 1e-3], - [4, 4, 4, 4, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3], - [4, 4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3], + # Format: [batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, use_qknorm, atol, rtol] + [4, 4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3], + [4, 4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3], + [4, 4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3], + [4, 4, 8, 4, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3], + [4, 8, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3], + [4, 8, 4, 4, (12, 24), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3], + [4, 8, 4, 4, (6, 12), (12, 24), "equiangular", "equiangular", False, 1e-5, 1e-3], + [4, 1, 1, 1, (2, 4), (2, 4), "equiangular", "equiangular", False, 1e-5, 1e-3], + [4, 1, 4, 1, (2, 4), (2, 4), "equiangular", "equiangular", False, 1e-5, 1e-3], + [4, 4, 4, 4, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", False, 1e-5, 1e-3], + [4, 4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", False, 1e-5, 1e-3], + # same cases with QK norm enabled + [4, 4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3], + [4, 4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3], + [4, 4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3], + [4, 4, 8, 4, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3], + [4, 8, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3], + [4, 8, 4, 4, (12, 24), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3], + [4, 8, 4, 4, (6, 12), (12, 24), "equiangular", "equiangular", True, 1e-5, 1e-3], + [4, 1, 1, 1, (2, 4), (2, 4), "equiangular", "equiangular", True, 1e-5, 1e-3], + [4, 1, 4, 1, (2, 4), (2, 4), "equiangular", "equiangular", True, 1e-5, 1e-3], + [4, 4, 4, 4, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", True, 1e-5, 1e-3], + [4, 4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", True, 1e-5, 1e-3], ], skip_on_empty=True, ) @unittest.skipUnless(optimized_kernels_is_available(), "skipping test because optimized kernels are not available") - def test_custom_implementation(self, batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): + def test_custom_implementation(self, batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, use_qknorm, atol, rtol, verbose=True): """Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation""" if (self.device.type == "cuda") and (not cuda_kernels_is_available()): @@ -111,10 +124,10 @@ def test_custom_implementation(self, batch_size, channels, channels_out, heads, inputs_opt = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()} # reference input and model - model_ref = NeighborhoodAttentionS2(in_channels=channels, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, optimized_kernel=False).to(self.device) + model_ref = NeighborhoodAttentionS2(in_channels=channels, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, use_qknorm=use_qknorm, optimized_kernel=False).to(self.device) # Device model and inputs - model_opt = NeighborhoodAttentionS2(in_channels=channels, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, optimized_kernel=True).to(self.device) + model_opt = NeighborhoodAttentionS2(in_channels=channels, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, use_qknorm=use_qknorm, optimized_kernel=True).to(self.device) # Synchronize parameters of model model_opt.load_state_dict(model_ref.state_dict()) @@ -320,10 +333,12 @@ def test_optimized_pt2_compatibility(self, batch_size, channels, heads, in_shape "q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32), } - test_inputs = (inputs["k"], inputs["v"], inputs["q"], - att.k_weights, att.v_weights, att.q_weights, - att.k_bias, att.v_bias, att.q_bias, - att.quad_weights, att.psi_col_idx, att.psi_roff_idx, + kw = F.conv2d(inputs["k"], att.k_weights, att.k_bias) + vw = F.conv2d(inputs["v"], att.v_weights, att.v_bias) + qw = F.conv2d(inputs["q"], att.q_weights, att.q_bias) * att.scale + + test_inputs = (kw, vw, qw, + att.quad_weights, att.psi_col_idx, att.psi_roff_idx, att.psi_max_nnz, att.num_heads, nlon_in, nlat_out, nlon_out) opcheck(torch.ops.attention_kernels._neighborhood_s2_attention_optimized, test_inputs) diff --git a/torch_harmonics/attention/_attention_utils.py b/torch_harmonics/attention/_attention_utils.py index b7ef7b67..daf298ea 100644 --- a/torch_harmonics/attention/_attention_utils.py +++ b/torch_harmonics/attention/_attention_utils.py @@ -39,8 +39,8 @@ # HELPER ROUTINE FOR BACKWARD setup_context def _setup_context_attention_backward(ctx, inputs, output): - k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs - ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) + kw, vw, qw, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs + ctx.save_for_backward(col_idx, row_off, quad_weights, kw, vw, qw) ctx.nh = nh ctx.max_psi_nnz = max_psi_nnz ctx.nlon_in = nlon_in @@ -93,16 +93,10 @@ def _(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, # forward @torch.library.custom_op("attention_kernels::_neighborhood_s2_attention_optimized", mutates_args=()) - def _neighborhood_s2_attention_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, - wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, - bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + def _neighborhood_s2_attention_optimized(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - kw = F.conv2d(k, weight=wk, bias=bk) - vw = F.conv2d(v, weight=wv, bias=bv) - qw = F.conv2d(q, weight=wq, bias=bq) - # reshape, folding num heads into batch dim B, _, H, W = kw.shape kw = kw.reshape(B*nh, -1, H, W) @@ -130,37 +124,20 @@ def _neighborhood_s2_attention_optimized(k: torch.Tensor, v: torch.Tensor, q: to return output @torch.library.register_fake("attention_kernels::_neighborhood_s2_attention_optimized") - def _(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, - wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, - bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + def _(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (k.shape[0], wv.shape[0], nlat_out, nlon_out) - return torch.empty(out_shape, dtype=k.dtype, device=k.device) + out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) def _neighborhood_s2_attention_bwd_optimized(ctx, grad_output): - col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + col_idx, row_off, quad_weights, kw, vw, qw = ctx.saved_tensors nh = ctx.nh max_psi_nnz = ctx.max_psi_nnz nlon_in = ctx.nlon_in nlat_out = ctx.nlat_out nlon_out = ctx.nlon_out - # check if we need the grads at all - k_needs_grad = ctx.needs_input_grad[0] - v_needs_grad = ctx.needs_input_grad[1] - q_needs_grad = ctx.needs_input_grad[2] - wk_needs_grad = ctx.needs_input_grad[3] - wv_needs_grad = ctx.needs_input_grad[4] - wq_needs_grad = ctx.needs_input_grad[5] - bk_needs_grad = ctx.needs_input_grad[6] - bv_needs_grad = ctx.needs_input_grad[7] - bq_needs_grad = ctx.needs_input_grad[8] - - kw = F.conv2d(k, weight=wk, bias=bk) - vw = F.conv2d(v, weight=wv, bias=bv) - qw = F.conv2d(q, weight=wq, bias=bq) - # reshape, folding num heads into batch dim B, _, H, W = kw.shape kw = kw.reshape(B*nh, -1, H, W) @@ -186,64 +163,15 @@ def _neighborhood_s2_attention_bwd_optimized(ctx, grad_output): col_idx, row_off, nlon_in, nlat_out, nlon_out) - # weight grads - _, C, H, W = dkw.shape - dkw = dkw.reshape(B, -1, H, W) - dkw = dkw.to(dtype=kw_dtype) - if wk_needs_grad: - dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() - else: - dwk = None - - _, C, H, W = dvw.shape - dvw = dvw.reshape(B, -1, H, W) - dvw = dvw.to(dtype=vw_dtype) - if wv_needs_grad: - dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() - else: - dwv = None - - _, C, H, W = dqw.shape - dqw = dqw.reshape(B, -1, H, W) - dqw = dqw.to(dtype=qw_dtype) - if wq_needs_grad: - dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() - else: - dwq = None - - # input grads - if v_needs_grad: - dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) - else: - dv = None - - if k_needs_grad: - dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) - else: - dk = None - - if q_needs_grad: - dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) - else: - dq = None - - # bias grads: - if bv_needs_grad: - dbv = torch.sum(dvw, dim=(0,2,3)) - else: - dbv = None - - if bk_needs_grad: - dbk = torch.sum(dkw, dim=(0,2,3)) - else: - dbk = None - - if bq_needs_grad: - dbq = torch.sum(dqw, dim=(0,2,3)) - else: - dbq = None + # reshape back to original batch dim and convert back precision + _, _, Hk, Wk = dkw.shape + dkw = dkw.reshape(B, -1, Hk, Wk).to(dtype=kw_dtype) + _, _, Hv, Wv = dvw.shape + dvw = dvw.reshape(B, -1, Hv, Wv).to(dtype=vw_dtype) + _, _, Hq, Wq = dqw.shape + dqw = dqw.reshape(B, -1, Hq, Wq).to(dtype=qw_dtype) - return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + return dkw, dvw, dqw, \ None, None, None, None, None, None, None, None # register backward @@ -535,14 +463,9 @@ def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, return dqy @torch.library.custom_op("attention_kernels::_neighborhood_s2_attention_torch", mutates_args=()) -def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, - wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, - bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], +def _neighborhood_s2_attention_torch(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - kw = F.conv2d(k, weight=wk, bias=bk) - vw = F.conv2d(v, weight=wv, bias=bv) - qw = F.conv2d(q, weight=wq, bias=bq) # reshape, folding num heads into batch dim B, _, H, W = kw.shape @@ -560,41 +483,29 @@ def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch. col_idx, row_off, nlon_in, nlat_out, nlon_out) - _, C, H, W = output.shape + _, _, H, W = output.shape output = output.reshape(B, -1, H, W) return output @torch.library.register_fake("attention_kernels::_neighborhood_s2_attention_torch") -def _(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, - wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, - bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], +def _(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (k.shape[0], wv.shape[0], nlat_out, nlon_out) - return torch.empty(out_shape, dtype=k.dtype, device=k.device) + out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) def _neighborhood_s2_attention_bwd_torch(ctx, grad_output): - col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + col_idx, row_off, quad_weights, kw, vw, qw = ctx.saved_tensors nh = ctx.nh nlon_in = ctx.nlon_in nlat_out = ctx.nlat_out nlon_out = ctx.nlon_out # check if we need the grads at all - k_needs_grad = ctx.needs_input_grad[0] - v_needs_grad = ctx.needs_input_grad[1] - q_needs_grad = ctx.needs_input_grad[2] - wk_needs_grad = ctx.needs_input_grad[3] - wv_needs_grad = ctx.needs_input_grad[4] - wq_needs_grad = ctx.needs_input_grad[5] - bk_needs_grad = ctx.needs_input_grad[6] - bv_needs_grad = ctx.needs_input_grad[7] - bq_needs_grad = ctx.needs_input_grad[8] - - kw = F.conv2d(k, weight=wk, bias=bk) - vw = F.conv2d(v, weight=wv, bias=bv) - qw = F.conv2d(q, weight=wq, bias=bq) + kw_needs_grad = ctx.needs_input_grad[0] + vw_needs_grad = ctx.needs_input_grad[1] + qw_needs_grad = ctx.needs_input_grad[2] # reshape, folding num heads into batch dim B, _, H, W = kw.shape @@ -606,85 +517,37 @@ def _neighborhood_s2_attention_bwd_torch(ctx, grad_output): B, _, H, W = grad_output.shape grad_output = grad_output.reshape(B*nh, -1, H, W) - if v_needs_grad or wv_needs_grad or bv_needs_grad: + if vw_needs_grad: dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output, quad_weights, col_idx, row_off, nlon_in, nlat_out, nlon_out) - _, C, H, W = dvw.shape + _, _, H, W = dvw.shape dvw = dvw.reshape(B, -1, H, W) else: dvw = None - if k_needs_grad or wk_needs_grad or bk_needs_grad: + if kw_needs_grad: dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output, quad_weights, col_idx, row_off, nlon_in, nlat_out, nlon_out) - _, C, H, W = dkw.shape + _, _, H, W = dkw.shape dkw = dkw.reshape(B, -1, H, W) else: dkw = None - if q_needs_grad or wq_needs_grad or bq_needs_grad: + if qw_needs_grad: dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output, quad_weights, col_idx, row_off, nlon_in, nlat_out, nlon_out) - _, C, H, W = dqw.shape + _, _, H, W = dqw.shape dqw = dqw.reshape(B, -1, H, W) else: dqw = None - # input grads - if v_needs_grad: - dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) - else: - dv = None - - if k_needs_grad: - dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) - else: - dk = None - - if q_needs_grad: - dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) - else: - dq = None - - # weight grads - if wv_needs_grad: - dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() - else: - dwv = None - - if wk_needs_grad: - dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() - else: - dwk = None - - if wq_needs_grad: - dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() - else: - dwq = None - - # bias grads: - if bv_needs_grad: - dbv = torch.sum(dvw, dim=(0,2,3)) - else: - dbv = None - - if bk_needs_grad: - dbk = torch.sum(dkw, dim=(0,2,3)) - else: - dbk = None - - if bq_needs_grad: - dbq = torch.sum(dqw, dim=(0,2,3)) - else: - dbq = None - - return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + return dkw, dvw, dqw, \ None, None, None, None, None, None, None, None # register backward diff --git a/torch_harmonics/attention/attention.py b/torch_harmonics/attention/attention.py index c17335c1..abe0efb5 100644 --- a/torch_harmonics/attention/attention.py +++ b/torch_harmonics/attention/attention.py @@ -35,6 +35,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from torch_harmonics.quadrature import precompute_latitudes from torch_harmonics.disco.convolution import _precompute_convolution_tensor_s2 @@ -77,6 +78,7 @@ def __init__( grid_in: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular", scale: Optional[Union[torch.Tensor, float]] = None, + use_qknorm: Optional[bool] = False, bias: Optional[bool] = True, k_channels: Optional[int] = None, out_channels: Optional[int] = None, @@ -131,6 +133,13 @@ def __init__( self.v_bias = None self.proj_bias = None + if use_qknorm: + self.q_norm_weights = nn.Parameter(torch.zeros(self.k_channels // self.num_heads)) + self.k_norm_weights = nn.Parameter(torch.zeros(self.k_channels // self.num_heads)) + else: + self.q_norm_weights = None + self.k_norm_weights = None + def extra_repr(self): return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}" @@ -147,7 +156,7 @@ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value # change this later to allow arbitrary number of batch dims assert (query.dim() == key.dim()) and (key.dim() == value.dim()) and (value.dim() == 4) - # perform MLP + # perform QKV projections query = nn.functional.conv2d(query, self.q_weights, bias=self.q_bias) key = nn.functional.conv2d(key, self.k_weights, bias=self.k_bias) value = nn.functional.conv2d(value, self.v_weights, bias=self.v_bias) @@ -168,8 +177,18 @@ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value B, _, C, H, W = value.shape value = value.permute(0,1,3,4,2).reshape(B, self.num_heads, H*W, C) - # multiply the query, key and value tensors - out = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=self.log_quad_weights, dropout_p=self.drop_rate, scale=self.scale) + if self.q_norm_weights is not None: + query = F.rms_norm(query, normalized_shape=self.q_norm_weights.shape, weight=1 + self.q_norm_weights) + if self.k_norm_weights is not None: + key = F.rms_norm(key, normalized_shape=self.k_norm_weights.shape, weight=1 + self.k_norm_weights) + + # apply scale — if scale is a tensor (e.g. learnable), multiply into query + # directly since SDPA only accepts a float scale + if isinstance(self.scale, torch.Tensor): + query = query * self.scale + out = F.scaled_dot_product_attention(query, key, value, attn_mask=self.log_quad_weights, dropout_p=self.drop_rate, scale=1.0) + else: + out = F.scaled_dot_product_attention(query, key, value, attn_mask=self.log_quad_weights, dropout_p=self.drop_rate, scale=self.scale) # reshape B, _, _, C = out.shape @@ -220,6 +239,7 @@ def __init__( grid_out: Optional[str] = "equiangular", num_heads: Optional[int] = 1, scale: Optional[Union[torch.Tensor, float]] = None, + use_qknorm: Optional[bool] = False, bias: Optional[bool] = True, theta_cutoff: Optional[float] = None, k_channels: Optional[int] = None, @@ -297,7 +317,7 @@ def __init__( if scale is not None: self.scale = scale else: - self.scale = 1 / math.sqrt(self.k_channels) + self.scale = 1 / math.sqrt(self.k_channels // self.num_heads) if bias: self.q_bias = nn.Parameter(torch.zeros(self.k_channels)) @@ -310,6 +330,13 @@ def __init__( self.v_bias = None self.proj_bias = None + if use_qknorm: + self.q_norm_weights = nn.Parameter(torch.zeros(self.k_channels // self.num_heads)) + self.k_norm_weights = nn.Parameter(torch.zeros(self.k_channels // self.num_heads)) + else: + self.q_norm_weights = None + self.k_norm_weights = None + if self.optimized_kernel: self.attention_handle = _neighborhood_s2_attention_optimized else: @@ -337,7 +364,25 @@ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value if value.shape[-2] != self.nlat_in or value.shape[-1] != self.nlon_in: raise ValueError(f"value spatial shape {(value.shape[-2], value.shape[-1])} does not match in_shape {(self.nlat_in, self.nlon_in)}") - # do the scaling + # perform QKV projections + query = nn.functional.conv2d(query, self.q_weights, bias=self.q_bias) + key = nn.functional.conv2d(key, self.k_weights, bias=self.k_bias) + value = nn.functional.conv2d(value, self.v_weights, bias=self.v_bias) + + # perform QK normalization (must come before scale) + if self.q_norm_weights is not None: + B, C, H, W = query.shape + query = query.reshape(B, self.num_heads, -1, H, W).permute(0,1,3,4,2) + query = F.rms_norm(query, normalized_shape=self.q_norm_weights.shape, weight=1 + self.q_norm_weights) + query = query.permute(0,1,4,2,3).reshape(B, C, H, W).contiguous() + + if self.k_norm_weights is not None: + B, C, H, W = key.shape + key = key.reshape(B, self.num_heads, -1, H, W).permute(0,1,3,4,2) + key = F.rms_norm(key, normalized_shape=self.k_norm_weights.shape, weight=1 + self.k_norm_weights) + key = key.permute(0,1,4,2,3).reshape(B, C, H, W).contiguous() + + # scale after normalization query_scaled = query * self.scale # TODO: insert dimension checks for input @@ -345,12 +390,6 @@ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value key, value, query_scaled, - self.k_weights, - self.v_weights, - self.q_weights, - self.k_bias, - self.v_bias, - self.q_bias, self.quad_weights, self.psi_col_idx, self.psi_roff_idx, diff --git a/torch_harmonics/distributed/distributed_attention.py b/torch_harmonics/distributed/distributed_attention.py index a875d386..49e688ea 100644 --- a/torch_harmonics/distributed/distributed_attention.py +++ b/torch_harmonics/distributed/distributed_attention.py @@ -330,6 +330,7 @@ def __init__( grid_out: Optional[str] = "equiangular", num_heads: Optional[int] = 1, scale: Optional[Union[torch.Tensor, float]] = None, + use_qknorm: Optional[bool] = False, bias: Optional[bool] = True, theta_cutoff: Optional[float] = None, k_channels: Optional[int] = None, @@ -343,7 +344,7 @@ def __init__( super().__init__( in_channels, in_shape, out_shape, grid_in=grid_in, grid_out=grid_out, - num_heads=num_heads, scale=scale, bias=bias, + num_heads=num_heads, scale=scale, use_qknorm=use_qknorm, bias=bias, theta_cutoff=theta_cutoff, k_channels=k_channels, out_channels=out_channels, optimized_kernel=True, @@ -492,13 +493,26 @@ def forward( if value.shape[-2] != self.nlat_in_local or value.shape[-1] != self.nlon_in_local: raise ValueError(f"value spatial shape {(value.shape[-2], value.shape[-1])} does not match local in_shape {(self.nlat_in_local, self.nlon_in_local)}") - # scale query - query_scaled = query * self.scale - # ---- 1. project to k/v/q ---- - key_proj = nn.functional.conv2d(key, self.k_weights, bias=self.k_bias) - value_proj = nn.functional.conv2d(value, self.v_weights, bias=self.v_bias) - query_proj = nn.functional.conv2d(query_scaled, self.q_weights, bias=self.q_bias) + key_proj = nn.functional.conv2d(key, self.k_weights, bias=self.k_bias) + value_proj = nn.functional.conv2d(value, self.v_weights, bias=self.v_bias) + query_proj = nn.functional.conv2d(query, self.q_weights, bias=self.q_bias) + + # QK normalization (must come before scale) + if self.q_norm_weights is not None: + B, C, H, W = query_proj.shape + query_proj = query_proj.reshape(B, self.num_heads, -1, H, W).permute(0,1,3,4,2) + query_proj = nn.functional.rms_norm(query_proj, normalized_shape=self.q_norm_weights.shape, weight=1 + self.q_norm_weights) + query_proj = query_proj.permute(0,1,4,2,3).reshape(B, C, H, W).contiguous() + + if self.k_norm_weights is not None: + B, C, H, W = key_proj.shape + key_proj = key_proj.reshape(B, self.num_heads, -1, H, W).permute(0,1,3,4,2) + key_proj = nn.functional.rms_norm(key_proj, normalized_shape=self.k_norm_weights.shape, weight=1 + self.k_norm_weights) + key_proj = key_proj.permute(0,1,4,2,3).reshape(B, C, H, W).contiguous() + + # scale after normalization + query_proj = query_proj * self.scale # fold num_heads into batch B, _, H, W = key_proj.shape From 619403b13842414cc2c251cc140a6708beadfdc2 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 13 Apr 2026 02:38:23 -0700 Subject: [PATCH 10/17] added qknorm to attention tests --- tests/test_distributed_attention.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/test_distributed_attention.py b/tests/test_distributed_attention.py index a2eccf7e..1e888564 100644 --- a/tests/test_distributed_attention.py +++ b/tests/test_distributed_attention.py @@ -114,12 +114,18 @@ def _gather_helper_bwd(self, tensor, attn_dist, use_out_shapes=False): @parameterized.expand( [ - # nlat_in, nlon_in, nlat_out, nlon_out, batch_size, in_channels, num_heads, k_channels, out_channels, grid_in, grid_out, atol, rtol - [64, 128, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], - [64, 128, 64, 128, 2, 16, 2, None, None, "equiangular", "equiangular", 1e-5, 1e-4], - [64, 128, 64, 128, 2, 16, 1, 8, 8, "equiangular", "equiangular", 1e-5, 1e-4], - [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], - [65, 128, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4], + # nlat_in, nlon_in, nlat_out, nlon_out, batch_size, in_channels, num_heads, k_channels, out_channels, grid_in, grid_out, use_qknorm, atol, rtol + [64, 128, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], + [64, 128, 64, 128, 2, 16, 2, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], + [64, 128, 64, 128, 2, 16, 1, 8, 8, "equiangular", "equiangular", False, 1e-5, 1e-4], + [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], + [65, 128, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], + # same cases with QK norm enabled + [64, 128, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], + [64, 128, 64, 128, 2, 16, 2, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], + [64, 128, 64, 128, 2, 16, 1, 8, 8, "equiangular", "equiangular", True, 1e-5, 1e-4], + [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], + [65, 128, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], ], skip_on_empty=True, ) @@ -127,7 +133,7 @@ def test_distributed_neighborhood_attention( self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, in_channels, num_heads, k_channels, out_channels, - grid_in, grid_out, + grid_in, grid_out, use_qknorm, atol, rtol, verbose=True, ): @@ -143,6 +149,7 @@ def test_distributed_neighborhood_attention( grid_out=grid_out, num_heads=num_heads, bias=True, + use_qknorm=use_qknorm, k_channels=k_channels, out_channels=out_channels, ) @@ -161,6 +168,9 @@ def test_distributed_neighborhood_attention( attn_dist.v_bias.copy_(attn_serial.v_bias) attn_dist.q_bias.copy_(attn_serial.q_bias) attn_dist.proj_bias.copy_(attn_serial.proj_bias) + if use_qknorm: + attn_dist.q_norm_weights.copy_(attn_serial.q_norm_weights) + attn_dist.k_norm_weights.copy_(attn_serial.k_norm_weights) # Helper: create inputs inp_full = { From 4c2d6a82bad0d4959889de4173ff1cc06591ea8d Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 13 Apr 2026 03:02:10 -0700 Subject: [PATCH 11/17] fixing weight inits --- torch_harmonics/attention/attention.py | 28 +++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch_harmonics/attention/attention.py b/torch_harmonics/attention/attention.py index abe0efb5..5f95828f 100644 --- a/torch_harmonics/attention/attention.py +++ b/torch_harmonics/attention/attention.py @@ -108,18 +108,18 @@ def __init__( log_quad_weights = torch.log(quad_weights).reshape(1,1,-1) self.register_buffer("log_quad_weights", log_quad_weights, persistent=False) - # learnable parameters - # TODO: double-check that this gives us the correct initialization magnitudes - # the standard MHA uses xavier uniform, NATTEN uses kaiming. Let's use that for now + # learnable parameters — Xavier uniform init matching PyTorch MHA convention: + # bound = sqrt(6 / (fan_in + fan_out)) for each projection if self.k_channels % self.num_heads != 0: raise ValueError(f"Please make sure that number of heads {self.num_heads} divides k_channels {self.k_channels} evenly.") if self.out_channels % self.num_heads != 0: raise ValueError(f"Please make sure that number of heads {self.num_heads} divides out_channels {self.out_channels} evenly.") - scale_qkv = math.sqrt(3.0 / self.in_channels) - self.q_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1)) - self.k_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1)) - self.v_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.out_channels, self.in_channels, 1, 1) - 1)) + scale_qk = math.sqrt(6.0 / (self.in_channels + self.k_channels)) + scale_v = math.sqrt(6.0 / (self.in_channels + self.out_channels)) scale_proj = math.sqrt(3.0 / self.out_channels) + self.q_weights = nn.Parameter(scale_qk * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1)) + self.k_weights = nn.Parameter(scale_qk * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1)) + self.v_weights = nn.Parameter(scale_v * (2 * torch.rand(self.out_channels, self.in_channels, 1, 1) - 1)) self.proj_weights = nn.Parameter(scale_proj * (2 * torch.rand(self.out_channels, self.out_channels, 1, 1) - 1)) if bias: @@ -300,18 +300,18 @@ def __init__( self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_roff_idx", roff_idx, persistent=False) - # learnable parameters - # TODO: double-check that this gives us the correct initialization magnitudes - # the standard MHA uses xavier uniform, NATTEN uses kaiming. Let's use that for now + # learnable parameters — Xavier uniform init matching PyTorch MHA convention: + # bound = sqrt(6 / (fan_in + fan_out)) for each projection if self.k_channels % self.num_heads != 0: raise ValueError(f"Please make sure that number of heads {self.num_heads} divides k_channels {self.k_channels} evenly.") if self.out_channels % self.num_heads != 0: raise ValueError(f"Please make sure that number of heads {self.num_heads} divides out_channels {self.out_channels} evenly.") - scale_qkv = math.sqrt(3.0 / self.in_channels) - self.q_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1)) - self.k_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1)) - self.v_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.out_channels, self.in_channels, 1, 1) - 1)) + scale_qk = math.sqrt(6.0 / (self.in_channels + self.k_channels)) + scale_v = math.sqrt(6.0 / (self.in_channels + self.out_channels)) scale_proj = math.sqrt(3.0 / self.out_channels) + self.q_weights = nn.Parameter(scale_qk * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1)) + self.k_weights = nn.Parameter(scale_qk * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1)) + self.v_weights = nn.Parameter(scale_v * (2 * torch.rand(self.out_channels, self.in_channels, 1, 1) - 1)) self.proj_weights = nn.Parameter(scale_proj * (2 * torch.rand(self.out_channels, self.out_channels, 1, 1) - 1)) if scale is not None: From 8a21178cbd04dbeeaeab0af6ccebe9c4c3a80391 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 13 Apr 2026 05:13:11 -0700 Subject: [PATCH 12/17] adding upsampling test with attention --- tests/test_distributed_attention.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/test_distributed_attention.py b/tests/test_distributed_attention.py index 1e888564..5d254b15 100644 --- a/tests/test_distributed_attention.py +++ b/tests/test_distributed_attention.py @@ -115,17 +115,27 @@ def _gather_helper_bwd(self, tensor, attn_dist, use_out_shapes=False): @parameterized.expand( [ # nlat_in, nlon_in, nlat_out, nlon_out, batch_size, in_channels, num_heads, k_channels, out_channels, grid_in, grid_out, use_qknorm, atol, rtol + # same shape tests [64, 128, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], [64, 128, 64, 128, 2, 16, 2, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], [64, 128, 64, 128, 2, 16, 1, 8, 8, "equiangular", "equiangular", False, 1e-5, 1e-4], - [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], [65, 128, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], + # downsampling tests + [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], + [65, 128, 33, 64, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], + [32, 64, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], + [33, 64, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", False, 1e-5, 1e-4], # same cases with QK norm enabled [64, 128, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], [64, 128, 64, 128, 2, 16, 2, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], [64, 128, 64, 128, 2, 16, 1, 8, 8, "equiangular", "equiangular", True, 1e-5, 1e-4], - [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], [65, 128, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], + # downsampling tests + [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], + [65, 128, 33, 64, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], + # upsampling tests + [32, 64, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], + [33, 64, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", True, 1e-5, 1e-4], ], skip_on_empty=True, ) From 430c405b857d63740ec2ddf8d8c6afa1840a72f6 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 13 Apr 2026 05:22:56 -0700 Subject: [PATCH 13/17] fixed testutils for more stable distributed test teardown --- tests/testutils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/testutils.py b/tests/testutils.py index ed522f1e..da98e5f9 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -127,8 +127,10 @@ def setup_distributed_context(ctx): def teardown_distributed_context(ctx): + device_ids = [ctx.device.index] if ctx.device.type == "cuda" else None + dist.barrier(device_ids=device_ids) thd.finalize() - dist.destroy_process_group(None) + dist.destroy_process_group() return From b1c8a0317c9f1845a2e09130a7310ddc9d224136 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 14 Apr 2026 01:01:51 -0700 Subject: [PATCH 14/17] improving pt2 compatibility --- tests/test_distributed_primitives.py | 127 +++++++++++++ .../distributed/distributed_attention.py | 42 +++-- torch_harmonics/distributed/primitives.py | 175 ++++++++++-------- 3 files changed, 252 insertions(+), 92 deletions(-) diff --git a/tests/test_distributed_primitives.py b/tests/test_distributed_primitives.py index 42706648..57136330 100644 --- a/tests/test_distributed_primitives.py +++ b/tests/test_distributed_primitives.py @@ -33,12 +33,15 @@ from parameterized import parameterized import torch +import torch.distributed as dist import torch_harmonics.distributed as thd from torch_harmonics.distributed import ( compute_split_shapes, scatter_to_polar_region, gather_from_polar_region, distributed_transpose_polar, + reduce_from_polar_region, + reduce_from_azimuth_region, ) from testutils import ( @@ -47,6 +50,7 @@ teardown_module, setup_class_from_context, split_tensor_dim, + compare_tensors, ) _DIST_CTX = {} @@ -204,5 +208,128 @@ def test_transpose_polar(self, B, C, H, W, dim0, dim1): self.assertTrue(torch.equal(x_gathered, x_full)) +class TestDistributedReduce(unittest.TestCase): + """ + Test reduce_from_polar_region and reduce_from_azimuth_region. + + Forward: each rank holds a different local tensor. The reduce primitive + must produce the same result as manually gathering all local tensors and + summing them (the "local sum + global sum" path). + + Backward: the adjoint of a broadcast-sum (all_reduce) is the identity – + the upstream gradient flows through each rank unchanged. We verify this + by comparing the computed input gradient against the upstream gradient dy. + """ + + @classmethod + def setUpClass(cls): + setup_class_from_context(cls, _DIST_CTX) + + # ------------------------------------------------------------------ + # Reference helpers – use dist.all_gather so we never call the op + # under test as part of the reference computation. + # ------------------------------------------------------------------ + + def _polar_gather_sum(self, x_local): + """Global sum over the polar group via all_gather + Python sum.""" + polar_size = thd.polar_group_size() + if polar_size > 1: + x_all = [torch.empty_like(x_local) for _ in range(polar_size)] + dist.all_gather(x_all, x_local.detach().contiguous(), group=thd.polar_group()) + return torch.stack(x_all, dim=0).sum(dim=0) + else: + return x_local.detach().clone() + + def _azimuth_gather_sum(self, x_local): + """Global sum over the azimuth group via all_gather + Python sum.""" + az_size = thd.azimuth_group_size() + if az_size > 1: + x_all = [torch.empty_like(x_local) for _ in range(az_size)] + dist.all_gather(x_all, x_local.detach().contiguous(), group=thd.azimuth_group()) + return torch.stack(x_all, dim=0).sum(dim=0) + else: + return x_local.detach().clone() + + # ------------------------------------------------------------------ + # Tests + # ------------------------------------------------------------------ + + @parameterized.expand( + [ + # B, C, H, W + [2, 8, 16, 32], + [1, 4, 7, 13], + [3, 16, 8, 16], + ], + skip_on_empty=True, + ) + def test_reduce_from_polar_region_fwd_bwd(self, B, C, H, W): + set_seed(333) + + # Give each rank a distinct contribution so that a missing reduce is + # immediately visible: x_local = randn + hrank. + x_local = torch.randn(B, C, H, W, device=self.device) + float(self.hrank) + + # --- Forward reference: gather-and-sum without using the primitive --- + ref = self._polar_gather_sum(x_local) + + # --- Distributed forward --- + x = x_local.clone().requires_grad_(True) + out = reduce_from_polar_region(x) + + self.assertTrue( + compare_tensors("reduce_from_polar_region fwd", ref, out, + atol=1e-5, rtol=1e-4, verbose=True), + "forward output does not match the reference global sum", + ) + + # --- Backward: the gradient must pass through unchanged --- + dy = torch.randn_like(out) + out.backward(dy) + + self.assertTrue( + compare_tensors("reduce_from_polar_region bwd", dy, x.grad, + atol=1e-5, rtol=1e-4, verbose=True), + "input gradient does not match the upstream gradient (expected pass-through)", + ) + + @parameterized.expand( + [ + # B, C, H, W + [2, 8, 16, 32], + [1, 4, 7, 13], + [3, 16, 8, 16], + ], + skip_on_empty=True, + ) + def test_reduce_from_azimuth_region_fwd_bwd(self, B, C, H, W): + set_seed(333) + + x_local = torch.randn(B, C, H, W, device=self.device) + float(self.wrank) + + # --- Forward reference --- + ref = self._azimuth_gather_sum(x_local) + + # --- Distributed forward --- + x = x_local.clone().requires_grad_(True) + out = reduce_from_azimuth_region(x) + + self.assertTrue( + compare_tensors("reduce_from_azimuth_region fwd", ref, out, + atol=1e-5, rtol=1e-4, verbose=True), + "forward output does not match the reference global sum", + ) + + # --- Backward: pass-through --- + dy = torch.randn_like(out) + out.backward(dy) + + self.assertTrue( + compare_tensors("reduce_from_azimuth_region bwd", dy, x.grad, + atol=1e-5, rtol=1e-4, verbose=True), + "input gradient does not match the upstream gradient (expected pass-through)", + ) + + if __name__ == "__main__": unittest.main() diff --git a/torch_harmonics/distributed/distributed_attention.py b/torch_harmonics/distributed/distributed_attention.py index 49e688ea..8f527a4c 100644 --- a/torch_harmonics/distributed/distributed_attention.py +++ b/torch_harmonics/distributed/distributed_attention.py @@ -50,6 +50,7 @@ # autograd.Function for the ring-step attention kernel calls # --------------------------------------------------------------------------- +@torch.compiler.disable() def _ring_kv(kw_chunk, vw_chunk, az_group, next_nlon_kw, next_nlon_kv): """Async send current chunks, receive next chunks with known shapes.""" send_to, recv_from = get_group_neighbors(az_group) @@ -81,7 +82,6 @@ class _RingNeighborhoodAttentionFn(torch.autograd.Function): @staticmethod def forward( - ctx, kw, vw, qw, psi_col_idx, psi_roff_idx, psi_row_idx, quad_weights, @@ -134,14 +134,30 @@ def forward( if step < az_size - 1: for req in reqs: req.wait() - kw_chunk = recv_kw - vw_chunk = recv_vw + kw_chunk = recv_kw.clone() + vw_chunk = recv_vw.clone() # Finalize: y = y_acc / alpha_sum (both channels-last layout) y_out = y_acc / alpha_sum.unsqueeze(-1) # [B, H, W, C_v] y_out = y_out.permute(0, 3, 1, 2).contiguous() # [B, C_v, H, W] - # Save for backward (kw/vw: channels-first; scalars: [B,H,W]) + # alpha_sum and qdotk_max are returned so setup_context can save them; + # they are marked non-differentiable there, so backward still only + # receives one gradient argument (dy for y_out). + return y_out, alpha_sum, qdotk_max + + @staticmethod + def setup_context(ctx, inputs, output): + (kw, vw, qw, + psi_col_idx, psi_roff_idx, psi_row_idx, + quad_weights, + nlon_in, lon_chunk_starts, nlon_kx_list, + lat_halo_start, nlat_out_local, nlon_out_local, + r_lat, az_group, az_rank, az_size) = inputs + y_out, alpha_sum, qdotk_max = output + # alpha_sum and qdotk_max are internal accumulators, not true outputs; + # marking them non-differentiable keeps backward's signature as (ctx, dy). + ctx.mark_non_differentiable(alpha_sum, qdotk_max) ctx.save_for_backward(kw, vw, qw, psi_col_idx, psi_roff_idx, psi_row_idx, quad_weights, alpha_sum, qdotk_max) ctx.nlon_in = nlon_in @@ -153,10 +169,10 @@ def forward( ctx.az_group = az_group ctx.az_rank = az_rank ctx.az_size = az_size - return y_out @staticmethod - def backward(ctx, dy): + def backward(ctx, dy, _dalpha_sum, _dqdotk_max): + # _dalpha_sum and _dqdotk_max are always None (non-differentiable outputs) (kw, vw, qw, psi_col_idx, psi_roff_idx, psi_row_idx, quad_weights, @@ -217,8 +233,8 @@ def backward(ctx, dy): if step < az_size - 1: for req in reqs: req.wait() - kw_chunk = recv_kw - vw_chunk = recv_vw + kw_chunk = recv_kw.clone() + vw_chunk = recv_vw.clone() # Finalize pass-1: normalize integral, compute dqy # Use the SAVED forward alpha_sum/qdotk_max (same values, but authoritative) @@ -278,8 +294,8 @@ def backward(ctx, dy): nlon_kx_list[next_src], nlon_kx_list[next_src]) for req in reqs: req.wait() - kw_chunk = recv_kw - vw_chunk = recv_vw + kw_chunk = recv_kw.clone() + vw_chunk = recv_vw.clone() if az_size > 1 and az_group is not None: dist.all_reduce(dkw_full_cl, group=az_group) @@ -299,8 +315,8 @@ def backward(ctx, dy): # Return grads for (kw, vw, qw, psi_col, psi_roff, psi_row, quad_weights, # nlon_in, lon_chunk_starts, nlon_kx_list, lat_halo_start, # nlat_out_local, nlon_out_local, r_lat, - # az_group, az_rank, az_size, polar_group) - return dkw, dvw, dqy, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + # az_group, az_rank, az_size) + return dkw, dvw, dqy, None, None, None, None, None, None, None, None, None, None, None, None, None, None # --------------------------------------------------------------------------- @@ -540,7 +556,7 @@ def forward( lat_halo_start = lat_in_starts[self.comm_rank_polar] - self.r_lat # ---- 3. ring attention ---- - out = _RingNeighborhoodAttentionFn.apply( + out, _, _ = _RingNeighborhoodAttentionFn.apply( key_halo, value_halo, query_proj, diff --git a/torch_harmonics/distributed/primitives.py b/torch_harmonics/distributed/primitives.py index 21088f1b..70624f56 100644 --- a/torch_harmonics/distributed/primitives.py +++ b/torch_harmonics/distributed/primitives.py @@ -32,7 +32,6 @@ import torch import torch.distributed as dist -from torch.amp import custom_fwd, custom_bwd from .utils import config as thd_config from .utils import polar_group, azimuth_group, polar_group_size, polar_group_rank @@ -137,52 +136,51 @@ def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False, class _DistributeTransposeAzimuth(torch.autograd.Function): @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, x, dims, dim1_split_sizes): - + def forward(x, dims, dim1_split_sizes): # WAR for a potential contig check torch bug for channels last contig tensors x = x.contiguous() - xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) - x = torch.cat(xlist, dim=dims[1]).contiguous() + xlist, _, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) + return torch.cat(xlist, dim=dims[1]).contiguous() + + @staticmethod + def setup_context(ctx, inputs, output): + x, dims, _ = inputs ctx.dims = dims - ctx.dim0_split_sizes = dim0_split_sizes - - return x + comm_size = dist.get_world_size(group=azimuth_group()) + ctx.dim0_split_sizes = compute_split_shapes(x.shape[dims[0]], comm_size) @staticmethod - @custom_bwd(device_type="cuda") def backward(ctx, go): dims = ctx.dims dim0_split_sizes = ctx.dim0_split_sizes + # WAR for a potential contig check torch bug for channels last contig tensors go = go.contiguous() - # WAR for a potential contig check torch bug for channels last contig tensors gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group()) gi = torch.cat(gilist, dim=dims[0]).contiguous() - return gi, None, None - + class _DistributeTransposePolar(torch.autograd.Function): @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, x, dims, dim1_split_sizes): - - # WAR for a potential contig check torch bug for channels last contig tensors + def forward(x, dims, dim1_split_sizes): + # WAR for a potential contig check torch bug for channels last contig tensors x = x.contiguous() - xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=polar_group()) - x = torch.cat(xlist, dim=dims[1]).contiguous() + xlist, _, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=polar_group()) + return torch.cat(xlist, dim=dims[1]).contiguous() + + @staticmethod + def setup_context(ctx, inputs, output): + x, dims, _ = inputs ctx.dims = dims - ctx.dim0_split_sizes = dim0_split_sizes - return x + comm_size = dist.get_world_size(group=polar_group()) + ctx.dim0_split_sizes = compute_split_shapes(x.shape[dims[0]], comm_size) @staticmethod - @custom_bwd(device_type="cuda") def backward(ctx, go): - dims = ctx.dims dim0_split_sizes = ctx.dim0_split_sizes - # WAR for a potential contig check torch bug for channels last contig tensors + # WAR for a potential contig check torch bug for channels last contig tensors go = go.contiguous() gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=polar_group()) gi = torch.cat(gilist, dim=dims[0]).contiguous() @@ -308,25 +306,25 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None): class _CopyToPolarRegion(torch.autograd.Function): - + @staticmethod def symbolic(graph, input_): return input_ - + @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, input_): - + def forward(input_): return input_ - + @staticmethod - @custom_bwd(device_type="cuda") - def backward(ctx, grad_output): + def setup_context(ctx, inputs, output): + pass + @staticmethod + def backward(ctx, grad_output): if is_distributed_polar(): return _reduce(grad_output, group=polar_group()) else: - return grad_output, None + return grad_output class _CopyToAzimuthRegion(torch.autograd.Function): @@ -336,19 +334,19 @@ def symbolic(graph, input_): return input_ @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, input_): - + def forward(input_): return input_ @staticmethod - @custom_bwd(device_type="cuda") - def backward(ctx, grad_output): + def setup_context(ctx, inputs, output): + pass + @staticmethod + def backward(ctx, grad_output): if is_distributed_azimuth(): return _reduce(grad_output, group=azimuth_group()) else: - return grad_output, None + return grad_output class _ScatterToPolarRegion(torch.autograd.Function): @@ -358,19 +356,20 @@ def symbolic(graph, input_, dim_): return _split(input_, dim_, group=polar_group()) @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, input_, dim_): + def forward(input_, dim_): if is_distributed_polar(): - ctx.dim = dim_ - ctx.split_shapes = compute_split_shapes( - input_.shape[dim_], polar_group_size() - ) return _split(input_, dim_, group=polar_group()) else: return input_ @staticmethod - @custom_bwd(device_type="cuda") + def setup_context(ctx, inputs, output): + input_, dim_ = inputs + ctx.dim = dim_ + if is_distributed_polar(): + ctx.split_shapes = compute_split_shapes(input_.shape[dim_], polar_group_size()) + + @staticmethod def backward(ctx, grad_output): if is_distributed_polar(): return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None @@ -385,16 +384,18 @@ def symbolic(graph, input_, dim_, shapes_): return _gather(input_, dim_, shapes_, polar_group()) @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, input_, dim_, shapes_): + def forward(input_, dim_, shapes_): if is_distributed_polar(): - ctx.dim = dim_ return _gather(input_, dim_, shapes_, group=polar_group()) else: return input_ @staticmethod - @custom_bwd(device_type="cuda") + def setup_context(ctx, inputs, output): + _, dim_, _ = inputs + ctx.dim = dim_ + + @staticmethod def backward(ctx, grad_output): if is_distributed_polar(): return _split(grad_output, ctx.dim, group=polar_group()), None, None @@ -403,7 +404,7 @@ def backward(ctx, grad_output): class _ReduceFromPolarRegion(torch.autograd.Function): - + @staticmethod def symbolic(graph, input_): if is_distributed_polar(): @@ -412,21 +413,23 @@ def symbolic(graph, input_): return input_ @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, input_): + def forward(input_): if is_distributed_polar(): return _reduce(input_, group=polar_group()) else: return input_ @staticmethod - @custom_bwd(device_type="cuda") + def setup_context(ctx, inputs, output): + pass + + @staticmethod def backward(ctx, grad_output): return grad_output - + class _ReduceFromAzimuthRegion(torch.autograd.Function): - + @staticmethod def symbolic(graph, input_): if is_distributed_azimuth(): @@ -435,15 +438,17 @@ def symbolic(graph, input_): return input_ @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, input_): + def forward(input_): if is_distributed_azimuth(): return _reduce(input_, group=azimuth_group()) else: return input_ @staticmethod - @custom_bwd(device_type="cuda") + def setup_context(ctx, inputs, output): + pass + + @staticmethod def backward(ctx, grad_output): return grad_output @@ -458,19 +463,20 @@ def symbolic(graph, input_, dim_): return input_ @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, input_, dim_): + def forward(input_, dim_): if is_distributed_polar(): - ctx.dim = dim_ - ctx.split_shapes = compute_split_shapes( - input_.shape[dim_], polar_group_size() - ) return _reduce_scatter(input_, dim_, group=polar_group()) else: return input_ @staticmethod - @custom_bwd(device_type="cuda") + def setup_context(ctx, inputs, output): + input_, dim_ = inputs + ctx.dim = dim_ + if is_distributed_polar(): + ctx.split_shapes = compute_split_shapes(input_.shape[dim_], polar_group_size()) + + @staticmethod def backward(ctx, grad_output): if is_distributed_polar(): return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None @@ -488,16 +494,18 @@ def symbolic(graph, input_, dim_, shapes_): return input_ @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, input_, dim_, shapes_): + def forward(input_, dim_, shapes_): if is_distributed_polar(): - ctx.dim = dim_ return _gather(input_, dim_, shapes_, group=polar_group()) else: return input_ @staticmethod - @custom_bwd(device_type="cuda") + def setup_context(ctx, inputs, output): + _, dim_, _ = inputs + ctx.dim = dim_ + + @staticmethod def backward(ctx, grad_output): if is_distributed_polar(): return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None @@ -512,15 +520,19 @@ def distributed_transpose_azimuth(input_, dims_, shapes_): def distributed_transpose_polar(input_, dims_, shapes_): return _DistributeTransposePolar.apply(input_, dims_, shapes_) +@torch.compiler.disable() def copy_to_polar_region(input_): return _CopyToPolarRegion.apply(input_) +@torch.compiler.disable() def copy_to_azimuth_region(input_): return _CopyToAzimuthRegion.apply(input_) +@torch.compiler.disable() def reduce_from_polar_region(input_): return _ReduceFromPolarRegion.apply(input_) +@torch.compiler.disable() def reduce_from_azimuth_region(input_): return _ReduceFromAzimuthRegion.apply(input_) @@ -559,21 +571,14 @@ class _PolarHaloExchangeFn(torch.autograd.Function): """ @staticmethod - @custom_fwd(device_type="cuda") - def forward(ctx, x, r_lat): + def forward(x, r_lat): if not is_distributed_polar(): return x - ctx.r_lat = r_lat group_size = polar_group_size() group_rank = polar_group_rank() - ctx.group_size = group_size - ctx.group_rank = group_rank prev_rank, next_rank = get_group_neighbors(polar_group()) - ctx.prev_rank = prev_rank - ctx.next_rank = next_rank - ctx.H = x.shape[2] B, C, H, W = x.shape device, dtype = x.device, x.dtype @@ -602,7 +607,18 @@ def forward(ctx, x, r_lat): return torch.cat([recv_top, x, recv_bot], dim=2).contiguous() @staticmethod - @custom_bwd(device_type="cuda") + def setup_context(ctx, inputs, output): + x, r_lat = inputs + ctx.r_lat = r_lat + ctx.H = x.shape[2] + if is_distributed_polar(): + ctx.group_size = polar_group_size() + ctx.group_rank = polar_group_rank() + prev_rank, next_rank = get_group_neighbors(polar_group()) + ctx.prev_rank = prev_rank + ctx.next_rank = next_rank + + @staticmethod def backward(ctx, dout): if not is_distributed_polar(): @@ -656,5 +672,6 @@ def backward(ctx, dout): # Gradients for r_lat is None (not tensors / non-differentiable) return dx, None +@torch.compiler.disable() def polar_halo_exchange(x, r_lat): return _PolarHaloExchangeFn.apply(x, r_lat) From b9a386e6df1eadc39ef560543db586bc0b8bbef5 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 14 Apr 2026 01:38:01 -0700 Subject: [PATCH 15/17] making distributed transpose more robust --- torch_harmonics/distributed/primitives.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/torch_harmonics/distributed/primitives.py b/torch_harmonics/distributed/primitives.py index 70624f56..138c0bc1 100644 --- a/torch_harmonics/distributed/primitives.py +++ b/torch_harmonics/distributed/primitives.py @@ -35,6 +35,7 @@ from .utils import config as thd_config from .utils import polar_group, azimuth_group, polar_group_size, polar_group_rank +from .utils import azimuth_group_size from .utils import is_distributed_polar, is_distributed_azimuth @@ -137,6 +138,8 @@ class _DistributeTransposeAzimuth(torch.autograd.Function): @staticmethod def forward(x, dims, dim1_split_sizes): + if not is_distributed_azimuth(): + return x # WAR for a potential contig check torch bug for channels last contig tensors x = x.contiguous() xlist, _, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) @@ -146,11 +149,13 @@ def forward(x, dims, dim1_split_sizes): def setup_context(ctx, inputs, output): x, dims, _ = inputs ctx.dims = dims - comm_size = dist.get_world_size(group=azimuth_group()) - ctx.dim0_split_sizes = compute_split_shapes(x.shape[dims[0]], comm_size) + if is_distributed_azimuth(): + ctx.dim0_split_sizes = compute_split_shapes(x.shape[dims[0]], azimuth_group_size()) @staticmethod def backward(ctx, go): + if not is_distributed_azimuth(): + return go, None, None dims = ctx.dims dim0_split_sizes = ctx.dim0_split_sizes # WAR for a potential contig check torch bug for channels last contig tensors @@ -164,6 +169,8 @@ class _DistributeTransposePolar(torch.autograd.Function): @staticmethod def forward(x, dims, dim1_split_sizes): + if not is_distributed_polar(): + return x # WAR for a potential contig check torch bug for channels last contig tensors x = x.contiguous() xlist, _, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=polar_group()) @@ -173,11 +180,13 @@ def forward(x, dims, dim1_split_sizes): def setup_context(ctx, inputs, output): x, dims, _ = inputs ctx.dims = dims - comm_size = dist.get_world_size(group=polar_group()) - ctx.dim0_split_sizes = compute_split_shapes(x.shape[dims[0]], comm_size) + if is_distributed_polar(): + ctx.dim0_split_sizes = compute_split_shapes(x.shape[dims[0]], polar_group_size()) @staticmethod def backward(ctx, go): + if not is_distributed_polar(): + return go, None, None dims = ctx.dims dim0_split_sizes = ctx.dim0_split_sizes # WAR for a potential contig check torch bug for channels last contig tensors From 482d5727872774c7c6a715b11b3e44b33342c10b Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 16 Apr 2026 04:58:38 -0700 Subject: [PATCH 16/17] updated changelog --- Changelog.md | 14 ++++++++++++++ tests/test_distributed_attention.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/Changelog.md b/Changelog.md index b334c442..56f44701 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,6 +2,20 @@ ## Versioning +### v0.9.1 +* Fixed cross-attention, where output grid size is different from input grid size +* Support for DistributedNeighborhoodAttentionS2. This layer uses a 2-stage kernel to compute the attention per spatial parallel rank and performs an online update using ring exchange. Neighboring points in latitude are gathered using halo exchange +* Added proper shape checks in all attention layers +* Optional QK normalization (`use_qknorm=True`) for `AttentionS2` and `NeighborhoodAttentionS2`, applying per-head RMS normalization to Q and K projections +* Fixed weight initialization in `AttentionS2` and `NeighborhoodAttentionS2`: Q/K/V projections now use correct gain factors when innput dim != embedding dim +* Fixed default attention scale in `NeighborhoodAttentionS2`: now divides by `k_channels // num_heads` instead of `k_channels` +* New distributed primitives: differentiable `polar_halo_exchange` and `get_group_neighbors` to support distributed attention +* New ring-step CUDA kernels for distributed attention: forward (`s2_attn_fwd_ring_step`) and two-pass backward (`s2_attn_bwd_ring_step_pass1/2`) +* Improved robustness of distributed transpose and better `torch.compile` compatibility +* added new tests: + * expanded attention tests for thorough testing of cross-attention, QK normalization, and up/downsampling + * added test comparing distributed attention with serial layer + ### v0.9.0 * New CPU backend (OpenMP-accelerated) for both DISCO convolution and attention layers diff --git a/tests/test_distributed_attention.py b/tests/test_distributed_attention.py index 5d254b15..34b44f05 100644 --- a/tests/test_distributed_attention.py +++ b/tests/test_distributed_attention.py @@ -191,7 +191,7 @@ def test_distributed_neighborhood_attention( # ---- serial forward ---- out_full = attn_serial(inp_full["q"], inp_full["k"], inp_full["v"]) - + torch.cuda.synchronize() # ---- serial backward ---- From 025879dcb0c5b2ae704008466ee12a38e989f221 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 16 Apr 2026 05:49:07 -0700 Subject: [PATCH 17/17] bumping version --- pyproject.toml | 2 +- torch_harmonics/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 455e48b0..e81d7d28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ [tool.setuptools_scm] -fallback_version = "0.9.0" +fallback_version = "0.9.1a" [tool.setuptools.packages.find] include = ["torch_harmonics*"] diff --git a/torch_harmonics/__init__.py b/torch_harmonics/__init__.py index 37ee4e38..a795f60c 100644 --- a/torch_harmonics/__init__.py +++ b/torch_harmonics/__init__.py @@ -29,7 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -__version__ = "0.9.0" +__version__ = "0.9.1a" from .truncation import truncate_sht from .quadrature import QuadratureS2