[WIP: Rebased branch] MXFP4 GEMM through Waveasm#1182
Open
panditsa wants to merge 32 commits intoiree-org:mainfrom
Open
[WIP: Rebased branch] MXFP4 GEMM through Waveasm#1182panditsa wants to merge 32 commits intoiree-org:mainfrom
panditsa wants to merge 32 commits intoiree-org:mainfrom
Conversation
Add Assumption(K >= 2048) to the MXFP4 preshuffle template. This enables the IndexMapping simplification pass to prove that within_nblk (max 1023) < K_PACKED (= K/2 >= 1024), eliminating dynamic floordiv/mod for the B-data preshuffle mapping. Key fixes: - Use subs_idxc to resolve derived symbols (K_PACKED -> K//2) before the divisor lower-bound check, so constraint-based bounds on K propagate to K_PACKED. - Extract symbol lower bounds from Assumption(S >= c) and Assumption(S > c) constraints. - Guard subs_idxc for contexts where IndexingContext is unavailable. The pass now correctly detects that within_nblk < K_PACKED and rewrites the IndexMapping to eliminate the dynamic floordiv/mod. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds annotate_iv_strides pass that statically computes the per-iteration stride of induction variables through index mappings. This enables the codegen to fold constant IV strides directly into buffer_load voffset fields, eliminating runtime address arithmetic. Key components: - extract_iv() in symbol_utils: separates IV from complex floor/Mod exprs using integer-division decomposition identities - compute_iv_stride_through_mapping() in mapping_utils: probes mappings numerically to determine constant strides through symbolic dimensions - annotate_iv_strides pass: wires stride analysis into the compilation pipeline, attaching results as node metadata Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When pipeline double-buffering creates the loop body, the in-loop scale reloads from the swapped LDS buffer can be sub-word loads (ds_read_u8, ds_read_u16) rather than extract_strided_slice from a vector<4xi8>. This caused _find_mergeable_groups to reject the group entirely, so _coalesce_vector_iter_args never fired and the scalar scale extraction pattern (v_bfe_u32 + individual byte loads) persisted in the final asm. Relax _find_mergeable_groups to accept groups where init values trace to extract_strided_slice but yield values do not. For such groups, create a wide vector<4xi8> load from the byte-0 yield load's base address before the yield, giving the coalesce logic a proper dword source. The existing opsel replacement then converts the scalar scale chains to indexed vector<4xf8E8M0FNU> access with scalesIdx, producing ds_read_b32 + op_sel instead of ds_read_u8/u16 + v_bfe_u32. Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
The gather_to_lds codegen emitted two redundant layers of OOB protection: (1) SRD validBytes/num_records hardware clamping via _compute_branchless_valid_bytes, and (2) a per-load software mask (_build_mask + select with OOB sentinel 0x7FFFFFFF). Layer 2 cost ~4 VALU instructions and a temporary VGPR per gather_to_lds call (~32 VALU + ~8 VGPRs per loop phase). The root cause was that _compute_branchless_valid_bytes used 0x7FFFFFFE (hardware max) as `real_valid`, providing no actual clamping when the guard condition was true. The software mask was the only real bounds check. - _compute_branchless_valid_bytes: when the shape is dynamic (static path returns 0x7FFFFFFE fallback), compute the actual buffer size at runtime via gen_sympy_index(product(shape) * elem_bytes). This makes the SRD num_records reflect the real buffer dimensions, enabling hardware clamping for OOB addresses. - handle_gather_to_lds: skip _build_mask + select when valid_bytes_override is set (g2s_guard exists with real bounds). Hardware num_records clamping handles all OOB cases. This matches AITER's approach: raw buffer_load with SRD clamping, no per-load software bounds checking. Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
The _merge_scale_byte_loads pass merged sub-word LDS scale loads into a single vector<4xi8> load, but had two correctness issues that caused ~25% mismatch on large asymmetric shapes (e.g. block 256,192,256): 1. Loads were grouped only by memref identity, not by actual base address. Loads from the same memref but at different addresses (different affine map operands/expressions) were incorrectly merged, causing the wide load to read from the wrong double-buffer phase. 2. Merging vector<2xi8> loads created size-2 extracts from the wide vector<4xi8>, which broke _trace_scale_chain (expects vector<4xi8> source) and prevented opsel for those bytes. Fix: restrict merging to vector<1xi8> loads only, and group by (memref, prefix indices, affine base expression) so loads at different base addresses are never merged. Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Scalar kernel arguments (M, N, K) are now kept in dedicated SGPRs instead of being unconditionally copied to VGPRs. Downstream uniform arithmetic (ceildiv, floordiv, mod, comparisons) auto-selects SALU when both operands are scalar, reducing VGPR pressure by 5 registers (232 → 227) and doubling SALU instruction count in the prologue. Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
When a buffer_load has voffset = V_ADD_U32(vgpr, sgpr) and soffset = 0, rewrite to use the VGPR directly as voffset and the SGPR as soffset. This eliminates one VALU instruction per load by leveraging the hardware scalar offset: effective_addr = SRD_base + voffset + soffset. Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
…set in assembly emission Three store translation handlers (handleVectorTransferWrite, handleRawBufferStore, handleMemRefStore) were passing AGPR-typed values directly to BUFFER_STORE_DWORD* ops, which require WaveASM_AnyVGPR operands. This caused verification failures when eliminate_epilogue stores MFMA accumulator results (which live in AGPRs on gfx950) to global memory. Insert V_ACCVGPR_READ_B32 before the store in all three paths, matching the existing pattern in handleVectorStore and handleVectorMaskedStore. Additionally, fix the assembly emitter to use "off" syntax instead of "offen" when the voffset operand is an immediate rather than a VGPR, since MUBUF instructions require a VGPR when the offen flag is set. Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
SALU instructions implicitly write the SCC (Scalar Condition Code) flag on AMDGPU hardware. When optimizations promote VGPR ops to SGPR (SALU), the new scalar instructions can silently clobber SCC between a producer (s_cmp, s_add_u32) and its consumer (s_cbranch_scc*, s_cselect_b32, s_addc_u32), causing corrupted addresses and memory faults. This commit adds infrastructure to detect and prevent SCC hazards: - SCCDef/SCCUse traits (defined but not yet applied to existing ops, as adding NativeOpTrait to ODS classes changes generated C++ and alters MLIR pass behavior — needs further investigation) - SALUUnaryWithSCCOp/SALUBinaryWithSCCOp base classes for future migration of SCC-clobbering ops away from Pure - SCC verifier pass (--waveasm-scc-verifier) that walks IR in emission order using isa<> checks to detect SCC clobbers between producers and consumers, wired into both compilation pipelines - ScopedCSE guard against future SCCUse-tagged ops The verifier is intentionally trait-free on existing ops to guarantee zero codegen impact. The writesSCC() function enumerates all known SCC-writing SALU ops by type. Made-with: Cursor
All SCC-clobbering SALU ops now carry the SCCDef trait via their op class, and the SCC verifier uses hasTrait<SCCDef>() instead of isa<> enumeration. Ops that read SCC (s_cselect_b32) carry SCCUse. Classification: - SALUBinaryWithSCCOp (NoMemoryEffect + AlwaysSpeculatable + SCCDef): bitwise, shifts, min/max, BFE — equivalent to Pure for MLIR passes but identifiable as SCC-clobbering by the verifier - SALUUnaryWithSCCOp (same traits): s_not, s_brev, s_bcnt*, s_ff*, s_flbit*, s_abs - SALUBinaryWithCarryOp + SCCDef: s_add_u32, s_addc_u32, s_sub_u32 - SALUCmpOp + SCCDef: all s_cmp_* variants - s_cselect_b32: standalone with SCCUse + NoMemoryEffect (not Pure, not CSE-eligible — result depends on implicit SCC) - SALUBinaryOp (Pure, no SCC): s_mul_i32, s_bfm_*, s_pack_* - SALUUnaryOp (Pure, no SCC): s_mov_b32/b64, s_sext_* Key insight: replacing Pure with explicit NoMemoryEffect + AlwaysSpeculatableImplTrait + SCCDef produces bit-identical assembly while enabling trait-based SCC verification. Validated: test passes, assembly identical to pre-change baseline. Made-with: Cursor
- emitOr/emitXor helpers in Handlers.h: S_OR_B32/S_XOR_B32 when both operands are scalar, V_OR_B32/V_XOR_B32 otherwise (same pattern as emitAnd) - handleArithOrI: uses emitOr instead of always V_OR_B32 - handleArithXorI: uses emitXor instead of handleBinaryVALU - handleArithDivUI power-of-2: uses emitLshr (has scalar path) instead of always V_LSHRREV_B32 - handleArithRemUI power-of-2: uses emitAnd (has scalar path) instead of always V_AND_B32 These are safe handler-level promotions: the SALU path only activates when both operands are already scalar, and the SCC verifier catches any hazards. No assembly change for the current GEMM kernel (operands are thread-ID-derived VGPRs), but enables SALU emission for future kernels with scalar OR/XOR/div/rem patterns. Made-with: Cursor
- handleArithCmpI: V_CMP VOP3 can take one SGPR via the constant bus; only ensureVGPR when BOTH operands are SGPR (saves v_mov_b32 for cases like v_cmp_gt_i32 vcc, sN, 0) - handleArithSelect: don't ensureVGPR cond before V_CMP_NE_U32 when the other operand is an immediate (no constant bus conflict) - AffineHandlers non-overlapping Add: use emitOr (SALU-aware) instead of ensureBothVGPR + V_OR_B32 Reduces v_mov_b32 sN->vN instructions from 14 to 12 in the GEMM prologue. V_CNDMASK_B32 SGPR optimization was attempted but cancelled due to VCC constant bus interaction causing test failures. Made-with: Cursor
- emitScalarCmp: helper that emits S_CMP_* for a given CmpIPredicate. Refactors the existing handleArithCmpI ConditionOp path to use it. - emitMaxI32/emitMinI32/emitMaxU32/emitMinU32: S_MAX/S_MIN when both operands are scalar, V_MAX/V_MIN otherwise. Updates all four handleArithMax/MinSI/UI handlers to use these helpers. The v_max_i32 in the GEMM prologue (ceildiv bound clamping) becomes s_max_i32, keeping the value in SGPR for downstream scalar chains. CmpI→Select fusion (s_cmp+s_cselect for all-scalar cmpi) was attempted but causes memory faults from cascading SGPR register allocation changes that corrupt buffer addresses. Needs deeper investigation of which downstream consumer breaks when the cmpi result is SGPR instead of VGPR. Left as TODO with the emitScalarCmp infrastructure in place for when that investigation is done. Made-with: Cursor
- Move emitScalarCmp from static in ArithHandlers.cpp to inline in Handlers.h so it can be shared across handler files. - Add TODO in emitSrdNumRecords documenting the s_cmp+s_cselect optimization opportunity for the arith.select(arith.cmpi) pattern. Multiple implementation approaches were tested (direct PSRegType, s_mov copy, s_add copy) but all produce memory faults despite correct-looking assembly. Root cause needs IR dump investigation. Made-with: Cursor
Root cause: PrecoloredSRegOp for initial SRDs (e.g., output buffer at s[36:39]) had no downstream SSA users — the epilogue referenced them via RawOps with hardcoded register numbers. CanonicalizerPass DCE'd the PrecoloredSRegOp, so LinearScanPass never reserved those SGPRs. Under increased SGPR pressure the allocator assigned loop temps to s36/s37, clobbering the SRD base address and causing GPU memory faults. Fix: add DCEProtectOp after each initial SRD's PrecoloredSRegOp in emitSRDPrologue to prevent DCE and ensure permanent reservation. Also enables s_cmp + s_cselect path in emitSrdNumRecords for the branchless g2s guard pattern (arith.select(arith.cmpi(scalar, scalar), scalar, scalar)), emitting directly into the SRD word 2 PSReg. Made-with: Cursor
When emitUnsignedFloordiv receives scalar inputs, the Barrett reduction must run in VALU (no SALU equivalent for float rcp), but the result is uniform. Add v_readfirstlane_b32 at the end to extract back to SGPR so downstream arithmetic stays in SALU instead of cascading to VALU. Made-with: Cursor
…elect When arith.select's condition comes from arith.cmpi with scalar operands and both true/false values are scalar, emit s_cmp + s_cselect directly as a fused pair. This bypasses the VGPR boolean materialization from handleArithCmpI, eliminating dead v_cmp + v_cndmask instructions. Reduces peak VGPRs from 226 to 224 in the 256x192 MXFP4 GEMM kernel. Made-with: Cursor
When both cmpi operands are scalar and every user is an arith.select with scalar true/false values, the cmpi+select fusion in handleArithSelect will emit s_cmp + s_cselect directly. Skip the VALU v_cmp + v_cndmask boolean materialization entirely, mapping the cmpi result to a dummy constant. Eliminates 4 dead VALU instructions per loop iteration (v_mov_b32, v_cmp_lt_i32, v_mov_b32, v_cndmask_b32) from the branchless guard. Made-with: Cursor
…fragmentation Two complementary approaches to reduce peak VGPR count from interleaved buffer_load/ds_read register assignments: 1. VGPR Compaction Pass (VGPRCompaction.cpp): Post-allocation pass that reassigns physical VGPRs using shortest-first greedy coloring. Short-lived values (ds_read results) pack into low registers, long-lived values (buffer_load prefetch) go to high registers. Handles precolored pinning, sub-element remapping, loop liveness extension, and _iterArgPhysRegs metadata updates. 2. Bidirectional Linear Scan (RegAlloc.h, LinearScanRegAlloc.cpp): allocRangeFromTop/allocSingleFromTop methods that scan the free bitvector from high to low with a capped ceiling (maxPressure, not maxRegs) to avoid allocating into the AGPR region. Long-lived VGPR ranges (length > 75th percentile) get top-down allocation while short-lived ranges get normal bottom-up allocation. The compaction pass reduces VGPRs by ~4 on the 256x192 MXFP4 GEMM kernel. The bidirectional approach is currently neutral (threshold needs tuning for specific kernels) but provides the infrastructure for future improvements. Made-with: Cursor
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
The _merge_scale_byte_loads pass had three bugs that prevented it from
firing and caused a register pressure regression (257 VGPRs, exceeding
the 256 limit for 4-wave occupancy):
1. Python id() was used to group MLIR Values (memref, indices, affine
operands), but the Python bindings create fresh wrapper objects per
access, so loads from the same memref were placed into separate
groups. Use hash() instead, which delegates to mlirValueHash and
correctly identifies the same underlying Value.
2. Stripping the trailing constant from an affine map string left
whitespace ("(expr )>" vs "(expr)>"), preventing the offset-0 load
from grouping with offset-1/2 loads. Add .rstrip() before the
closing delimiter.
3. Restricting to vector<1xi8> only (the 3baa31d fix) avoided
breaking _trace_scale_chain but left vector<2xi8> loads unmerged,
wasting registers. Re-allow vector<2xi8> loads and replace their
downstream extract_strided_slice(size=1) users with direct byte
extracts from the wide vector<4xi8>, preserving the source type
that _trace_scale_chain requires for opsel.
Result: VGPR count drops from 257 to 254 for the asymmetric
block 256x192x256 dynamic preshuffle-B kernel.
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
…) with the loop IV increment S_ADD_U32(iv, step=2) since both compute arg8 + 2. The liveness pass's hasBufferLoadWARHazard only detected WAR hazards for buffer_load ops, so the merged S_ADD_U32 was tied to the IV block_arg and allocated to the same SGPR (s3). The in-place s_add_u32 s3, s3, 2 clobbered the original IV, causing the second unrolled copy to compute arg8+5 instead of arg8+3 for its prefetch guard, zeroing out SRD num_records prematurely. Fix: Generalized hasBufferLoadWARHazard → hasWARHazard in waveasm/lib/Transforms/Liveness.cpp. The new function detects def/use overlap for ALL iter_args (not just buffer_loads), preventing the tied equivalence class from being formed when the iter_arg's def point precedes uses of the corresponding block_arg. The allocator now assigns the IV increment to a separate register (s43), preserving the original IV (s3) for the second copy's computations. Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
…ation identity _find_mergeable_groups used id(init_src.owner) and id(ysrc.owner) to group vector<1xi8> iter_args by their defining MLIR Operation. Python MLIR bindings create fresh wrapper objects on each .owner access, so id() returns different values for the same underlying C++ Operation. This caused iter_args that should be coalesced into a single vector<4xi8> to end up in separate groups, leaving stale vector<1xi8> iter_args that produced incorrect scale factors for scaled_mfma ops (~50% output mismatch). This is the same bug pattern fixed in _merge_scale_byte_loads (45a336a). Switching to hash() delegates to mlirOperationHash which correctly identifies the same underlying object regardless of Python wrapper identity. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…mediates Broadens the rematerialization pass to handle two additional patterns beyond the original V_MOV_B32-from-immediate: - V_MOV_B32 with SGPR source (scalar-to-vector address copies) - S_MOV_B32 with immediate operands (scalar constant materialization) This shortens live ranges of pre-loop address setup values, reducing peak VGPR pressure enough for the new merged A-data+A-scale schedule to compile within the 256-VGPR hardware limit. Key design constraint: VGPR-producing ops are never cloned into loop bodies (preserving the VALU-free loop property that avoids MFMA/VALU pipeline contention). SALU ops (S_MOV_B32) are exempt since they use the scalar ALU pipeline. Also adds dumpPeakPressureInfo calls to computeLiveness for debug-mode pressure diagnostics. Validated: old schedule still passes on WaveASM backend. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When the pipeliner shifts the A-scale LDS address by one byte between
init and yield iterations, the extract_strided_slice offsets differ
(e.g. init extracts bytes {1,3} from addr X, yield extracts {0,2}
from addr X+1). The coalescer was rejecting these groups because
init_off != yield_off, leaving individual vector<1xi8> iter_args that
became v_bfe_u32 VALU instructions inside the loop body.
Two fixes:
1. In _find_mergeable_groups: treat offset-mismatched yields as
"untraceable" instead of rejecting the group entirely.
2. In _coalesce_vector_iter_args: when the yield value traces to an
extract_strided_slice from a vector<4xi8>, reuse the source vector
directly as the merged yield value (no need to construct a new load).
This eliminates the v_bfe_u32 VALU ops from the loop body and enables
the new merged A-data + A-scale interleaved schedule to compile and
pass on the WaveASM backend with zero VALU instructions in the loop.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Remove shutil.copy debug dump to /tmp in compile.py - Restore use_ticketed_waitcnt heuristic (was hardcoded False) - Remove ~21 bare print() statements from mapping_utils.py - Remove scratch test file test_sympy_diff.py - Remove commented-out K >= 2048 assumption - Fix duplicate section header in Passes.td Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Do not merge, this branch is to track changes merged into main.