Parent
Part of #1243 — DeepSeek NSA kernels for MI350
Description
Build a comprehensive correctness test suite that validates all NSA kernels against a pure PyTorch dense attention reference implementation.
Test matrix
| Kernel |
Reference |
Tolerance |
| Mean pooling |
torch.mean(K.reshape(B,N//bs,bs,G,D), dim=2) |
exact (FP32), 1e-3 (FP16) |
| Compressed attention |
F.scaled_dot_product_attention(Q, K_cmp, V_cmp) |
1e-3 abs, 1e-2 rel |
| Top-k selection |
torch.topk(scores, k=block_count) |
exact (integer indices) |
| Selection attention |
Dense attention masked to selected blocks |
1e-3 abs, 1e-2 rel |
| Sliding window |
flash_attn_func(..., window_size=...) |
1e-3 abs |
| Gated combination |
g_cmp*O_cmp + g_slc*O_slc + g_swa*O_swa |
exact (FP32), 1e-4 (FP16) |
| Full NSA pipeline |
Dense full attention |
1e-2 rel (aggregate) |
Test configurations
configs = [
dict(B=1, M=1024, N=1024, H=32, G=4, D=64, bs=32, bc=8, ws=128),
dict(B=2, M=4096, N=4096, H=64, G=8, D=128, bs=64, bc=16, ws=512),
dict(B=1, M=64, N=65536, H=128, G=8, D=128, bs=64, bc=16, ws=512), # decode-like
dict(B=1, M=1, N=8192, H=128, G=8, D=128, bs=64, bc=16, ws=512), # single-token decode
]
Edge cases
- Sequence length not divisible by block_size
- block_count > N // block_size (more blocks requested than available)
- window_size > N (window larger than sequence)
- Causal mask at sequence boundaries
- GQA with HEADS_PER_GROUP=1 (MHA) and HEADS_PER_GROUP=H (MQA)
Infrastructure
- Use pytest with parametrized fixtures
- Run on both CPU (FP64 reference) and MI350 (FP16 kernel)
- CI integration: add to wave test suite
Depends on
Parent
Part of #1243 — DeepSeek NSA kernels for MI350
Description
Build a comprehensive correctness test suite that validates all NSA kernels against a pure PyTorch dense attention reference implementation.
Test matrix
torch.mean(K.reshape(B,N//bs,bs,G,D), dim=2)F.scaled_dot_product_attention(Q, K_cmp, V_cmp)torch.topk(scores, k=block_count)flash_attn_func(..., window_size=...)g_cmp*O_cmp + g_slc*O_slc + g_swa*O_swaTest configurations
Edge cases
Infrastructure
Depends on