Skip to content

[wave] NSA: numerical correctness tests vs dense attention reference #1259

@harsh-nod

Description

@harsh-nod

Parent

Part of #1243 — DeepSeek NSA kernels for MI350

Description

Build a comprehensive correctness test suite that validates all NSA kernels against a pure PyTorch dense attention reference implementation.

Test matrix

Kernel Reference Tolerance
Mean pooling torch.mean(K.reshape(B,N//bs,bs,G,D), dim=2) exact (FP32), 1e-3 (FP16)
Compressed attention F.scaled_dot_product_attention(Q, K_cmp, V_cmp) 1e-3 abs, 1e-2 rel
Top-k selection torch.topk(scores, k=block_count) exact (integer indices)
Selection attention Dense attention masked to selected blocks 1e-3 abs, 1e-2 rel
Sliding window flash_attn_func(..., window_size=...) 1e-3 abs
Gated combination g_cmp*O_cmp + g_slc*O_slc + g_swa*O_swa exact (FP32), 1e-4 (FP16)
Full NSA pipeline Dense full attention 1e-2 rel (aggregate)

Test configurations

configs = [
    dict(B=1, M=1024, N=1024, H=32, G=4, D=64, bs=32, bc=8, ws=128),
    dict(B=2, M=4096, N=4096, H=64, G=8, D=128, bs=64, bc=16, ws=512),
    dict(B=1, M=64, N=65536, H=128, G=8, D=128, bs=64, bc=16, ws=512),  # decode-like
    dict(B=1, M=1, N=8192, H=128, G=8, D=128, bs=64, bc=16, ws=512),   # single-token decode
]

Edge cases

  • Sequence length not divisible by block_size
  • block_count > N // block_size (more blocks requested than available)
  • window_size > N (window larger than sequence)
  • Causal mask at sequence boundaries
  • GQA with HEADS_PER_GROUP=1 (MHA) and HEADS_PER_GROUP=H (MQA)

Infrastructure

  • Use pytest with parametrized fixtures
  • Run on both CPU (FP64 reference) and MI350 (FP16 kernel)
  • CI integration: add to wave test suite

Depends on

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestnsaDeepSeek Native Sparse Attention

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions