Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 59 additions & 36 deletions tests/kernel/wave/asm/test_waveasm_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,7 @@ def _dbuf_mxfp4_helper(
reorder_workgroups=None,
eliminate_epilogue=False,
linearize_reads=True,
reads_merged=None,
):
"""Shared helper for double-buffered MXFP4 scheduled GEMM tests.

Expand Down Expand Up @@ -1258,10 +1259,18 @@ def _dbuf_mxfp4_helper(
options, gemm, schedule=schedule, dynamic_values=dynamic_values
)

# Verify MLIR contains scaled_mfma operation
# Verify MLIR contains scaled_mfma operation.
assert (
"amdgpu.scaled_mfma" in kernel_info.mlir_text
), "Expected amdgpu.scaled_mfma operation in MLIR"
# Check whether scale reads were merged by merge_contiguous_reads.
if reads_merged is not None:
actually_merged = "vector<1xi8>" not in kernel_info.mlir_text
assert actually_merged == reads_merged, (
f"Scale read merge status changed: expected "
f"{'merged' if reads_merged else 'unmerged'}, "
f"got {'merged' if actually_merged else 'unmerged'}"
)
if dynamic_dims:
expected_idx = r"function_type = \([^)]*index, index, index\) -> \(\)"
expected_msg = "M, N, and K"
Expand Down Expand Up @@ -1336,33 +1345,43 @@ def _dbuf_mxfp4_helper(
)


def _mxfp4_config(shape, block, wave_shape, reads_merged=False):
"""Build a pytest.param for an MXFP4 preshuffle config with auto-generated id."""
block_id = f"{block[0]}x{block[1]}x{block[2]}_w{'x'.join(map(str, wave_shape))}"
return pytest.param(shape, block, wave_shape, reads_merged, id=block_id)


@param_bool("dynamic_dims", "dyn")
@param_bool("use_buffer_ops", "bufops")
@param_bool("use_schedule", "sched")
@pytest.mark.parametrize("eliminate_epilogue", [True, False], ids=["ee", "no_ee"])
@pytest.mark.parametrize("output_dtype", ["f32", "bf16"])
@pytest.mark.parametrize(
"shape,block,wave_shape",
"shape,block,wave_shape,reads_merged",
[
pytest.param((1024, 1024, 8192), (128, 256, 256), (1, 4), id="128x256x256"),
pytest.param((1024, 1024, 8192), (128, 32, 256), (2, 2), id="128x32x256"),
pytest.param((896, 640, 8192), (224, 160, 256), (2, 2), id="224x160x256"),
pytest.param((1024, 768, 8192), (256, 192, 256), (1, 4), id="256x192x256"),
pytest.param((1024, 640, 8192), (256, 160, 256), (2, 2), id="256x160x256"),
pytest.param((1024, 896, 8192), (256, 224, 256), (2, 2), id="256x224x256"),
_mxfp4_config((1024, 1024, 8192), (128, 256, 256), (1, 4), reads_merged=True),
_mxfp4_config((1024, 1024, 8192), (128, 256, 256), (4, 1), reads_merged=True),
_mxfp4_config((1024, 1024, 8192), (128, 32, 256), (2, 2)),
_mxfp4_config((896, 640, 8192), (224, 160, 256), (2, 2)),
_mxfp4_config((1024, 768, 8192), (256, 192, 256), (1, 4)),
_mxfp4_config((1024, 768, 8192), (256, 192, 256), (2, 2), reads_merged=True),
_mxfp4_config((1024, 640, 8192), (256, 160, 256), (2, 2)),
_mxfp4_config((1024, 896, 8192), (256, 224, 256), (2, 2)),
],
)
def test_dbuf_4wave_mxfp4_gemm_cpp_backend(
shape,
block,
wave_shape,
reads_merged,
dynamic_dims,
use_buffer_ops,
use_schedule,
eliminate_epilogue,
output_dtype,
compiler,
dump_asm,
request,
):
"""End-to-end test for asymmetric MXFP4 GEMM with 4 waves.

Expand All @@ -1382,27 +1401,34 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend(
"not yet supported"
)

def expect_fail(reason, is_crashing=False):
"""Mark test as expected failure; XPASS (strict) catches silent fixes."""
if is_crashing:
# TODO: Some of those tests segfault, crashing the test runner.
pytest.xfail(reason)
request.node.add_marker(pytest.mark.xfail(reason=reason, strict=True))

# VGPR overflow: 256x224x256 scheduled pipeline exceeds 256 VGPR limit.
if block_id == "256x224x256" and use_schedule:
pytest.xfail("C++ ASM backend exceeds VGPR limit with scheduled pipeline")
expect_fail("C++ ASM backend exceeds VGPR limit with scheduled pipeline")

# VGPR overflow: 224x160x256 scheduled pipeline exceeds 256 VGPR limit
# without epilogue elimination; with ee=True it fits for static dims
# but dynamic dims adds enough extra VGPRs to overflow again.
if block_id == "224x160x256" and use_schedule:
if not eliminate_epilogue:
pytest.xfail(
expect_fail(
"C++ ASM backend exceeds VGPR limit with scheduled pipeline "
"(ee=False) for 224x160x256"
)
elif dynamic_dims:
pytest.xfail(
expect_fail(
"C++ ASM backend exceeds VGPR limit with ee=True + dynamic "
"dims for 224x160x256"
)
else:
# TODO (Gaurav/Sanket): should be passing after all the cherry-picks
pytest.xfail(
expect_fail(
"VGPR overflow: 224x160x256 ee + scheduled pipeline + static "
"dims exceeds register limit (register index is out of range)"
)
Expand All @@ -1411,38 +1437,39 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend(
# but dynamic dims adds enough extra VGPRs to overflow again.
if block_id == "256x160x256" and use_schedule:
if not eliminate_epilogue:
pytest.xfail(
expect_fail(
"C++ ASM backend exceeds VGPR limit with scheduled pipeline "
"(ee=False) for 256x160x256"
)
elif dynamic_dims:
pytest.xfail(
expect_fail(
"C++ ASM backend exceeds VGPR limit with ee=True + dynamic "
"dims for 256x160x256"
)
else:
# TODO (Gaurav/Sanket): should be passing after all the cherry-picks
pytest.xfail(
expect_fail(
"VGPR overflow: 256x160x256 ee + scheduled pipeline + static "
"dims exceeds register limit (register index is out of range)"
"dims exceeds register limit (register index is out of range)",
is_crashing=True,
)

# VGPR overflow for 256x192x256: ee=True reduces register pressure
# enough to pass with static dims; ee=False and dynamic dims still overflow.
if block_id == "256x192x256" and use_schedule:
if not eliminate_epilogue:
pytest.xfail(
expect_fail(
"C++ ASM backend exceeds VGPR limit with scheduled pipeline "
"(ee=False); ee=True resolves this for 256x192x256"
)
elif dynamic_dims:
pytest.xfail(
expect_fail(
"C++ ASM backend exceeds VGPR limit with ee=True + dynamic "
"dims for 256x192x256"
)
else:
# TODO (Gaurav/Sanket): should be passing after all the cherry-picks
pytest.xfail(
expect_fail(
"VGPR overflow: 256x192x256 ee + scheduled pipeline + static "
"dims exceeds register limit (register index is out of range)"
)
Expand All @@ -1451,33 +1478,27 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend(
# dynamic dims when ee=False; dynamic-dims-only when ee=True).
if block_id == "128x32x256" and use_schedule:
if not eliminate_epilogue:
pytest.xfail(
expect_fail(
"Numerical mismatch on (2,2) wave shape with scheduled "
"pipeline (ee=False)"
)
elif dynamic_dims:
pytest.xfail("Numerical mismatch with dynamic dims on (2,2) wave shape")
expect_fail("Numerical mismatch with dynamic dims on (2,2) wave shape")
else:
# TODO (Gaurav/Sanket): should be passing after all the cherry-picks
pytest.xfail(
expect_fail(
"128x32x256 (2,2) wave shape ee + scheduled pipeline + "
"static dims: numerical mismatch (no_bufops) or SGPR "
"overflow s103 not available (bufops)"
"overflow s103 not available (bufops)",
is_crashing=True,
)

# SGPR overflow: 128x256x256 with ee + scheduled pipeline + buffer ops
# exceeds the 102 SGPR limit (static: s103 not available; dynamic: SRD
# allocation assertion).
# TODO (Gaurav/Sanket): should be passing after all the cherry-picks
if (
block_id == "128x256x256"
and eliminate_epilogue
and use_schedule
and use_buffer_ops
):
pytest.xfail(
"C++ ASM backend exceeds SGPR limit for 128x256x256 with "
"ee + scheduled pipeline + buffer ops"
# VGPR overflow: 128x256x256 with (4,1) wave shape and scheduled pipeline
# requires 341 VGPRs but only 256 are available.
if block_id == "128x256x256" and wave_shape == (4, 1) and use_schedule:
expect_fail(
"C++ ASM backend exceeds VGPR limit (341 needed) for "
"128x256x256 (4,1) with scheduled pipeline",
)

# Linearized reads with dynamic dims produce complex floor/Mod
Expand All @@ -1499,6 +1520,7 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend(
wave_shape=wave_shape,
eliminate_epilogue=eliminate_epilogue,
linearize_reads=not skip_linearize,
reads_merged=reads_merged,
)


Expand Down Expand Up @@ -1534,6 +1556,7 @@ def test_dbuf_4wave_mxfp4_dynamic_mn_reorder_cpp_backend(
use_schedule=True,
wave_shape=wave_shape,
reorder_workgroups=True,
reads_merged=True,
)


Expand Down
Loading