Skip to content

[WIP: Rebased branch] MXFP4 GEMM through Waveasm#1182

Open
panditsa wants to merge 32 commits intoiree-org:mainfrom
panditsa:working_256x192_rebased
Open

[WIP: Rebased branch] MXFP4 GEMM through Waveasm#1182
panditsa wants to merge 32 commits intoiree-org:mainfrom
panditsa:working_256x192_rebased

Conversation

@panditsa
Copy link
Copy Markdown
Contributor

Do not merge, this branch is to track changes merged into main.

panditsa and others added 30 commits March 24, 2026 19:30
Add Assumption(K >= 2048) to the MXFP4 preshuffle template. This
enables the IndexMapping simplification pass to prove that
within_nblk (max 1023) < K_PACKED (= K/2 >= 1024), eliminating
dynamic floordiv/mod for the B-data preshuffle mapping.

Key fixes:
- Use subs_idxc to resolve derived symbols (K_PACKED -> K//2) before
  the divisor lower-bound check, so constraint-based bounds on K
  propagate to K_PACKED.
- Extract symbol lower bounds from Assumption(S >= c) and
  Assumption(S > c) constraints.
- Guard subs_idxc for contexts where IndexingContext is unavailable.

The pass now correctly detects that within_nblk < K_PACKED and
rewrites the IndexMapping to eliminate the dynamic floordiv/mod.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds annotate_iv_strides pass that statically computes the per-iteration
stride of induction variables through index mappings. This enables the
codegen to fold constant IV strides directly into buffer_load voffset
fields, eliminating runtime address arithmetic.

Key components:
- extract_iv() in symbol_utils: separates IV from complex floor/Mod exprs
  using integer-division decomposition identities
- compute_iv_stride_through_mapping() in mapping_utils: probes mappings
  numerically to determine constant strides through symbolic dimensions
- annotate_iv_strides pass: wires stride analysis into the compilation
  pipeline, attaching results as node metadata

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When pipeline double-buffering creates the loop body, the in-loop scale
reloads from the swapped LDS buffer can be sub-word loads (ds_read_u8,
ds_read_u16) rather than extract_strided_slice from a vector<4xi8>.
This caused _find_mergeable_groups to reject the group entirely, so
_coalesce_vector_iter_args never fired and the scalar scale extraction
pattern (v_bfe_u32 + individual byte loads) persisted in the final asm.

Relax _find_mergeable_groups to accept groups where init values trace
to extract_strided_slice but yield values do not. For such groups,
create a wide vector<4xi8> load from the byte-0 yield load's base
address before the yield, giving the coalesce logic a proper dword
source. The existing opsel replacement then converts the scalar scale
chains to indexed vector<4xf8E8M0FNU> access with scalesIdx, producing
ds_read_b32 + op_sel instead of ds_read_u8/u16 + v_bfe_u32.

Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
The gather_to_lds codegen emitted two redundant layers of OOB
protection: (1) SRD validBytes/num_records hardware clamping via
_compute_branchless_valid_bytes, and (2) a per-load software mask
(_build_mask + select with OOB sentinel 0x7FFFFFFF). Layer 2 cost
~4 VALU instructions and a temporary VGPR per gather_to_lds call
(~32 VALU + ~8 VGPRs per loop phase).

The root cause was that _compute_branchless_valid_bytes used
0x7FFFFFFE (hardware max) as `real_valid`, providing no actual
clamping when the guard condition was true. The software mask was
the only real bounds check.

- _compute_branchless_valid_bytes: when the shape is dynamic
  (static path returns 0x7FFFFFFE fallback), compute the actual
  buffer size at runtime via gen_sympy_index(product(shape) *
  elem_bytes). This makes the SRD num_records reflect the real
  buffer dimensions, enabling hardware clamping for OOB addresses.
- handle_gather_to_lds: skip _build_mask + select when
  valid_bytes_override is set (g2s_guard exists with real bounds).
  Hardware num_records clamping handles all OOB cases.

This matches AITER's approach: raw buffer_load with SRD clamping,
no per-load software bounds checking.

Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
The _merge_scale_byte_loads pass merged sub-word LDS scale loads
into a single vector<4xi8> load, but had two correctness issues
that caused ~25% mismatch on large asymmetric shapes (e.g.
block 256,192,256):

1. Loads were grouped only by memref identity, not by actual base
   address. Loads from the same memref but at different addresses
   (different affine map operands/expressions) were incorrectly
   merged, causing the wide load to read from the wrong
   double-buffer phase.

2. Merging vector<2xi8> loads created size-2 extracts from the
   wide vector<4xi8>, which broke _trace_scale_chain (expects
   vector<4xi8> source) and prevented opsel for those bytes.

Fix: restrict merging to vector<1xi8> loads only, and group by
(memref, prefix indices, affine base expression) so loads at
different base addresses are never merged.

Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Scalar kernel arguments (M, N, K) are now kept in dedicated SGPRs
instead of being unconditionally copied to VGPRs. Downstream uniform
arithmetic (ceildiv, floordiv, mod, comparisons) auto-selects SALU
when both operands are scalar, reducing VGPR pressure by 5 registers
(232 → 227) and doubling SALU instruction count in the prologue.

Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
When a buffer_load has voffset = V_ADD_U32(vgpr, sgpr) and soffset = 0,
rewrite to use the VGPR directly as voffset and the SGPR as soffset.
This eliminates one VALU instruction per load by leveraging the hardware
scalar offset: effective_addr = SRD_base + voffset + soffset.

Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
…set in assembly emission

Three store translation handlers (handleVectorTransferWrite,
handleRawBufferStore, handleMemRefStore) were passing AGPR-typed values
directly to BUFFER_STORE_DWORD* ops, which require WaveASM_AnyVGPR
operands. This caused verification failures when eliminate_epilogue
stores MFMA accumulator results (which live in AGPRs on gfx950) to
global memory. Insert V_ACCVGPR_READ_B32 before the store in all three
paths, matching the existing pattern in handleVectorStore and
handleVectorMaskedStore.

Additionally, fix the assembly emitter to use "off" syntax instead of
"offen" when the voffset operand is an immediate rather than a VGPR,
since MUBUF instructions require a VGPR when the offen flag is set.

Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
SALU instructions implicitly write the SCC (Scalar Condition Code) flag
on AMDGPU hardware.  When optimizations promote VGPR ops to SGPR (SALU),
the new scalar instructions can silently clobber SCC between a producer
(s_cmp, s_add_u32) and its consumer (s_cbranch_scc*, s_cselect_b32,
s_addc_u32), causing corrupted addresses and memory faults.

This commit adds infrastructure to detect and prevent SCC hazards:

- SCCDef/SCCUse traits (defined but not yet applied to existing ops,
  as adding NativeOpTrait to ODS classes changes generated C++ and
  alters MLIR pass behavior — needs further investigation)
- SALUUnaryWithSCCOp/SALUBinaryWithSCCOp base classes for future
  migration of SCC-clobbering ops away from Pure
- SCC verifier pass (--waveasm-scc-verifier) that walks IR in emission
  order using isa<> checks to detect SCC clobbers between producers
  and consumers, wired into both compilation pipelines
- ScopedCSE guard against future SCCUse-tagged ops

The verifier is intentionally trait-free on existing ops to guarantee
zero codegen impact.  The writesSCC() function enumerates all known
SCC-writing SALU ops by type.

Made-with: Cursor
All SCC-clobbering SALU ops now carry the SCCDef trait via their op
class, and the SCC verifier uses hasTrait<SCCDef>() instead of isa<>
enumeration.  Ops that read SCC (s_cselect_b32) carry SCCUse.

Classification:
- SALUBinaryWithSCCOp (NoMemoryEffect + AlwaysSpeculatable + SCCDef):
  bitwise, shifts, min/max, BFE — equivalent to Pure for MLIR passes
  but identifiable as SCC-clobbering by the verifier
- SALUUnaryWithSCCOp (same traits): s_not, s_brev, s_bcnt*, s_ff*,
  s_flbit*, s_abs
- SALUBinaryWithCarryOp + SCCDef: s_add_u32, s_addc_u32, s_sub_u32
- SALUCmpOp + SCCDef: all s_cmp_* variants
- s_cselect_b32: standalone with SCCUse + NoMemoryEffect (not Pure,
  not CSE-eligible — result depends on implicit SCC)
- SALUBinaryOp (Pure, no SCC): s_mul_i32, s_bfm_*, s_pack_*
- SALUUnaryOp (Pure, no SCC): s_mov_b32/b64, s_sext_*

Key insight: replacing Pure with explicit NoMemoryEffect +
AlwaysSpeculatableImplTrait + SCCDef produces bit-identical assembly
while enabling trait-based SCC verification.

Validated: test passes, assembly identical to pre-change baseline.
Made-with: Cursor
- emitOr/emitXor helpers in Handlers.h: S_OR_B32/S_XOR_B32 when both
  operands are scalar, V_OR_B32/V_XOR_B32 otherwise (same pattern as
  emitAnd)
- handleArithOrI: uses emitOr instead of always V_OR_B32
- handleArithXorI: uses emitXor instead of handleBinaryVALU
- handleArithDivUI power-of-2: uses emitLshr (has scalar path)
  instead of always V_LSHRREV_B32
- handleArithRemUI power-of-2: uses emitAnd (has scalar path)
  instead of always V_AND_B32

These are safe handler-level promotions: the SALU path only activates
when both operands are already scalar, and the SCC verifier catches
any hazards.  No assembly change for the current GEMM kernel (operands
are thread-ID-derived VGPRs), but enables SALU emission for future
kernels with scalar OR/XOR/div/rem patterns.

Made-with: Cursor
- handleArithCmpI: V_CMP VOP3 can take one SGPR via the constant bus;
  only ensureVGPR when BOTH operands are SGPR (saves v_mov_b32 for
  cases like v_cmp_gt_i32 vcc, sN, 0)
- handleArithSelect: don't ensureVGPR cond before V_CMP_NE_U32 when
  the other operand is an immediate (no constant bus conflict)
- AffineHandlers non-overlapping Add: use emitOr (SALU-aware) instead
  of ensureBothVGPR + V_OR_B32

Reduces v_mov_b32 sN->vN instructions from 14 to 12 in the GEMM
prologue.  V_CNDMASK_B32 SGPR optimization was attempted but
cancelled due to VCC constant bus interaction causing test failures.

Made-with: Cursor
- emitScalarCmp: helper that emits S_CMP_* for a given CmpIPredicate.
  Refactors the existing handleArithCmpI ConditionOp path to use it.
- emitMaxI32/emitMinI32/emitMaxU32/emitMinU32: S_MAX/S_MIN when both
  operands are scalar, V_MAX/V_MIN otherwise.  Updates all four
  handleArithMax/MinSI/UI handlers to use these helpers.

The v_max_i32 in the GEMM prologue (ceildiv bound clamping) becomes
s_max_i32, keeping the value in SGPR for downstream scalar chains.

CmpI→Select fusion (s_cmp+s_cselect for all-scalar cmpi) was attempted
but causes memory faults from cascading SGPR register allocation
changes that corrupt buffer addresses.  Needs deeper investigation of
which downstream consumer breaks when the cmpi result is SGPR instead
of VGPR.  Left as TODO with the emitScalarCmp infrastructure in place
for when that investigation is done.

Made-with: Cursor
- Move emitScalarCmp from static in ArithHandlers.cpp to inline in
  Handlers.h so it can be shared across handler files.
- Add TODO in emitSrdNumRecords documenting the s_cmp+s_cselect
  optimization opportunity for the arith.select(arith.cmpi) pattern.
  Multiple implementation approaches were tested (direct PSRegType,
  s_mov copy, s_add copy) but all produce memory faults despite
  correct-looking assembly. Root cause needs IR dump investigation.

Made-with: Cursor
Root cause: PrecoloredSRegOp for initial SRDs (e.g., output buffer at
s[36:39]) had no downstream SSA users — the epilogue referenced them
via RawOps with hardcoded register numbers. CanonicalizerPass DCE'd the
PrecoloredSRegOp, so LinearScanPass never reserved those SGPRs. Under
increased SGPR pressure the allocator assigned loop temps to s36/s37,
clobbering the SRD base address and causing GPU memory faults.

Fix: add DCEProtectOp after each initial SRD's PrecoloredSRegOp in
emitSRDPrologue to prevent DCE and ensure permanent reservation.

Also enables s_cmp + s_cselect path in emitSrdNumRecords for the
branchless g2s guard pattern (arith.select(arith.cmpi(scalar, scalar),
scalar, scalar)), emitting directly into the SRD word 2 PSReg.

Made-with: Cursor
When emitUnsignedFloordiv receives scalar inputs, the Barrett reduction
must run in VALU (no SALU equivalent for float rcp), but the result is
uniform. Add v_readfirstlane_b32 at the end to extract back to SGPR so
downstream arithmetic stays in SALU instead of cascading to VALU.

Made-with: Cursor
…elect

When arith.select's condition comes from arith.cmpi with scalar operands
and both true/false values are scalar, emit s_cmp + s_cselect directly
as a fused pair. This bypasses the VGPR boolean materialization from
handleArithCmpI, eliminating dead v_cmp + v_cndmask instructions.

Reduces peak VGPRs from 226 to 224 in the 256x192 MXFP4 GEMM kernel.

Made-with: Cursor
When both cmpi operands are scalar and every user is an arith.select
with scalar true/false values, the cmpi+select fusion in
handleArithSelect will emit s_cmp + s_cselect directly. Skip the VALU
v_cmp + v_cndmask boolean materialization entirely, mapping the cmpi
result to a dummy constant.

Eliminates 4 dead VALU instructions per loop iteration (v_mov_b32,
v_cmp_lt_i32, v_mov_b32, v_cndmask_b32) from the branchless guard.

Made-with: Cursor
…fragmentation

Two complementary approaches to reduce peak VGPR count from interleaved
buffer_load/ds_read register assignments:

1. VGPR Compaction Pass (VGPRCompaction.cpp):
   Post-allocation pass that reassigns physical VGPRs using shortest-first
   greedy coloring. Short-lived values (ds_read results) pack into low
   registers, long-lived values (buffer_load prefetch) go to high registers.
   Handles precolored pinning, sub-element remapping, loop liveness extension,
   and _iterArgPhysRegs metadata updates.

2. Bidirectional Linear Scan (RegAlloc.h, LinearScanRegAlloc.cpp):
   allocRangeFromTop/allocSingleFromTop methods that scan the free bitvector
   from high to low with a capped ceiling (maxPressure, not maxRegs) to avoid
   allocating into the AGPR region. Long-lived VGPR ranges (length > 75th
   percentile) get top-down allocation while short-lived ranges get normal
   bottom-up allocation.

The compaction pass reduces VGPRs by ~4 on the 256x192 MXFP4 GEMM kernel.
The bidirectional approach is currently neutral (threshold needs tuning for
specific kernels) but provides the infrastructure for future improvements.

Made-with: Cursor
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
The _merge_scale_byte_loads pass had three bugs that prevented it from
firing and caused a register pressure regression (257 VGPRs, exceeding
the 256 limit for 4-wave occupancy):

1. Python id() was used to group MLIR Values (memref, indices, affine
   operands), but the Python bindings create fresh wrapper objects per
   access, so loads from the same memref were placed into separate
   groups. Use hash() instead, which delegates to mlirValueHash and
   correctly identifies the same underlying Value.

2. Stripping the trailing constant from an affine map string left
   whitespace ("(expr )>" vs "(expr)>"), preventing the offset-0 load
   from grouping with offset-1/2 loads. Add .rstrip() before the
   closing delimiter.

3. Restricting to vector<1xi8> only (the 3baa31d fix) avoided
   breaking _trace_scale_chain but left vector<2xi8> loads unmerged,
   wasting registers. Re-allow vector<2xi8> loads and replace their
   downstream extract_strided_slice(size=1) users with direct byte
   extracts from the wide vector<4xi8>, preserving the source type
   that _trace_scale_chain requires for opsel.

Result: VGPR count drops from 257 to 254 for the asymmetric
block 256x192x256 dynamic preshuffle-B kernel.

Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
…) with the loop IV increment S_ADD_U32(iv, step=2) since both compute arg8 + 2. The liveness pass's hasBufferLoadWARHazard only detected WAR hazards for buffer_load ops, so the merged S_ADD_U32 was tied to the IV block_arg and allocated to the same SGPR (s3). The in-place s_add_u32 s3, s3, 2 clobbered the original IV, causing the second unrolled copy to compute arg8+5 instead of arg8+3 for its prefetch guard, zeroing out SRD num_records prematurely.

Fix: Generalized hasBufferLoadWARHazard → hasWARHazard in waveasm/lib/Transforms/Liveness.cpp. The new function detects def/use overlap for ALL iter_args (not just buffer_loads), preventing the tied equivalence class from being formed when the iter_arg's def point precedes uses of the corresponding block_arg. The allocator now assigns the IV increment to a separate register (s43), preserving the original IV (s3) for the second copy's computations.

Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
…ation identity

_find_mergeable_groups used id(init_src.owner) and id(ysrc.owner) to group
vector<1xi8> iter_args by their defining MLIR Operation. Python MLIR bindings
create fresh wrapper objects on each .owner access, so id() returns different
values for the same underlying C++ Operation. This caused iter_args that should
be coalesced into a single vector<4xi8> to end up in separate groups, leaving
stale vector<1xi8> iter_args that produced incorrect scale factors for
scaled_mfma ops (~50% output mismatch).

This is the same bug pattern fixed in _merge_scale_byte_loads (45a336a).
Switching to hash() delegates to mlirOperationHash which correctly identifies
the same underlying object regardless of Python wrapper identity.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…mediates

Broadens the rematerialization pass to handle two additional patterns
beyond the original V_MOV_B32-from-immediate:

  - V_MOV_B32 with SGPR source (scalar-to-vector address copies)
  - S_MOV_B32 with immediate operands (scalar constant materialization)

This shortens live ranges of pre-loop address setup values, reducing
peak VGPR pressure enough for the new merged A-data+A-scale schedule
to compile within the 256-VGPR hardware limit.

Key design constraint: VGPR-producing ops are never cloned into loop
bodies (preserving the VALU-free loop property that avoids MFMA/VALU
pipeline contention). SALU ops (S_MOV_B32) are exempt since they use
the scalar ALU pipeline.

Also adds dumpPeakPressureInfo calls to computeLiveness for debug-mode
pressure diagnostics.

Validated: old schedule still passes on WaveASM backend.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When the pipeliner shifts the A-scale LDS address by one byte between
init and yield iterations, the extract_strided_slice offsets differ
(e.g. init extracts bytes {1,3} from addr X, yield extracts {0,2}
from addr X+1).  The coalescer was rejecting these groups because
init_off != yield_off, leaving individual vector<1xi8> iter_args that
became v_bfe_u32 VALU instructions inside the loop body.

Two fixes:
1. In _find_mergeable_groups: treat offset-mismatched yields as
   "untraceable" instead of rejecting the group entirely.
2. In _coalesce_vector_iter_args: when the yield value traces to an
   extract_strided_slice from a vector<4xi8>, reuse the source vector
   directly as the merged yield value (no need to construct a new load).

This eliminates the v_bfe_u32 VALU ops from the loop body and enables
the new merged A-data + A-scale interleaved schedule to compile and
pass on the WaveASM backend with zero VALU instructions in the loop.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
panditsa and others added 2 commits March 24, 2026 19:30
- Remove shutil.copy debug dump to /tmp in compile.py
- Restore use_ticketed_waitcnt heuristic (was hardcoded False)
- Remove ~21 bare print() statements from mapping_utils.py
- Remove scratch test file test_sympy_diff.py
- Remove commented-out K >= 2048 assumption
- Fix duplicate section header in Passes.td

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant