Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 69 additions & 19 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from parameterized import parameterized, parameterized_class

import torch
import torch.nn.functional as F
from torch.library import opcheck

# from torch.autograd import gradcheck
Expand Down Expand Up @@ -71,23 +72,35 @@ def setUp(self):

@parameterized.expand(
[
# Format: [batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[4, 4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 8, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 8, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 8, 4, 4, (12, 24), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 8, 4, 4, (6, 12), (12, 24), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 1, 1, 1, (2, 4), (2, 4), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 1, 4, 1, (2, 4), (2, 4), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, 4, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
# Format: [batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, use_qknorm, atol, rtol]
[4, 4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3],
[4, 4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3],
[4, 4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3],
[4, 4, 8, 4, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3],
[4, 8, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3],
[4, 8, 4, 4, (12, 24), (6, 12), "equiangular", "equiangular", False, 1e-5, 1e-3],
[4, 8, 4, 4, (6, 12), (12, 24), "equiangular", "equiangular", False, 1e-5, 1e-3],
[4, 1, 1, 1, (2, 4), (2, 4), "equiangular", "equiangular", False, 1e-5, 1e-3],
[4, 1, 4, 1, (2, 4), (2, 4), "equiangular", "equiangular", False, 1e-5, 1e-3],
[4, 4, 4, 4, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", False, 1e-5, 1e-3],
[4, 4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", False, 1e-5, 1e-3],
# same cases with QK norm enabled
[4, 4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3],
[4, 4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3],
[4, 4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3],
[4, 4, 8, 4, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3],
[4, 8, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3],
[4, 8, 4, 4, (12, 24), (6, 12), "equiangular", "equiangular", True, 1e-5, 1e-3],
[4, 8, 4, 4, (6, 12), (12, 24), "equiangular", "equiangular", True, 1e-5, 1e-3],
[4, 1, 1, 1, (2, 4), (2, 4), "equiangular", "equiangular", True, 1e-5, 1e-3],
[4, 1, 4, 1, (2, 4), (2, 4), "equiangular", "equiangular", True, 1e-5, 1e-3],
[4, 4, 4, 4, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", True, 1e-5, 1e-3],
[4, 4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", True, 1e-5, 1e-3],
],
skip_on_empty=True,
)
@unittest.skipUnless(optimized_kernels_is_available(), "skipping test because optimized kernels are not available")
def test_custom_implementation(self, batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
def test_custom_implementation(self, batch_size, channels, channels_out, heads, in_shape, out_shape, grid_in, grid_out, use_qknorm, atol, rtol, verbose=True):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""

if (self.device.type == "cuda") and (not cuda_kernels_is_available()):
Expand All @@ -111,10 +124,10 @@ def test_custom_implementation(self, batch_size, channels, channels_out, heads,
inputs_opt = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()}

# reference input and model
model_ref = NeighborhoodAttentionS2(in_channels=channels, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, optimized_kernel=False).to(self.device)
model_ref = NeighborhoodAttentionS2(in_channels=channels, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, use_qknorm=use_qknorm, optimized_kernel=False).to(self.device)

# Device model and inputs
model_opt = NeighborhoodAttentionS2(in_channels=channels, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, optimized_kernel=True).to(self.device)
model_opt = NeighborhoodAttentionS2(in_channels=channels, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, use_qknorm=use_qknorm, optimized_kernel=True).to(self.device)

# Synchronize parameters of model
model_opt.load_state_dict(model_ref.state_dict())
Expand Down Expand Up @@ -320,10 +333,12 @@ def test_optimized_pt2_compatibility(self, batch_size, channels, heads, in_shape
"q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32),
}

test_inputs = (inputs["k"], inputs["v"], inputs["q"],
att.k_weights, att.v_weights, att.q_weights,
att.k_bias, att.v_bias, att.q_bias,
att.quad_weights, att.psi_col_idx, att.psi_roff_idx,
kw = F.conv2d(inputs["k"], att.k_weights, att.k_bias)
vw = F.conv2d(inputs["v"], att.v_weights, att.v_bias)
qw = F.conv2d(inputs["q"], att.q_weights, att.q_bias) * att.scale

test_inputs = (kw, vw, qw,
att.quad_weights, att.psi_col_idx, att.psi_roff_idx,
att.psi_max_nnz, att.num_heads, nlon_in, nlat_out, nlon_out)

opcheck(torch.ops.attention_kernels._neighborhood_s2_attention_optimized, test_inputs)
Expand Down Expand Up @@ -418,5 +433,40 @@ def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, g
threshold = _perf_test_thresholds[self.device.type]["bwd_ms"]
self.assertTrue(duration <= threshold, msg=f"Backward execution time on device {self.device.type} is too high: {duration:.2f} ms > {threshold:.2f} ms")

def test_wrong_shape_assertions(self):
"""Verify that forward raises ValueError on spatial-shape mismatches."""
B, C = 2, 16
in_shape = (12, 24)
out_shape = (6, 12)
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

model = NeighborhoodAttentionS2(
in_channels=C,
in_shape=in_shape,
out_shape=out_shape,
grid_in="equiangular",
grid_out="equiangular",
num_heads=1,
bias=False,
).to(self.device)

q = torch.randn(B, C, nlat_out, nlon_out, device=self.device)
kv = torch.randn(B, C, nlat_in, nlon_in, device=self.device)

# 1. Self-attention on an up/downsampling module: a single tensor cannot
# simultaneously satisfy in_shape (for k/v) and out_shape (for q).
with self.assertRaises(ValueError):
model(q) # key defaults to query, but key must have in_shape

# 2. q_shape == k_shape != v_shape: key carries out_shape instead of in_shape.
with self.assertRaises(ValueError):
model(q, q, kv)

# 3. q_shape == v_shape != k_shape: value carries out_shape instead of in_shape.
with self.assertRaises(ValueError):
model(q, kv, q)


if __name__ == "__main__":
unittest.main()
Loading
Loading