Skip to content

[water] provenance-based propagation heuristic#1206

Merged
ftynse merged 3 commits intomainfrom
users/ftynse/heuristic
Apr 2, 2026
Merged

[water] provenance-based propagation heuristic#1206
ftynse merged 3 commits intomainfrom
users/ftynse/heuristic

Conversation

@ftynse
Copy link
Copy Markdown
Contributor

@ftynse ftynse commented Mar 30, 2026

It is unfortunately necessary to introduce this additional complexity to reflect the naive non-dataflow approach in the python counterpart: keep track of the vector shape of the operation from which the lattice originally propagates. This is required by the propagation-stopper heuristic that cannot infer this value from per-value vector shapes that water correctly cleans up to correspond to the shape. The shape is stored as a separate lattice field that is simply carried over based on priority. Since all Mma operations now have different priorities by construction, this never leads to lattice tops outside of synthetic test cases. For Write operations, the other source, the priority is equal but so are vector shapes since those are simply taken from the hardware constraint object.

This is highly dubious due to non-local propagation rules and blind overwriting of vector shapes in the python code, but necessary to match the current behavior. It must be revised as soon as possible to avoid priorities and non-local behavior altogether.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 30, 2026

Water Code Coverage

Filename                                                           Functions  Missed Functions  Executed       Lines      Missed Lines     Cover    Branches   Missed Branches     Cover
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
lib/Transforms/MemrefDecomposition.cpp                                    28                 0   100.00%         600                49    91.83%         104                46    55.77%
lib/Transforms/AllocToAlloca.cpp                                           2                 0   100.00%          17                 0   100.00%           0                 0         -
lib/Transforms/CheckStaticAssertions.cpp                                   2                 0   100.00%          22                 1    95.45%           8                 4    50.00%
lib/Transforms/GPUModuleToBinary.cpp                                      19                 5    73.68%         339               115    66.08%         128                57    55.47%
lib/Transforms/DropTransformOps.cpp                                        2                 0   100.00%          16                 0   100.00%           2                 0   100.00%
lib/Transforms/GPUToGPURuntime.cpp                                        14                 0   100.00%         298                23    92.28%          40                17    57.50%
lib/Transforms/SLPVectorizer.cpp                                          61                 3    95.08%        1065               102    90.42%         558               167    70.07%
lib/Transforms/AccessCheckers.cpp                                         35                 1    97.14%         446                40    91.03%         124                30    75.81%
lib/Transforms/AssembleISA.cpp                                             4                 1    75.00%          30                 2    93.33%           2                 1    50.00%
lib/Dialect/Wave/Transforms/LoweringPatterns.cpp                          48                 2    95.83%         966               146    84.89%         272                82    69.85%
lib/Dialect/Wave/Transforms/PropagateDefaultsFromConstraints.cpp           3                 3     0.00%          35                35     0.00%          12                12     0.00%
lib/Dialect/Wave/Transforms/TypeConverter.cpp                              7                 2    71.43%          96                26    72.92%          32                17    46.88%
lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp                         10                 0   100.00%         238                18    92.44%          58                11    81.03%
lib/Dialect/Wave/Transforms/DetectNormalForms.cpp                          4                 0   100.00%          48                 0   100.00%           8                 0   100.00%
lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp                   2                 0   100.00%          24                 1    95.83%           6                 1    83.33%
lib/Dialect/Wave/Transforms/InferTypes.cpp                               110                14    87.27%        1923               150    92.20%         880               439    50.11%
lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp                            5                 0   100.00%         130                 1    99.23%          16                 2    87.50%
lib/Dialect/Wave/Transforms/Utils.cpp                                      6                 0   100.00%          96                 5    94.79%          26                 4    84.62%
lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp              7                 0   100.00%         183                16    91.26%          32                14    56.25%
lib/Dialect/Wave/IR/WaveOps.cpp                                          167                19    88.62%        3159               344    89.11%        1266               295    76.70%
lib/Dialect/Wave/IR/WaveAttrs.cpp                                         73                 6    91.78%         966                97    89.96%         424                63    85.14%
lib/Dialect/Wave/IR/IndexExpr.cpp                                         11                 0   100.00%         119                 1    99.16%          24                 3    87.50%
lib/Dialect/Wave/IR/WaveDialect.cpp                                       14                 0   100.00%         528                24    95.45%         194                11    94.33%
lib/Dialect/Wave/IR/WaveTypes.cpp                                          9                 1    88.89%          75                 8    89.33%          18                 3    83.33%
lib/Dialect/Wave/IR/WaveInterfaces.cpp                                   108                 3    97.22%        1649               100    93.94%         666               104    84.38%
lib/Dialect/Wave/IR/WaveUtils.cpp                                         21                 0   100.00%         190                 8    95.79%          78                13    83.33%
lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp                3                 0   100.00%          34                 6    82.35%           8                 2    75.00%
lib/Dialect/NormalForm/IR/NormalFormDialect.cpp                            1                 0   100.00%           6                 0   100.00%           0                 0         -
lib/Dialect/NormalForm/IR/NormalFormOps.cpp                               12                 0   100.00%         201                 9    95.52%          58                 7    87.93%
lib/Pipelines/Pipelines.cpp                                                2                 0   100.00%          27                 0   100.00%           0                 0         -
lib/Analysis/InUseForSpeculation.cpp                                      12                 1    91.67%         142                 8    94.37%          32                 4    87.50%
include/water/Dialect/Wave/Transforms/LoweringPatterns.h                   1                 0   100.00%           3                 0   100.00%           0                 0         -
include/water/Dialect/Wave/IR/IndexExpr.h                                  1                 0   100.00%          10                 0   100.00%           2                 0   100.00%
include/water/Dialect/Wave/IR/WaveInterfaces.h                            40                 3    92.50%         159                 8    94.97%           8                 2    75.00%
include/water/Dialect/Wave/IR/WaveTypes.h                                  1                 0   100.00%           5                 0   100.00%           4                 0   100.00%
include/water/Dialect/Wave/IR/WaveUtils.h                                  1                 0   100.00%           5                 0   100.00%           4                 1    75.00%
include/water/Dialect/Wave/IR/WaveAttrs.h                                  4                 0   100.00%          14                 0   100.00%           0                 0         -
include/water/Dialect/NormalForm/IR/NormalFormInterfaces.h                 1                 1     0.00%           4                 4     0.00%           0                 0         -
include/water/Analysis/InUseForSpeculation.h                              12                 3    75.00%          39                17    56.41%          16                10    37.50%
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
TOTAL                                                                    863                68    92.12%       13907              1364    90.19%        5110              1422    72.17%

Download full HTML report

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates Wave index-expression propagation to track and propagate an additional “source vector shape” (the originating op’s vector shape) in the lattice, in order to preserve Python-prototype propagation/stopper behavior that can’t be reconstructed from the per-value vector shapes alone.

Changes:

  • Extend IndexExprsLatticeStorage with sourceVectorShape and sourceVectorShapePriority, and join them using priority semantics.
  • Update the propagation-stopper heuristic (shouldPropagateIndexExprs) and key ops (MMA init, Write sideways propagation, Permute stride permutation) to preserve/use the new source vector shape.
  • Adjust and expand MLIR tests and design notes to reflect the new behavior.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
water/lib/Dialect/Wave/IR/WaveOps.cpp Updates propagation heuristic and ensures MMA/Write/Permute flows preserve sourceVectorShape.
water/lib/Dialect/Wave/IR/WaveInterfaces.cpp Implements storage, printing, equality, and join semantics for sourceVectorShape.
water/include/water/Dialect/Wave/IR/WaveInterfaces.h Documents new lattice fields and updated propagation heuristic behavior.
water/test/Dialect/Wave/infer-index-exprs.mlir Adjusts an existing test expectation around backward propagation into a register.
water/test/Dialect/Wave/infer-index-exprs-lattice.mlir Updates skip tests and adds new tests covering sourceVectorShape join behavior and conflicts.
docs/wave/ir_design_notes.rst Updates design notes to describe the new source-vector-shape-based heuristic.

Comment on lines +1696 to +1702
// Equal priorities: identical values join cleanly, different go to top.
if (lhsSVS == rhsSVS) {
joinedSourceVectorShape = lhsSVS;
joinedSourceVectorShapePriority = lhsPri;
} else {
return IndexExprsLatticeStorage::top();
}
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IndexExprsLatticeStorage::join can now return top() due to a sourceVectorShape conflict (equal sourceVectorShapePriority but different sourceVectorShape). Current diagnostics in propagation treat such cases as "index expression" conflicts (since getJoinedVectorShape can still succeed), which is misleading and makes failures hard to debug. Consider adding an explicit source-vector-shape join/check helper so callers can classify and report this as a distinct "source vector shape" conflict (and include both source shapes/priorities in the notes).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid concern. Maybe in a follow-up or maybe not as we do not intend on keeping this forever

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never happen outside of synthetic cases. I just didn't want to assert in the core lattice method.

@ftynse ftynse requested a review from martin-luecke April 1, 2026 07:42
Copy link
Copy Markdown
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% done reviewing yet, but leaving some comments to address

Comment on lines +1696 to +1702
// Equal priorities: identical values join cleanly, different go to top.
if (lhsSVS == rhsSVS) {
joinedSourceVectorShape = lhsSVS;
joinedSourceVectorShapePriority = lhsPri;
} else {
return IndexExprsLatticeStorage::top();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid concern. Maybe in a follow-up or maybe not as we do not intend on keeping this forever

Comment on lines +2035 to +2046
// CHECK-DAG: M : <[#wave.index_symbol<T0>] -> (T0 * 32, 1, 1)>
// CHECK-DAG: N : <[#wave.index_symbol<T0>] -> (T0 * 10, 1, 1)>
%result = wave.add %a, %b {wave_test.override_operand_index = [
[{M = 5 : i32, N = 5 : i32}, {
M = #wave.index_mapping<[#wave.index_symbol<T0>] -> (T0 * 32, 1, 1)>,
N = #wave.index_mapping<[#wave.index_symbol<T0>] -> (T0 * 10, 1, 1)>
}, {M = 16 : i64, N = 16 : i64}],
[{M = 3 : i32, N = 3 : i32}, {
M = #wave.index_mapping<[#wave.index_symbol<T0>] -> (T0 * 32, 1, 1)>,
N = #wave.index_mapping<[#wave.index_symbol<T0>] -> (T0 * 10, 1, 1)>
}, {M = 16 : i64, N = 0 : i64}]
]}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't actually check here that {M=16,N=16} wins in propagation. Having a consuming op where this is propagated to would enable checking that this works as expected. Same for the other non-failure tests

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the actual problem here is we don't even print source vector shape anywhere, do we can't check it directly, only by proxy.. I'll factor this out into a separate file.

Copy link
Copy Markdown
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, with a few comments and questions to address

@ftynse ftynse force-pushed the users/ftynse/heuristic branch from 2dbb998 to 1d4587e Compare April 2, 2026 13:47
ftynse added 3 commits April 2, 2026 16:03
It is unfortunately necessary to introduce this additional complexity to
reflect the naive non-dataflow approach in the python counterpart: keep
track of the vector shape of the operation from which the lattice
originally propagates. This is required by the propagation-stopper
heuristic that cannot infer this value from per-value vector shapes that
water correctly cleans up to correspond to the shape. The shape is
stored as a separate lattice field that is simply carried over based on
priority. Since all Mma operations now have different priorities by
construction, this never leads to lattice tops outside of synthetic
test cases. For Write operations, the other source, the priority is
equal but so are vector shapes since those are simply taken from the
hardware constraint object.

This is highly dubious due to non-local propagation rules and blind
overwriting of vector shapes in the python code, but necessary to match
the current behavior. It must be revised as soon as possible to avoid
priorities and non-local behavior altogether.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
@ftynse ftynse force-pushed the users/ftynse/heuristic branch from 1d4587e to c70921d Compare April 2, 2026 14:06
Copy link
Copy Markdown
Contributor

@tgymnich tgymnich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. maybe add ScaledMMA

Comment on lines +2162 to +2163
SmallVector<wave::IndexExprsLatticeStorage> slots = llvm::map_to_vector(
valuesForIndexExpr, [&](Value v) { return getLatticeValue(v); });
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SmallVector<wave::IndexExprsLatticeStorage> slots = llvm::map_to_vector(
valuesForIndexExpr, [&](Value v) { return getLatticeValue(v); });
SmallVector<wave::IndexExprsLatticeStorage> slots =
llvm::map_to_vector(valuesForIndexExpr, getLatticeValue);

@ftynse
Copy link
Copy Markdown
Contributor Author

ftynse commented Apr 2, 2026

I'll update scaled mma separately to avoid issues with an interferring refactor

@ftynse ftynse merged commit de592dc into main Apr 2, 2026
18 of 19 checks passed
@ftynse ftynse deleted the users/ftynse/heuristic branch April 2, 2026 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants