Skip to content

Linearize read indices pipeline#1196

Merged
willghatch merged 18 commits intomainfrom
users/willghatch/linearize
Apr 3, 2026
Merged

Linearize read indices pipeline#1196
willghatch merged 18 commits intomainfrom
users/willghatch/linearize

Conversation

@willghatch
Copy link
Copy Markdown
Contributor

The main goal of this is to be able to have memory addresses for reads in a loop be simplified to start + IV * stride, with each of those values either being constant or at least able to be hoisted outside of the loop body.

Adds a pre-codegen pipeline that flattens N-dimensional read addresses into 1-dimensional LINEAR_INDEX accesses:

  • flatten_read_indices: rewrites mapped reads to use a single flat offset
  • annotate_iv_strides: extracts constant IV strides for loop-carried reads
  • Codegen LINEAR_INDEX paths for both vector loads and GatherToLDS
  • Removes the old _try_iv_split_offset codegen approach in favor of the new pipeline-based linearization
  • Helper functions (mem_simplify, linearize_dims, _infer_floor_to_exact) in mapping_utils for symbolic floor/Mod cancellation
  • Adjust bounds for linearized reads

This adds some new lit tests that show that with our mxfp4 shuffle layout we can generate linearized reads with constant stride.

This changes a ton of other lit tests. Disclaimer: I was asked to get this PR up ASAP, and this is a ton of churn in the lit tests, and I have not yet validated that they are all correct.

@@ -0,0 +1,234 @@
# Copyright 2025 The IREE Authors
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2026 everywhere

@willghatch willghatch force-pushed the users/willghatch/linearize branch 2 times, most recently from 9ebc371 to f037848 Compare March 30, 2026 18:26
@willghatch
Copy link
Copy Markdown
Contributor Author

I've gone back and forth on supporting mapping_dynamic_vals in flattening. It should work in theory, but it's been tricky to get all of the tests passing, so I've set it to skip in that case at this point. So if we were to land it in the current state, we would need follow-ups to (1) get it working for mapping_dynamic_vals and (2) get it working with Water.

@willghatch willghatch force-pushed the users/willghatch/linearize branch 4 times, most recently from fa75e88 to b0da273 Compare April 1, 2026 15:34
Copy link
Copy Markdown
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, just a few comments

"""
ip = InsertionPoint.current
owner = ip.block.owner
is_in_loop = not isinstance(owner, func_d.FuncOp) and owner.name == "scf.for"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this should be done by LICM pass in our waveasm pipeline. We probably can add a TODO

Comment on lines +955 to +972
zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(sym_strides)
lin_src, _ = _linearize_memref(
kb_src, zero_indices, zero_indices, strides_vals
)
if buffer_ops_enabled:
valid_bytes = _compute_valid_bytes(
lin_src,
element_type,
input_shape,
emitter,
use_real_bounds_override=True,
)
lin_src = _cast_buffer_and_encode_stride(
lin_src,
strides_vals,
element_type,
valid_bytes,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can it into lambda to avoid copypaste


def _get_iv_symbols(expr: sympy.Expr) -> list[sympy.Symbol]:
"""Return all induction-variable symbols in *expr*."""
return [s for s in expr.free_symbols if str(s).startswith(_INDUCTION_PREFIX)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This string matching is unfortunate, can you at least move this function to symbol_utils.py and use _INDUCTION_SYMBOL_PREFIX.

@Hardcode84
Copy link
Copy Markdown
Contributor

also, we need unittests for the new symbol/numeric probing functions

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 2, 2026

Water Code Coverage

Filename                                                           Functions  Missed Functions  Executed       Lines      Missed Lines     Cover    Branches   Missed Branches     Cover
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
lib/Transforms/MemrefDecomposition.cpp                                    28                 0   100.00%         600                49    91.83%         104                46    55.77%
lib/Transforms/AllocToAlloca.cpp                                           2                 0   100.00%          17                 0   100.00%           0                 0         -
lib/Transforms/CheckStaticAssertions.cpp                                   2                 0   100.00%          22                 1    95.45%           8                 4    50.00%
lib/Transforms/GPUModuleToBinary.cpp                                      19                 5    73.68%         339               115    66.08%         128                57    55.47%
lib/Transforms/DropTransformOps.cpp                                        2                 0   100.00%          16                 0   100.00%           2                 0   100.00%
lib/Transforms/GPUToGPURuntime.cpp                                        14                 0   100.00%         298                23    92.28%          40                17    57.50%
lib/Transforms/SLPVectorizer.cpp                                          61                 3    95.08%        1065               102    90.42%         558               165    70.43%
lib/Transforms/AccessCheckers.cpp                                         35                 1    97.14%         446                40    91.03%         124                30    75.81%
lib/Transforms/AssembleISA.cpp                                             4                 1    75.00%          30                 2    93.33%           2                 1    50.00%
lib/Dialect/Wave/Transforms/LoweringPatterns.cpp                          48                 2    95.83%         966               146    84.89%         272                82    69.85%
lib/Dialect/Wave/Transforms/PropagateDefaultsFromConstraints.cpp           3                 3     0.00%          35                35     0.00%          12                12     0.00%
lib/Dialect/Wave/Transforms/TypeConverter.cpp                              7                 2    71.43%          96                26    72.92%          32                17    46.88%
lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp                         10                 0   100.00%         238                18    92.44%          58                11    81.03%
lib/Dialect/Wave/Transforms/DetectNormalForms.cpp                          4                 0   100.00%          48                 0   100.00%           8                 0   100.00%
lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp                   2                 0   100.00%          24                 1    95.83%           6                 1    83.33%
lib/Dialect/Wave/Transforms/InferTypes.cpp                               110                14    87.27%        1923               150    92.20%         880               439    50.11%
lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp                            5                 0   100.00%         130                 1    99.23%          16                 2    87.50%
lib/Dialect/Wave/Transforms/Utils.cpp                                      6                 0   100.00%          96                 5    94.79%          26                 4    84.62%
lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp              7                 0   100.00%         183                16    91.26%          32                14    56.25%
lib/Dialect/Wave/IR/WaveOps.cpp                                          167                19    88.62%        3159               344    89.11%        1266               297    76.54%
lib/Dialect/Wave/IR/WaveAttrs.cpp                                         73                 6    91.78%         966                97    89.96%         424                62    85.38%
lib/Dialect/Wave/IR/IndexExpr.cpp                                         11                 0   100.00%         119                 1    99.16%          24                 3    87.50%
lib/Dialect/Wave/IR/WaveDialect.cpp                                       14                 0   100.00%         528                18    96.59%         194                10    94.85%
lib/Dialect/Wave/IR/WaveTypes.cpp                                          9                 1    88.89%          75                 8    89.33%          18                 3    83.33%
lib/Dialect/Wave/IR/WaveInterfaces.cpp                                   108                 3    97.22%        1649                98    94.06%         666               104    84.38%
lib/Dialect/Wave/IR/WaveUtils.cpp                                         21                 0   100.00%         190                 8    95.79%          78                13    83.33%
lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp                3                 0   100.00%          34                 6    82.35%           8                 2    75.00%
lib/Dialect/NormalForm/IR/NormalFormDialect.cpp                            1                 0   100.00%           6                 0   100.00%           0                 0         -
lib/Dialect/NormalForm/IR/NormalFormOps.cpp                               12                 0   100.00%         201                 9    95.52%          58                 7    87.93%
lib/Pipelines/Pipelines.cpp                                                2                 0   100.00%          27                 0   100.00%           0                 0         -
lib/Analysis/InUseForSpeculation.cpp                                      12                 1    91.67%         142                 8    94.37%          32                 4    87.50%
include/water/Dialect/Wave/Transforms/LoweringPatterns.h                   1                 0   100.00%           3                 0   100.00%           0                 0         -
include/water/Dialect/Wave/IR/IndexExpr.h                                  1                 0   100.00%          10                 0   100.00%           2                 0   100.00%
include/water/Dialect/Wave/IR/WaveInterfaces.h                            40                 3    92.50%         159                 8    94.97%           8                 2    75.00%
include/water/Dialect/Wave/IR/WaveTypes.h                                  1                 0   100.00%           5                 0   100.00%           4                 0   100.00%
include/water/Dialect/Wave/IR/WaveUtils.h                                  1                 0   100.00%           5                 0   100.00%           4                 1    75.00%
include/water/Dialect/Wave/IR/WaveAttrs.h                                  4                 0   100.00%          14                 0   100.00%           0                 0         -
include/water/Dialect/NormalForm/IR/NormalFormInterfaces.h                 1                 1     0.00%           4                 4     0.00%           0                 0         -
include/water/Analysis/InUseForSpeculation.h                              12                 3    75.00%          39                17    56.41%          16                10    37.50%
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
TOTAL                                                                    863                68    92.12%       13907              1356    90.25%        5110              1420    72.21%

Download full HTML report

@willghatch willghatch force-pushed the users/willghatch/linearize branch from ede0a8f to 6d6c076 Compare April 2, 2026 19:59
The main goal of this is to be able to have memory addresses for reads in a loop be simplified to `start + IV * stride`, with each of those values either being constant or at least able to be hoisted outside of the loop body.

Adds a pre-codegen pipeline that flattens N-dimensional read addresses into 1-dimensional LINEAR_INDEX accesses:
- flatten_read_indices: rewrites mapped reads to use a single flat offset
- annotate_iv_strides: extracts constant IV strides for loop-carried reads
- Codegen LINEAR_INDEX paths for both vector loads and GatherToLDS
- Removes the old _try_iv_split_offset codegen approach in favor of the new pipeline-based linearization
- Helper functions (mem_simplify, linearize_dims, _infer_floor_to_exact) in mapping_utils for symbolic floor/Mod cancellation
- Adjust bounds for linearized reads

This adds some new lit tests that show that with our mxfp4 shuffle layout we can generate linearized reads with constant stride.

This changes a ton of other lit tests.  Disclaimer:  I was asked to get this PR up ASAP, and this is a ton of churn in the lit tests, and I have not yet validated that they are all correct.

Signed-off-by: William G Hatch <william@hatch.uno>
Three fixes to the flatten_read_indices pass:

1. Skip reads with MemoryAccessFlags (VOLATILE, NONTEMPORAL).  The
   LINEAR_INDEX codegen fallback path uses vector.maskedload which
   drops volatile semantics.  This caused incorrect streamk partial
   buffer synchronization (stale spinlock reads).

2. Use physical_layout.shape for stride computation when a
   MemoryLayout is present, matching the strides the emitter
   creates for the memref via reinterpret_cast.

3. Use physical (post-mapping) start expressions as bound keys in
   _convert_bounds, falling back to the original index when the
   physical start contains $dynamic_val symbols that are only
   resolvable through the mapping at codegen time.

Made-with: Cursor

Signed-off-by: William G Hatch <william@hatch.uno>
Replace the water-specific emit_water_dialect flag with a general
linearize_reads option (default True).  Set it to False in:

- water_e2e_test: water-opt does not yet understand $LINEAR_INDEX
- waveasm 256x224x256 dynamic+bufops: linearized reads push VGPR
  count past the 256-register limit; disabling linearization lets
  the test pass instead of needing an xfail

Signed-off-by: William G Hatch <william@hatch.uno>
These tests check MLIR output via FileCheck and their CHECK
patterns have not been updated for the linearized-read output.
The water e2e tests already disable linearize_reads; align
the lit tests to match.

Made-with: Cursor

Signed-off-by: William G Hatch <william@hatch.uno>
256x224x256 with dynamic dims overflows VGPRs regardless of
buffer ops when linearized reads are enabled (unscheduled).

256x160x256 with dynamic dims produces numerical mismatches
when linearized, likely due to the preshuffle mapping's
floor/Mod expressions over dynamic K not simplifying correctly
after flatten_read_indices.

Disable linearize_reads for both configurations when
unscheduled + dynamic dims.

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
Signed-off-by: William G Hatch <william@hatch.uno>
- Update CHECK line numbers in mlir_converter_debug_locations.py
  (shifted +1 by linearize_reads=False insertion)
- Add annotate_iv_strides to expected_failures in mlir_roundtrip_pipeline.py
- Disable flatten_read_indices when dynamic_strides or water backend
  is active (incompatible with pre-computed symbolic strides)
- Broaden waveasm e2e MXFP4 skip to all dynamic-dim configs
  (floor/Mod expressions over dynamic K cause numerical mismatches)
- Update expected VGPR/SGPR counts for testScaledBatchedGemmMXFP4Codegen
  (172/62 with linearize_reads vs previous 170/61)

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
A previous commit disabled linearization for dynamic strides cases.

Signed-off-by: William G Hatch <william@hatch.uno>
Linearized reads set the SRD base to the buffer start (offset=0) and
use the full flat offset as the load index.  For masked-off threads at
workgroup boundaries, the computed flat offset can exceed the actual
tensor size.  With validBytes set to the hardware maximum (~2GB), the
SRD bounds check does not catch these, causing GPU memory access faults
on large tensors (e.g. GEMM shape 4096x20480x2560).

Fix by passing use_real_bounds_override=True to _compute_valid_bytes
for all linearized read paths (handle_read LINEAR_INDEX and
handle_gather_to_lds LINEAR_INDEX), so the SRD's validBytes reflects
the actual tensor size and OOB accesses from masked-off threads
harmlessly return zero.

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
…er reads

Linearized reads with buffer ops used vector.maskedload on
fat_raw_buffer memrefs to handle boundary elements.  This does not
lower correctly through the AMDGPU backend: the masked vector load can
interact poorly with the SRD bounds check when validBytes is set to the
exact tensor size, causing valid elements near the boundary to be
zeroed.

Switch to the same OOB-index-redirect strategy that the write path
already uses: for each element, arith.select between the real offset
and an out-of-bounds index based on the mask.  When the mask is
splatted (uniform across all elements), use a scalar select on the
vector load offset for better codegen.

This fixes 8 shape1 (111, 813) test failures in mi35x CI:
test_copy, test_dynamic_copy, test_vector_add, test_bound_check,
test_read_write_same, test_offset_read_one, test_offset_write,
test_transpose_write.

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
_convert_bounds was using the post-mapping (physical) start expression
as the bounds-check key.  When a mapping adds a runtime offset via
set_symbol (e.g. OFFSET in prefill_attention, EXT_IDX in
extend_attention), the physical start is in a different coordinate
space than the bound value.  For example, with mapping N_Q: j + OFFSET
and bound N_Q (per-sequence length 64), the mask became
(j + OFFSET) < 64, which is always false for OFFSET >= 64.

Generalize the existing $dynamic_val fallback: instead of string-
matching "dynamic_val", detect any extra free symbols in the physical
start compared to the original (pre-mapping) index.  Extra symbols
indicate a coordinate-space shift, so fall back to the original index
for bounds checking.

Made-with: Cursor

Signed-off-by: William G Hatch <william@hatch.uno>
There are still some dynamic cases used by these kernels that fail for linearization.
But I want to land something, so let's iterate on those later along with the other follow-ups.

Signed-off-by: William G Hatch <william@hatch.uno>
Signed-off-by: William G Hatch <william@hatch.uno>
Signed-off-by: William G Hatch <william@hatch.uno>
Signed-off-by: William G Hatch <william@hatch.uno>
Signed-off-by: William G Hatch <william@hatch.uno>
@willghatch willghatch force-pushed the users/willghatch/linearize branch from 6d6c076 to 2f89f54 Compare April 2, 2026 22:51
The pass only operates on LINEAR_INDEX reads, so it should not run
when linearize_reads is disabled.

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
The Water backend does not yet handle $LINEAR_INDEX symbols, so set
linearize_reads=False for water e2e, roundtrip pipeline, and the new
mxfp4 scaled MMA converter tests.

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
@willghatch willghatch force-pushed the users/willghatch/linearize branch from 2f89f54 to 6dbdfe7 Compare April 2, 2026 22:52
@willghatch
Copy link
Copy Markdown
Contributor Author

@Hardcode84 I've addressed your comments, let me know if you have any other comments.

@willghatch willghatch merged commit 059c25b into main Apr 3, 2026
18 of 19 checks passed
@willghatch willghatch deleted the users/willghatch/linearize branch April 3, 2026 15:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants