Skip to content

x86_64 AMX (int8 + bf16) & AVX-512-VNNI GEMM kernels — validated on AMX hardware#14

Open
czoli1976 wants to merge 288 commits into
mainfrom
claude/zealous-galileo-fEQ3d
Open

x86_64 AMX (int8 + bf16) & AVX-512-VNNI GEMM kernels — validated on AMX hardware#14
czoli1976 wants to merge 288 commits into
mainfrom
claude/zealous-galileo-fEQ3d

Conversation

@czoli1976

Copy link
Copy Markdown
Owner

⚠️ Scope note (please read first)

This branch sits linearly on top of current main (b9d05f1) but is 288 commits / 373 files (+31k / −1.7k) ahead. Only ~21 of those commits are the AMX/VNNI work described below; the rest spans unrelated areas — pulse/blockify streaming rewrites, CUDA/Metal GPU kernels, ARM64 SVE/SME, WASM SIMD, ONNX com.microsoft ops, tdim simplifier, docs/CI. If a focused AMX-only PR is preferred, those ~21 commits can be split onto a fresh branch off main instead.


x86_64 AMX / AVX-512-VNNI int8 & bf16 GEMM kernels

New kernels in linalg/src/x86_64_fma/:

  • avx512amx_mmm_i32_8x8 / _16x16 — Intel AMX int8 (tdpbssd)
  • avx512amx_mmm_f32_16x16 — AMX bf16→f32 (tdpbf16ps)
  • avx512vnni_mmm_i32_8x8 / _16x16 — AVX-512-VNNI (vpdpbusd)
  • avxvnni_mmm_i32_8x8 — AVX-VNNI ymm (Atom-class cores)

Validated on a real Intel AMX Xeon (amx_tile/amx_int8/amx_bf16, kernel 6.18.5) per linalg/AMX_BENCH_RUNBOOK.md. Full results: linalg/AMX_BENCH_RESULTS.md.

Correctness

  • AMX confirmed live (CPUID + arch_prctl XTILEDATA permission gate).
  • Bugfix 99eb75b validated on silicon — swapped operands in the AMX 16×16 sub fused-op handlers; all scalar_sub/per_row_sub/per_col_sub tests now pass for both int8 and bf16 16×16 (previously only build-verifiable on the non-AMX dev box).
  • 3 AMX-bf16 tests fail under the harness's f32-grade tolerance — root-caused to the test oracle, not the kernel. packed_packed.rs:367 picks the comparison tolerance from the f32 accumulator, ignoring that the f32f32_bf16 packing truncates inputs to bf16. Empirically verified on the AMX host: the kernel matches an independent bf16-truncated reference (built with the project's own f32_to_bf16_rne) with 0 outliers across ~335k elements at the same tight bar, vs 282,788 outliers against a pure-f32 reference. Proposed fix: select SuperApproximate for bf16 packings. These tests short-circuit/skip on non-AMX CI.

Performance (Gelem/s, AMX Xeon @ 2.1 GHz; see results file for full tables)

  • AMX int8 16×16 hits 228–280 Gelem/s1.5–1.8× the new VNNI 16×16 on the same silicon (justifies the boost(100) > boost(50) dispatch ordering).
  • AMX 16×16 is 1.6–3.5× the AMX 8×8.
  • New VNNI 16×16 (zmm) is 4.6–12.5× the 8×8 (ymm) — well above the dev box's 1.3–2.1×.
  • AMX bf16 16×16 is 3.1–5.4× the AVX-512 f32 kernel (documented bf16 precision trade).

Test plan

  • cargo test -p tract-linalg --lib avx512amx (on AMX hardware) — int8 suites clean; the 3 bf16 failures are the tolerance issue explained above.
  • cargo test -p tract-linalg --lib x86_64_fma::mmm
  • cargo bench -p tract-linalg --bench {amx_i32,amx_f32,vnni_i32} (AMX host; they skip with a message otherwise)

Note: benchmarks require an AMX host (Sapphire Rapids / Emerald Rapids / Granite Rapids Xeon, or Xeon Max). One re-measure (VNNI-8×8 @ 512³) is still outstanding — see §6 of the results file.

https://claude.ai/code/session_018Hes6yEvk2TSWB26SAJfqT


Generated by Claude Code

kali and others added 30 commits May 1, 2026 16:26
natural_for_rank produced a square mapping using inputs[0].rank() for
both sides, but ONNX-style broadcast right-aligns input over output, so
output_rank can exceed input_rank. The leading output-only axes were
absent from the mapping, tripping the optimizer's axes-mapping check
under paranoid_assertions on test_expand_dim_changed and
test_expand_shape_model4.
compare --stream contrasts the pulsed plan against the unpulsed
reference on the same random input. Catches state-recurrence
breakage on stateful Scan ops (issue sonos#2157) — turn_0 matches but
turn_1+ diverges when the body's State input gets re-seeded each
call, which is what an over-eager iter=1 Scan inliner produces.
Override the default `disconnected` axes_mapping with one that reflects
the op's actual axis flow:

* every non-reducing axis (= every axis but the last) is identity-mapped
  from input 0 → output 0, and — when ranks line up — also from the
  mask (input 1) → input 0, with the bool mask right-aligned over input
  0's trailing axes;
* the softmax-reducing axis (last axis of input 0, plus last axis of
  the mask) is **deliberately kept disconnected** from the output: its
  size is preserved, but it is not "the same axis" — splitting it
  through a reshape would break softmax's normalisation semantics.
  Leaving it disconnected lets axis-tracking passes (e.g. `track_axis`)
  correctly answer "no" when asked whether an axis-changing transform
  on the reducing axis can pass through.

Adds 4 unit tests covering same-rank float/bool mask, lower-rank
broadcast bool mask (right-aligned), and scalar bool mask.
`Arc<rayon::ThreadPool>` cannot be constructed on
wasm32-unknown-unknown — rayon's default `spawn_handler` calls
`std::thread::spawn`, which is unsupported there. Today the
`Executor::MultiThread(Arc<ThreadPool>)` API is therefore unusable
on browser targets via wasm-bindgen-rayon, even after a successful
`init_thread_pool(N)` JS-side: any `ThreadPoolBuilder::new().build()`
silently fails, so threading falls back to single.

Add `Executor::RayonGlobal` — a payload-free variant signalling "use
rayon's GLOBAL pool". On wasm-bindgen-rayon this routes work to the
N web Workers spawned by `init_thread_pool`. On native it uses
rayon's auto-initialised default. The MMM dispatcher (next commit)
uses `into_par_iter` directly on this variant, sidestepping the
broken `pool.install` path.

This is the enabling change — without it, the chunked-dispatch
improvement in the next commit is unreachable on
wasm32-unknown-unknown.
Replace the existing 1D `into_par_iter` over a single panel axis
with a 2D chunked grid in all three `run_with_scratch_space_*` arms,
and gate threading off entirely for tiny MMMs.

Three changes:

* `chunk_grid()` — ggml-style 16-tile (or 64-tile when one dim is 1)
  panel chunks, with a "block-per-thread along the longer axis"
  fallback when the natural grid would have fewer than `4*nth`
  chunks. Mirrors `ggml-cpu.c:1378-1398`.

* `chunked_dispatch_rayon()` — takes `Option<&rayon::ThreadPool>`.
  `Some(p)` with multi-thread pool: scoped via `p.install`. `None`
  or single-thread `p`: dispatched via `into_par_iter` directly,
  using rayon's GLOBAL pool. The latter is the only working route
  on wasm32-unknown-unknown (via wasm-bindgen-rayon); the former
  preserves native semantics for users passing custom pools.

* `THREADING_PANEL_THRESHOLD = 64` — for shapes whose panel-grid
  product is below this, skip the dispatcher entirely and run
  inline single-threaded. Per-call rayon overhead (~5 µs native,
  ~50 µs wasm-bindgen-rayon worker dispatch) exceeds parallel
  speedup on tiny MMMs. Empirically tuned against MobileNet v2 +
  DFN3 streaming inference.

Adds the `Executor::RayonGlobal` arm in all three dispatch sites
(introduced by the previous commit).

Cross-engine validation (Chromium / WebKit / Firefox via Playwright,
real DFN3 + synthetic dense MMMs from 32x32 to 1024x1024):

  Workload                                      4t speedup
  1024x1024x1024 (transformer FFN)              3.34 / 3.12 / 3.40
  512x768x768 (BERT FFN)                        3.07 / 3.30 / 3.42
  256x256x128                                   2.19 / 3.30 / 3.41
  64x256x64 (DFN-like small)                    1.0 / 2.00 / 3.11
  32x32x32 (tiny)                               1.0 / 1.0 / 1.0  (threshold gates)
  DFN3 streaming inference                      1.04 / 1.08 / 1.03  (Amdahl-bound)

All 60 measured cells produce bit-equal output. Existing 3524
linalg lib tests pass on native with `multithread-mm`.

Native (macOS aarch64, rayon path): net-neutral on common shapes
(256x256, 512x512, 64x256 within ±2%). Latent benefit on
1D-parallelism-hostile shapes (m=8 n=2048).

Full benchmark data in linalg/MULTITHREAD_BENCHMARKS.md.
The chunked-dispatch threshold was a hardcoded constant (`64`). Promote
it to a runtime-tunable `AtomicUsize` with `set_threading_panel_threshold`
and `current_threading_panel_threshold` accessors, default unchanged.

Use cases:
  * Transformer-only workloads (LLM, BERT) want a lower threshold —
    every MMM is large enough to thread profitably.
  * Streaming RNN / mobile-vision workloads keep the default or go
    higher — many tiny MMMs that should stay inline.
  * Tuning experiments without recompile.

Reader uses `Ordering::Relaxed` on the dispatch hot path (no lock).
Default `64` matches the previous constant; existing behaviour
unchanged unless explicitly retuned. `0` disables the gate.
When a depthwise ConvTranspose with a 1×N kernel lowers (in
`Deconv::wire_with_deconv_sum`) to an EinSum with effective K=1, the
matmul kernel runs one FMA per output tile against fixed per-tile setup
(clear, panel-load, store), so the kernel preamble dominates runtime
even though no real matmul work is happening.

This rewrite is wired in two places:

- `EinSum::declutter_with_session` — primary path, per Kali's review
  on sonos#2183 (semantic-equivalence rewrites belong in declutter; downstream
  rules compose better with a Mul than an EinSum).

- `EinSum::codegen` (fast-path before `detect_rule`) — for K=1 einsums
  introduced during codegen itself. ConvTranspose lowering emits an
  EinSum + DeconvSum pair within a single codegen pass, before any
  post-codegen declutter sees it. Calls the same helper, single source
  of truth.

Scope is intentionally narrow: only fires when the einsum's output is
consumed by DeconvSum. That precisely catches the original target case
(DFN3 / GTCRN depthwise ConvTranspose with 1×N kernel — see PR sonos#2183)
without disturbing other K=1-shaped einsums (degenerate Q@K^T inside
SDPA when head_dim=1, vision attention proptests, etc.) where backend-
specific pipelines pattern-match on the matmul shape and break when we
substitute a Mul.

Quantized einsums are left untouched: the existing `dequant` path in
`EinSumMatMul::codegen` produces a non-q einsum that the declutter rule
then catches naturally on the next pass.

Verified end-to-end on DeepFilterNet 3 enhance() pipeline (which pins
tract 0.22.1; backport applied for the WAV-bit-exact test): real WAV
input through STFT -> enc -> erb_dec -> df_dec -> iSTFT [-> post-filter]
-> output WAV is bit-for-bit identical between clean tract and tract +
this patch (md5 match, 0/1,309,932 PCM samples differ). End-to-end
wall-clock -41% (3.32s -> 1.96s mean of 3 on a 27.29s 48 kHz noisy WAV).

erb_dec.onnx einsum-level on M-series macOS:
- /convt1/ConvTranspose.einsum (M=1600, K=1, N=3): 7.43 ms -> 0.224 ms (33x)
- /convt2/0/ConvTranspose.einsum (M=800, K=1, N=3): 3.93 ms -> 0.109 ms (36x)

Tests: 6 explicit K=1 cases in prefix_matmul::test (positive ConvTranspose-
shape patterns + negative SDPA-shape + high-rank cases that should NOT
fire), all 237 tract-core lib tests pass, all 21,898 test-metal tests
pass.
Mirrors the prepare-job pattern from full.yml: workflow_dispatch
takes an optional pr_number, the prepare job resolves it to the
PR's head SHA via the GitHub API, and both linux and apple jobs
check out that ref. Lets us run the embedded build (and the
attached S3 bench-bundle upload) on a fork PR without mirroring
its branch onto sonos/tract first.

Bundle metadata: prepare also emits a test_branch output formatted
as 'pr-NNNN-<head ref>'. The Cross-script steps surface it (and the
resolved SHA) via GITHUB_HEAD_REF/GITHUB_SHA env, which make_bundle.sh
already consults — so graphite series can be located by PR number
and forks sharing branch names don't collide.
ex01-block-diag-reduce: two streams [T,4], pairwise dot product into a
[T,T] score matrix, block-diagonal mask (chunk_size=2), sum-reduce on
the row axis -> [T] output.  No attention-specific framing (no Q/K/V,
no softmax, no value tensor, no pre-chunked input shape) so the
pulsifier has to discover the per-pulse window from the mask structure
alone.

ex01-blockified is the hand-written reference of what Blockify should
produce: model interface preserved (inputs [2*S, 4], output [2*S]),
internal reshape factors the streaming axis into [S, 2] chunks.
Phase B POC.  Recognises the block-diagonal pattern in a typed model:

    EinSum(a, b) -> scores [T, T]
    Mul(scores, mask)  where uniform_tdim(mask) = (coord_i/k == coord_j/k)
    Reduce<Sum> on a streaming axis

and rewrites the surrounding subgraph into chunked form:

    a_blk = Reshape [k*S, ...] -> [S, k, ...]
    b_blk = Reshape [k*S, ...] -> [S, k, ...]
    EinSum chunk-batched (subscript prepended with 's')
    Reduce on the within-chunk row axis
    Reshape [S, k] -> [k*S] at the boundary

The mask construction subgraph is dead post-rewrite and is dropped.
Substitutes the user's stream symbol T -> k*S throughout the rewritten
subgraph, introducing a chunk symbol S.  The user's pulse value is
translated by /k so --pulse T=2 with chunk_size=2 becomes --pulse S=1
on the rewritten graph.

Hooked into PulsedModel::new_with_mapping; no-op when no pattern
matches.

Recogniser scope: a single block-diagonal pattern, identified by string
match on the uniform_tdim form.  Banded patterns and the more general
algebraic recognition come later.

Integration test loads harness/pulse-multi-axis/ex01-block-diag-reduce
NNEF graph, runs batch and pulsified, compares numerically — passes
with delay=0 across 3 chunks.
The Phase B POC matched the mask's uniform_tdim by formatting it to a
string and comparing against literal candidates — fragile to any TDim
Display change and bounded to a 2..=64 chunk-size range.

Replace with direct enum destructuring:

    Eq(Div(Sym(🎯i), k), Div(Sym(🎯j), k))   with k1 == k2 == k

Uses tract_core::ops::logic::sym_to_coord_axis() (already pub) to map
the syms back to streaming-axis indices, and verifies both sides
reference the streaming axes (in either order) and share the same k.

Recogniser is now robust to Display formatting and any chunk-size
positive integer.  Seven unit tests cover canonical form, large
chunk size, swapped-axis order, mismatched k rejection, non-streaming
axis rejection, non-Eq root rejection, and offset-in-numerator
rejection (a documented variant we intentionally don't accept yet).
The Phase B POC's einsum chunkifier prepended the chunk axis at
position 0 of every input/output subscript, regardless of where the
streaming axis actually lived.  That was correct only because ex01
happens to put streaming axes at position 0 everywhere; on a multi-head
shape like [B, H, T_q, T_k] (streaming at positions 2 and 3) the
prepend would silently produce a broken model that tract's einsum
output_facts would reject far away from the actual problem.

Make the recogniser explicit about the assumptions and the rewriter
honour them:

- Pattern now stores `einsum_in_streaming_axes` (one position per input)
  and requires every input to have exactly one streaming axis.  Drops
  the dead `einsum_side_in_mul` field.
- Output streaming axes must be exactly two and contiguous (the chunk
  batch axis goes between the first one's slot, and the within-chunk
  versions stay adjacent).
- chunkify_einsum inserts the chunk char at each input's streaming-
  axis position and at the first streaming output axis position,
  shifting everything after.
- chunked_axis_index uses the chunk insertion position to translate
  downstream Reduce/RmAxis indices: axes before the chunk position
  stay put, axes at-or-after shift by one.
- The boundary-merge reshape at model output finds the chunk axis
  dynamically instead of assuming it sits at position 0.

ex01-block-diag-reduce continues to pass end-to-end.  Five new unit
tests cover chunkify_einsum and chunked_axis_index with non-aligned
streaming-axis positions (the multi-head [B, H, T_q, T_k] shape and a
mixed-position case).
Same block-diagonal structure as ex01, but the row-axis Reduce<Sum> is
replaced by a second EinSum against a third stream c [T, D]:

    output[i, d] = sum_j masked[i, j] * c[j, d]

— the SDPA shape Q·Kᵀ → mask → attn·V without softmax.  Smallest
synthetic exercising a downstream second EinSum on the masked score
matrix.

Batch passes against the Python reference.  Pulsification with
--pulse T=2 currently fails: Blockify's recogniser only matches
Mul-by-mask → Reduce<Sum>, so it does not fire on this graph;
v1's pulsifier then produces a model with undetermined-symbol
broadcasts that panics at runtime.  Documented as the next concrete
Blockify target in the harness README.
Pins where ex02-block-diag-bilinear fails today in the pulse pipeline,
so we notice if the boundary moves as Blockify evolves.

Pulsification + plan construction silently succeed despite Blockify's
recogniser not firing on the Mul-by-mask -> 2nd-EinSum form; the
underlying machinery surfaces the failure as a recoverable Result::Err
at state.run() time, not as a panic.  The CLI panic seen earlier is a
downstream .unwrap() in libcli/src/tensor.rs on the same Err — CLI
ergonomics, not a library bug.

Test asserts the current boundary; it will start failing when Blockify
learns the second-EinSum pattern (and the failure either disappears or
moves to a different op).
Refactor the recogniser into the layered design:

  Phase 1 (op-agnostic, topological)
    Find every node whose output has multi-T-axis shape (>=2 streaming-
    symbol axes).  This is the section.  Initiators are section nodes
    whose inputs are all outside; terminators are non-section nodes
    consuming an in-section wire.

  Phase 2 (op-agnostic, structural coverage)
    At least one wire in the section must carry uniform_tdim or
    region_of_interest — the anchor used by phase 3.  (Wires without
    either are tolerated; the score matrix downstream of a Mul-by-mask
    has its values bounded by the mask without inheriting the
    annotation.)

  Phase 3 (op-agnostic, mask-form recognition)
    Find at least one wire with uniform_tdim matching the canonical
    block-diagonal form `(coord_i / k) == (coord_j / k)`; record k.

  Phase 4 (op-aware, bridge to existing rewriter)
    Pattern::from_section identifies the EinSum initiator (compute,
    no uniform_tdim on output), the Mul-by-mask body node, the mask
    outlet (the Mul's uniform_tdim input), and the Reduce<Sum>
    terminator.  This is where op-specific knowledge enters; the
    rewriter is unchanged.

The new layering means:
  - The detector knows nothing about EinSum, Mul, or Reduce — it
    just identifies the multi-T-axis subgraph.
  - The op-specific bridge (Pattern::from_section) is the explicit
    place where new initiator/terminator op-types will be added.
  - Multi-initiator graphs are already detected (ex01 has two:
    the score EinSum and the Eq for the mask).  Multi-terminator
    too, when they appear.

Behaviour-preserving on ex01 (integration test passes); ex02 still
fails at the same boundary as before.
Before this commit, find_quadratic_section lumped every multi-T-axis
node in the model into a single section regardless of dataflow
connectivity.  On a graph with two independent quadratic subgraphs
(e.g. parallel attention layers), Pattern::from_section would silently
pick the first matching EinSum/Reduce it found across both — possibly
from different subgraphs — and produce a partial, structurally wrong
rewrite.

Replace with find_quadratic_sections (plural): collect all multi-T-axis
nodes, partition them into connected components by dataflow edges, and
run phase 2+3 (annotation coverage + mask form) per component.  Each
surviving component becomes its own QuadraticSection.

find_pattern then handles 0/1/many cases explicitly:
  - 0: blockify is a no-op.
  - 1: build the bridge Pattern, rewrite as before.
  - 2+: refuse with a clear error.  Multi-section rewriting is a
    follow-up; refusing cleanly is preferable to a partial rewrite.

The connected-components walker is factored out as a separate function
(`connected_components`) so it's unit-testable in isolation.  Test
verifies that two parallel chains of nodes get partitioned into two
components even though they're declared "multi-T-axis" together.

ex01 (single-section) continues to pass end-to-end; total tests now
26 pulse unit + 55 proptests + 2 integration, all green.
Refactor the rewriter to dispatch on the terminator's op-type instead of
hardcoding Reduce<Sum>.  No new behaviour: ex01 still pulsifies and
matches batch numerically, ex02 still fails at the same boundary
(Pattern::from_section refuses to recognise EinSum terminators yet —
follow-up commit).

What changed:

  - Pattern.reduce_node renamed to terminator_node; the op-specific
    reduce_axis field dropped (re-derivable from `op_as::<Reduce>()`
    at the dispatch site).

  - The rewriter dispatches via if-let-chain on the terminator op-type:
      if let Some(reduce_op) = term_node.op_as::<Reduce>() {
          rewrite_reduce_terminator(...)
      } else {
          bail!("Blockify: unsupported terminator op-type ...")
      }
    The Reduce-specific axis-shifting logic is now a separate
    `rewrite_reduce_terminator` helper.

  - Same dispatch shape covers the Rm-on-reduce-axis follow-up
    consumer rewrite (axis index translated through the chunk
    insertion when applicable, otherwise op copied verbatim).

The recogniser (Pattern::from_section) still only matches Reduce<Sum>
on its terminator; the next commit adds the EinSum-terminator arm and
turns ex02 green.
Extend the recogniser and the rewriter to cover the case where the body
Mul-by-mask is followed by a contracting EinSum instead of Reduce<Sum>:
the SDPA shape Q·Kᵀ → mask → attn·V (without softmax).

Recogniser changes (Pattern::from_section):
  - Try Reduce<Sum> first; if not found, try EinSum.
  - An EinSum terminator candidate must take the masked wire as one of
    its inputs and must contract at least one of the streaming subscripts
    on that input — i.e. drop a streaming subscript char from the output.
  - Both arms produce the same Pattern; dispatch happens in the rewriter.

Rewriter changes:
  - chunkify_einsum generalised: accepts per-input/per-output streaming
    starts (Option<usize>) instead of pulling them off Pattern; the same
    routine now handles both the initiator (1-streaming inputs, 2-
    streaming output) and the terminator (potentially 2-streaming
    multi-T-axis input, 1-streaming auxiliary inputs, 1-streaming output).
  - rewrite_einsum_terminator helper: locates the multi-T-axis input slot
    (where the masked wire enters), reuses the chunked initiator output
    for that slot, chunkifies each remaining auxiliary input via a
    Reshape on its streaming axis, and wires the chunked subscript.
  - rewrite() dispatches via if-let-chain on the terminator op-type
    (Reduce vs EinSum); other op-types still bail with a clear error.

Numerical correctness verified end-to-end:
  - harness/core-proptest-pulse/tests/blockify_ex01.rs adds
    ex02_block_diag_bilinear_blockified_pulse_matches_batch which loads
    the NNEF graph, runs batch and pulsified, and asserts numerical
    equality on the [6, 4] output across three chunks.
  - The previous ex02 failure-boundary probe is replaced with the new
    numerical match test.

ex01 continues to pass.  Total: 27 pulse unit tests + 55 proptests +
2 integration tests, all green.
With QuadraticSection carrying topological info and the rewriter
already dispatching on op-type, the intermediate `Pattern` struct was
just a bag of derived locals — it added an indirection without
encoding anything the rewriter couldn't compute itself.

Drop it.  The rewriter now:
  - takes &QuadraticSection directly (no bridge layer);
  - identifies the compute initiator, body Mul-by-mask, mask outlet,
    and terminator inline at the top of `rewrite()`;
  - returns Ok(None) if the section's op-types don't match — Blockify
    becomes a no-op upstream, same behaviour as before.

Helper signatures cleaned up:
  - chunked_axis_index(orig_axis, chunk_pos) — no &Pattern, just the
    chunk insertion position; also dropped the unnecessary Result.
  - rewrite_reduce_terminator takes chunk_axis_in_output: usize.
  - rewrite_einsum_terminator takes mul_node_id: usize.
  - collect_dead_mask_nodes takes mask_outlet_node, einsum_node_id,
    mul_node_id directly.

Net diff: -147 lines.  All 27 pulse unit tests + 55 proptests +
2 integration tests still green; no behavioural change on ex01 or
ex02.
…ode"

The previous rewrite() identified specific nodes (compute initiator,
body Mul, terminator) and hand-wired the rewrite around them, then ran
a topological pass for non-section nodes.  That structure baked in the
assumption of one initiator + one body node + one terminator, even
though the QuadraticSection already exposes initiators/terminators
as collections.

Restructure to a single topological pass with per-role dispatch:

  for each node in eval order:
      if dead:          skip
      if Source:        translate_source       (T → k·S substitute)
      if initiator:     translate_initiator    (dispatch on op-type)
      if terminator:    translate_terminator   (dispatch on op-type)
      if section node:  translate_body         (dispatch on op-type)
      else:             translate_outside

Each per-role function returns Ok(true) on success, Ok(false) if its
op-type isn't handled (Blockify becomes a no-op upstream), or Err on
hard failure.

Per-op-type translators are independent:
  - translate_initiator dispatches to translate_initiator_einsum.
  - translate_body handles Mul-by-mask by aliasing the body node's
    output to its compute input in `mapping` — no node emitted in
    `out` (the chunked graph multiplies by 1).
  - translate_terminator dispatches to translate_terminator_reduce or
    translate_terminator_einsum.
  - translate_outside handles AxisOp specifically: when the input was
    chunkified upstream (rank grew by 1), it shifts the AxisOp's axis
    parameters through the chunk insertion via a generic
    shift_axisop_through_chunk helper.  Other ops are copied verbatim.

Adding a new initiator / body / terminator op-type now means adding
one arm to the corresponding dispatcher.  No central "find THE
EinSum / Mul / Reduce" logic remains.

Dead-node identification is op-agnostic now: a section node whose
output has uniform_tdim is dead (mask construction), plus any node
whose only consumers are dead.  Replaces the previous walk-back-from-
mask-outlet that needed to know einsum_node_id and mul_node_id.

ex01 (Reduce terminator) and ex02 (EinSum terminator) both still
match batch numerically.  All 27 pulse unit tests + 55 proptests + 2
integration tests green.
…tion

Replace the manual model rebuild (translate_source / translate_outside /
translate_initiator / translate_body / translate_terminator / boundary_merge
+ topological dispatch) with two clearly-named phases:

Phase 1: introduce a fresh chunk symbol S, substitute T → k·S globally
on the whole model via core's `substitute_symbols`.  No more YOLO model
rebuild; the substitution is one HashMap-driven core API call.

Phase 2: for each detected QuadraticSection, build a TypedModelPatch and
apply it.  The patch:
  - taps each EinSum input from the substituted model
  - wires split-reshapes inside the patch
  - wires the chunked initiator EinSum
  - dispatches on terminator op-type (Reduce or contracting EinSum) and
    wires the chunked terminator
  - if a downstream RmAxis-on-the-reduce-axis exists, wires the chunked
    RmAxis with the shifted axis
  - wires a merge reshape that flattens [..., S, k, ...] back to
    [..., k·S, ...] so the boundary outlet's shape matches the original
    post-substitution shape
  - shunts the boundary outlet to the merged result
  - obliterates the original initiator, body Mul, terminator,
    post-terminator op (if any), and the mask-construction subgraph

`shunt_outside`'s compatibility check passes because phase 1 already
substituted T → k·S in every fact, so the original outlet (boundary)
has shape [k·S, ...] post-substitution and the patch's merged outlet
matches syntactically.

The "one patch per section" structure handles multi-section graphs
naturally: each section gets its own patch built and applied
independently; inter-section state lives only in the shared model.
The earlier "refuse if more than one section" guardrail is dropped —
mismatched chunk sizes across sections still bail (a single global
substitution can't cover them).

Net diff: -200 LOC.  The deleted helpers (translate_source,
translate_outside, translate_initiator, translate_initiator_einsum,
translate_body, translate_terminator, translate_terminator_reduce,
translate_terminator_einsum, boundary_merge, shift_axisop_through_chunk,
collect_dead_nodes) collapse into the single build_section_patch
function plus core's substitute_symbols.

Tests unchanged at the integration level: 27 pulse unit + 55 proptests
+ 2 blockify integration tests, all green.
Replace the ad-hoc `blockify(model, sym, pulse) -> BlockifyResult`
function with a `BlockifyTransform` that implements `ModelTransform`,
matching the convention used by `PulseTransform` and the rest of
tract.

The transform's config is `BlockifyConfig { symbol: Option<String> }`
(streaming-symbol name, defaults to "S").  `transform()` finds the
quadratic sections, substitutes T → k·S globally via core's
`substitute_symbols`, applies one TypedModelPatch per section, and
stashes the substitution metadata in model properties:

  - `blockify.chunk_symbol`     (1-element string tensor, name of the
                                 new streaming symbol)
  - `blockify.chunk_size`       (scalar i64, the divisor k)
  - `blockify.original_symbol`  (1-element string tensor, what was
                                 substituted away — informational)

A `blockify_output(model)` helper reads these back into
`(Symbol, i64)`.  `PulsedModel::new_with_mapping` uses it to translate
the user's pulse value from token-units to chunk-units before invoking
the pulsifier.

Registered with `register_model_transform!` so the CLI sees it:

  tract MODEL --transform 'blockify(symbol: Some("T"))' \
              --transform 'pulse(symbol: Some("S"), pulse: "1")' \
              dump

prints both `blockify.*` and `pulse.*` properties on the final model
and a clean chunked + pulsed graph.

Drops the `blockify()` free function and `BlockifyResult` struct —
the transform is the public surface now.

All 27 pulse unit + 55 proptests + 2 blockify integration tests still
green.
…tions

Reframe the patch builder so it reads as the layered structure we agreed
on: top-level iterates op-agnostically over initiators, body nodes,
terminators; each iteration dispatches to a per-op-type sub-function
that is independent of the others.

Top-level `build_section_patch` is now ~70 lines of clear roles:

  for each initiator in sec.initiators:        wire_initiator(...)?
  for each body node in section topo order:    wire_body(...)?
  for each terminator in sec.terminators:      wire_terminator(...)?
  for each (boundary, chunked_form) in shunts: wire_merge_reshape + shunt
  collect_dead_nodes(model, sec, &patch.shunts) → obliterate

Sub-functions:

  wire_initiator              dispatches op-type → wire_initiator_einsum
  wire_body                   dispatches op-type → wire_body_mul_by_mask
  wire_terminator             dispatches op-type → wire_terminator_reduce
                                                 / wire_terminator_einsum
  wire_initiator_einsum       taps inputs, wires reshapes, wires chunked op
  wire_body_mul_by_mask       aliases compute input as chunked output
                              (no node added to patch)
  wire_terminator_reduce      wires chunked Reduce, optionally chunked
                              RmAxis on the (former) reduce axis
  wire_terminator_einsum      taps + reshapes auxiliary inputs, wires
                              chunked op; multi-T-axis input passes
                              through `chunked` map
  wire_merge_reshape          collapses [..., S, k, ...] → [..., k·S, ...]

`collect_dead_nodes` is op-agnostic: section nodes + shunted boundary
nodes + transitive (consumers-all-dead).  Excludes model inputs.

Adding a new op-type for any role now means adding one arm in the
relevant `wire_*` dispatcher and one new `wire_*_<op>` function.  No
shared state with other roles' translators beyond the `chunked` map
and the patch-in-progress.

ex01 + ex02 still numerically match batch; 27 pulse unit + 55 proptests
+ 2 integration tests all green.
NNEF supports `tract_core_shape_of(x)[i]` directly (cf. nemotron's
patch transforms which use the same idiom).  Collapse the three-line
slice+squeeze that was extracting T from the input shape:

    a_shape = tract_core_shape_of(a);
    T_slice = slice(a_shape, axes = [0], begin = [0], end = [1], stride = [1]);
    T_dim   = squeeze(T_slice, axes = [0]);
    pos     = tract_core_range(0, T_dim, step = 1);

into the one-liner:

    pos = tract_core_range(0, tract_core_shape_of(a)[0], step = 1);

Applied to both ex01-block-diag-reduce and ex02-block-diag-bilinear.
Both still pulsify and numerically match the batch reference.
The mask-construction chain was casting the position range to f32,
dividing in float, flooring, and casting back to i64 — five lines for
a single integer floor-division.  NNEF's `div` produces a TDim wire
on integer inputs which simplifies to `(coord)/k` directly; the rest
of the chain (unsqueezes, eq, cast-to-f32) doesn't care whether its
input is i64 or TDim.

The recogniser sees the same `((🎯0)/2 == (🎯1)/2)` uniform_tdim it
saw before; both ex01 and ex02 still pulsify and numerically match.
New typed op `WindowOnAxis { axis, window, start }` in pulse-opl: inserts
a static window axis after the streaming axis where slot `w` carries the
input shifted by `start + w` (with zero pad past the boundaries).  The
pulse pulsifier in pulse/src/ops/window.rs lowers it to
`Delay(0, W-1) → PulsePad(before = -start) → PulsedExposeWindow`:

* Delay accumulates the latest W chunks per pulse.
* PulsePad zero-fills the leading -start positions and shifts
  stream.delay back, so past-window forms (start < 0) end up causal
  (stream.delay = 0 after PulsePad).
* PulsedExposeWindow reshapes the per-pulse [W, ...] view into
  [1, W, ...], exposing W as a static axis.

Constraint: pulse=1 on the windowed axis.  Constraint:
start <= 0 <= start + window - 1 (window straddles current chunk).
Extends the mask recogniser from a single block-diagonal pattern
(coord_a/k == coord_b/k) to a closed-enum `MaskForm::Banded { lower,
upper, k, axis_a, axis_b }` covering both bounds on the chunk-index
diff.  Block-diagonal is now the special case lower == upper == 0.

Recognises two AST shapes after `reduce()`:
  Eq(Div(Sym(coord_a), k), Div(Sym(coord_b), k))
  Mul([Ge(upper, D), Ge(D, lower)]) where D = coord_a/k - coord_b/k

For banded sections (window straddling current chunk: lower <= 0 <=
upper, W = upper - lower + 1 > 1), the rewrite wraps the contracted-
axis input — identified by tracking the einsum's stream axis to the
output position matching mask.axis_a — with WindowOnAxis(window=W,
start=lower) followed by a flatten reshape that merges the new W
axis into the within-chunk axis.  The chunked einsum subscripts stay
unchanged; only the contracted axis carries W·k elements instead of
k.  Block-diagonal sections take the same path with W = 1 (a no-op
WindowOnAxis is skipped).

Output stream.delay is upper at the section boundary (0 for causal
masks, > 0 for future-window masks like ex03).  Purely-future
(lower > 0, skip current) and purely-past (upper < 0) forms are
recognised but not rewritten — they don't appear in real attention
masks and would need different pulsifier wiring.
…ize change

stream.delay counts elements on the streaming axis.  When the reshape
changes the per-pulse size on that axis (e.g. the (S, k) -> S·k merge
that Blockify emits at a section boundary), the same physical lag
should cover a different element count, so stream.delay must rescale
by new_per_pulse / old_per_pulse.

Without this, a banded section's chunk-axis stream.delay was carried
verbatim across the boundary merge into the token axis — a delay of
"1 chunk = k tokens" was advertised as "1 token", off by k - 1.
Downstream the CLI's `--assert-output-bundle` slicing path uses
`pulse.delay` to skip warmup tokens; with the wrong delay it skipped
too few and the streamed output failed to align with the batch
reference.

Rescale at PulsedReshape pulsed_output_facts: read the per-pulse sizes
from the stored (already substitued) `from`/`to` of the wrapped
AxisOp::Reshape, scale stream.delay accordingly, and bail with a
clear error if the result wouldn't be integer.

For pulse-v1's existing wavenet-style consumers this changes nothing
(no reshape across the streaming axis).  Block-diagonal Blockify
sections also unaffected (stream.delay stays 0 throughout).
…nme.sh

* ex03-banded-reduce: future-window mask (0 <= chunk(i) - chunk(j) <= 1).
  Mimics multi-chunk attention with right-context.  Output stream.delay
  is L*P = 2 tokens; the CLI's --assert-output-bundle path uses the
  rescaled pulse.delay to skip warmup tokens before comparing against
  io.npz.

* ex04-banded-causal: past-window mask (-1 <= chunk(i) - chunk(j) <= 0).
  The causal counterpart, exercising lower < 0 in the recogniser and
  the PulsePad zero-fill in the WindowOnAxis pulsifier (so the streamed
  first-chunk output matches batch's "out-of-stream chunks contribute
  zero" convention).  Output stream.delay = 0 (fully causal).

Both synthetics ship with a `runme.sh` that runs batch (--set T=6) and
pulsified (--pulse 'T=2') against io.npz with --approx approximate.
Existing ex01 / ex02 get the same treatment so all four cases run via
the standard CLI test harness.  .travis/cli-tests.sh extended to
discover harness/pulse-multi-axis/**/runme.sh alongside nnef-test-cases.

The previous Rust integration tests in
harness/core-proptest-pulse/tests/blockify_ex01.rs are gone — runme.sh
covers the same ground with one-tenth the lines, exercises the actual
CLI code path, and surfaces CLI-side bugs (like the pulse.delay
rescale issue fixed in the previous commit) as part of normal
validation rather than masking them with bespoke Rust glue.
`ModelPatch::apply` (core/src/model/patch.rs:368-385) seeds a garbage
walk from each shunted outlet's source node and replaces the chain with
Dummy, propagating upstream as long as a node has no remaining
successors.  That sweeps the same set our hand-rolled walk did:
initiator EinSum + mul-by-mask + terminator (the shunted boundaries)
plus the orphan mask-construction chain that only fed the now-Dummy
Mul.  Verified by `dump --transform blockify`: stale nodes are gone
without our explicit obliterate pass.

Net -54 lines.
kali and others added 29 commits May 28, 2026 09:29
cli/compare: strip the .fused_axis_op suffix the CUDA translator adds
when an op absorbs adjacent axis ops, so per-node compare lines up GPU
outputs against the CPU reference (covers ~17% more nodes on a typical
pulsified GPU model).

cuda/transform: TRACT_CUDA_FORCE_CPU=substr[,substr,...] env var that
forces matching nodes to the CPU fallback path.  Pinpointed
CudaGgmlGemm on selfAttn_xMatmul.blockified as the source of the
encoder pulse + CUDA drift; useful keep-around for the next time.
PulsePad is on `can_fuse_move`'s allowlist of ops that accept
non-contiguous (Move-permuted) inputs.  Its partial fills already use
`copy_with_origins`, but the initial 'copy the whole input to output'
used `flat_copy` — a flat memcpy that reads the buffer in pre-Move
byte order while the output is laid out in post-Move natural strides.

Visible symptom on the pulsified Nemotron encoder under --cuda: the
attention-output matmul fed a GpuPulsePad with a fused
GpuMoveAxis(0→1) on its input; the bad initial copy garbled the
matmul output before downstream layers consumed it, accumulating
~26% outliers end-to-end.
…nputs

The fallback pre-check called `gpu_op.output_facts` on the raw
target-side input facts, but those can be a mix of host facts (e.g. a
kv-cache past tensor) and device facts (current-turn output).  GPU op
`output_facts` impls bail with 'Inconsistent facts' on mixed inputs,
which then wrongly trips the CPU fallback.  Symptom on Llama 3.2 1B
f32f32 --cuda: kv-cache Concat and residual Add ops all fell back to
CPU, blowing the LLM CI op-only allowlist.

Mirror what wire-time `sync_inputs_if_required(ToDevice)` does:
wrap each non-device input as a DeviceFact-from-host before calling
`output_facts` for the pre-check.

Also adds an opt-in `TRACT_CUDA_TRANSLATE_DEBUG` env var that prints
each rejected node and the underlying error chain — handy for the
next time a pre-check decision needs investigation.
Adds a sibling kernel that loads the mask as uchar/char, substitutes
-inf at masked positions before softmax, and when post_softmax_mask is
set scrubs fully-masked rows (sum == 0 / NaN) to 0 on write-back.
Lifts the GpuScaledMaskedSoftmax guards so bool masks aren't rejected
by output_facts, and drops the rule_if!(!post_softmax_mask) gate on
both backends.

For the nemotron-streaming encoder this moves all 24 SMS nodes off CPU
(--cuda matches the recorded io bundle at --approx very).  Metal mirror
matches structurally; CI's macOS nemotron harness covers the numeric
check there.
…, Reduce)

Both SMS (this branch) and DiagGather (already on main) now have CUDA +
Metal kernels.  IsNan and Reduce don't appear in any of the 4 streaming
models — IsNan never did, Reduce<Sum> is always F32 which both backends
handle.  Audit on --cuda confirms zero CPU instances; --metal mirrors the
same allowlist now that the Metal kernels exist.  Tight placement check:
any regression that puts one of these on CPU now fails CI.
The decoder is stepped one token at a time by the caller (external state
plumbed through the outer graph), so iters resolves to 1 and the Scan
body can be inlined.  Apply the existing core force_scan_external_state
transform on the decoder run; the two LSTM cells now land on GPU.

Drop Scan from the gpu allowlists — no model in the harness keeps a
Scan node on CPU after this.
Now that cuda + metal Gather kernels are on main, the decoder embedding
lookup runs on GPU.  Audit confirms zero CPU Gather across all 4 models
(decoder is run with -t force_scan_external_state so its embedding
input is fed directly to CudaGather/MetalGather).
Residual (input + skip + bias) followed by LayerNormalization over the
last axis scaled by gamma (+ beta), computed in f32. Optionally emits
mean / inv_std_var / input_skip_bias_sum outputs.

Validated bit-close against onnxruntime (output + input_skip_bias_sum).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…8x8)

Route qmmm_i32 through VPDPBUSD when AVX-512 VNNI is available, replacing the
AVX2 per-K widening-multiply inner loop. Consumes the existing K=4-inner
PackedI8K4 layout; A is offset by +128 for VPDPBUSD's u8*s8 form and the
128*sum_k(B) bias is removed per output column, so the i32 accumulators stay
bit-identical to the AVX2 path and the whole quantization epilogue is reused.

Runtime-gated via where(AVX512VNNI); non-VNNI x86 keeps the AVX2 fallback.
Includes a vnni_i32 microbench (VNNI vs AVX2 int8).

Measured on Cascade Lake (single-thread, kernel-only): 9-14x AVX2 across
representative shapes (e.g. 512x512x512: 8.23 -> 99.5 Gelem/s).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Mirrors the SME probe pattern: a tiny dummy_amx.S file containing the
mnemonics the upcoming kernel needs (ldtilecfg, tilezero, tdpbusd,
tilerelease) is compiled by the build script. On toolchains predating
AMX support — notably Debian stretch's gas 2.28 — the probe fails and
the `tract_amx_int8` cfg is not emitted, so the (forthcoming) kernel
file is excluded from compilation and the Rust side never references the
absent symbol. Dispatch then falls back to VNNI or AVX2 silently.

Sets up infrastructure for the next commit which adds the actual kernel.
No behaviour change yet: amx_int8_files is empty until the kernel lands.

https://claude.ai/code/session_01MtQweWhf8E1W7pEJMF6ywf
Route qmmm_i32 through Intel AMX TDPBSSD when CPUID reports amx-int8/amx-tile
AND the OS grants tile-data XSAVE permission (Linux: arch_prctl ARCH_REQ_XCOMP_PERM).
The kernel exposes the same 8x8 ymm-accumulator tile as avx512vnni_mmm_i32_8x8
and reuses its entire post-matmul dispatcher epilogue (per_rows / per_cols /
scalars / q_scale / q_shr / add_unicast / store) unchanged — only the inner-K
matmul phase changes.

Tile geometry inside the kernel:
  tmm0 (C): 8 rows x 32 colsb -> 8 M x 8 N i32 accumulator (the 8x8 tile)
  tmm1 (A): 8 rows x 64 colsb -> 8 M x 64 K-bytes per inner iter
  tmm2 (B): 16 rows x 32 colsb -> 16 K-pair-rows x (8 N-cols * 4 K-bytes)
Per TDPBSSD: 8 * 8 * 64 = 4096 i32 mul-acc ops (128x a single vpdpbusd ymm).

After the matmul phase, tmm0 is tilestored to a 256-byte stack scratch and
loaded back as 8 row-major ymm registers, then a 24-instruction 8x8 i32
transpose (vpunpckl/h + vpunpcklqdq/h + vperm2i128) brings the accumulators
into the column-major ymm0..ymm7 layout the existing epilogue expects.

Packing:
- B reuses the existing K=4-inner PackedI8K4 layout unchanged (the same byte
  layout that VNNI feeds vpdpbusd; tileloadd with stride=32 and cfg.colsb=32
  reads it as one K-pair-row per tile row).
- A uses a NEW M-major-within-panel layout (PackedAmxA): per 8-M-row panel,
  bytes are laid out row-major as panel[m*K_padded + k] = A[m, k], with
  K_padded = ceil(K / 64) * 64. tileloadd with stride=K_padded reads 8
  contiguous M-rows of 64 K-bytes per inner iter.

TDPBSSD is s8 x s8 -> i32 (Sapphire Rapids+, AMX-INT8 baseline), so no +128
bias trick is needed (unlike VNNI's vpdpbusd). The i32 accumulators are
bit-identical to the AVX2 / VNNI paths.

Build-time gating: a `tract_amx_int8` cfg is emitted only when the assembler
accepts the AMX mnemonics (ldtilecfg, tilezero, tdpbssd, tilerelease,
tileloadd, tilestored), checked by the assembler_supports_amx_int8 probe
introduced in the previous commit. Old toolchains (Debian stretch binutils
2.28) fall back to VNNI silently.

Runtime gating: has_amx_int8() does both CPUID (leaf 7 sub-leaf 0 EDX
bits 24/25, since `is_x86_feature_detected!("amx-int8")` is gated on the
nightly x86_amx_intrinsics feature) and a one-shot Linux arch_prctl
ARCH_REQ_XCOMP_PERM call for XFEATURE_XTILEDATA (=18) via raw syscall.
Result is OnceLock-memoised. Non-Linux returns false.

Validation:
- `cargo test --release -p tract-linalg`: 2885+9 tests pass, 0 failed.
- The avx512amx_mmm_i32_8x8 kernel passes the full MMM property-test suite
  (i8i8 frame::prop, i32i32 frame::prop, store_i32/i8 row/col/arbitrary,
  return_q_scale across all rounding policies + pot/nonpot scales, etc.) —
  bit-identical to AVX2 and VNNI on the same inputs.

https://claude.ai/code/session_01MtQweWhf8E1W7pEJMF6ywf
Same M/K/N shapes as the vnni_i32 bench (64x256x64, 256x256x256,
512x512x512, 1024x1024x64). All three kernels run the i8i8 packing path
(index 1) so the only difference is the matmul inner loop. Skipped at
runtime when `has_amx_int8()` returns false (= CPUID lacks amx-int8/tile
or the arch_prctl XSAVE permission was denied), and at build time when
the `tract_amx_int8` cfg was not emitted.

https://claude.ai/code/session_01MtQweWhf8E1W7pEJMF6ywf
The AMX kernel uses a custom A-side packer `PackedAmxA` (M-major-within-
panel rows, K padded to multiples of 64). When dispatched on AMX hardware,
`OptMatMulPack::eval_with_session` in tract-core sees `PackedAmxA` as the
packing format and previously bailed with "OptMatMulPack does not support
packing format PackedAmxA". On Cascade Lake the bug was latent (the AMX
dispatcher never activated); on Sapphire Rapids/Emerald Rapids it caused
29 quant/matmul tests to fail end-to-end.

Fix:

* `core/src/ops/matmul/pack.rs::pack_view_with`: add a `PackedAmxA`
  downcast arm parallel to the existing `PackedI8K4` arm. Gate the import
  on `target_arch = "x86_64"` since `tract_linalg::x86_64_fma` only
  exists there.

* `linalg/src/x86_64_fma.rs`: drop `#[cfg(tract_amx_int8)]` from
  `pub mod amx;`. `PackedAmxA` and `has_amx_int8()` are pure data-layout
  / CPUID code with no AMX-specific assembly — they can compile and
  exist on any x86_64 host regardless of whether the assembler can encode
  AMX instructions. Only the kernel registration in `mmm.rs` and the
  `where(AVX512AMX)` gate need `tract_amx_int8`.

This lets tract-core reference `PackedAmxA` unconditionally, removing
the cross-crate cfg-gating problem (tract-core's build.rs doesn't run
the AMX assembler probe, so it can't see `tract_amx_int8`).

Test plan:

* `cargo test --release` across tract-linalg / tract-core / tract-data /
  tract-nnef / tract-onnx / tract-pulse / tract-transformers / tract-hir
  / tract on Emerald Rapids (model 207, amx-int8 + amx-tile flags):
  **3458 passed, 0 failed**, including the AVX-512 AMX MMM property
  suite (`avx512amx_mmm_i32_8x8::{i8i8,i32i32}::frame::prop`,
  `store_i32/i8::*`, `return_q_scale_*`, `fuse::prop`) and the
  tract-core `ops::matmul::quant::*` suite that exercises the
  `OptMatMulPack` -> `PackedAmxA` codepath end-to-end.

* All 15 quantized NNEF test cases (conv-q40 × 13, qmul, copy-requant)
  pass with output assertion against `io.npz` reference on AMX hardware.
Add prefetcht0 hints inside the K=64 inner loop of avx512amx_mmm_i32_8x8
for the data the NEXT iteration will consume. tileloadd brings the active
A/B tile data into L1 on demand; the prefetches ask the hardware
prefetcher to start the L2->L1 fill earlier so the next iter's tileloadd
sees the data already warm.

* A side: 1 prefetch per iter at [rax + 64] -- next iter's A row 0 start.
  The 7 other rows are stride r8 = K_padded apart; the hardware stream
  detector picks those up.
* B side: 8 prefetches at [rbx + 512..960] -- all 8 cache lines of next
  iter's 512-byte B panel.

Numbers on Emerald Rapids (model 207, 1 thread, `cargo bench
-p tract-linalg --bench amx_i32`), packed_packed avx512amx, i8*i8->i32:

|  Shape (M*K*N)    | Before (Gelem/s) | After (Gelem/s) | Delta |
|-------------------|------------------|------------------|------:|
|     64 *  256 * 64 |             64.5 |              66.5 | +3.2% |
|   256 *  256 *256 |             64.5 |              64.5 |  ~0%  |
|   512 *  512 *512 |              110 |               113 | +2.7% |
|  1024 * 1024 * 64 |              173 |               174 | +0.6% |

Small, consistent win on the long K shapes where B-side L2->L1 traffic
matters; flat on the K=64 shape and the saturating K=256 shape.

Test plan:

* `cargo test --release -p tract-linalg --lib avx512amx_mmm_i32_8x8`
  on ER: **114 passed, 0 failed** -- the full AMX MMM property suite
  (i8i8 frame::prop, i32i32 frame::prop, fuse::prop, store_i32/i8,
  return_q_scale_*) confirms prefetches did not change kernel semantics.
avx512amx_mmm_i32_16x16 hits the maximum AMX i8 tile geometry (16 rows x
64 colsb = 1024 B per tile, both tmm1 A and tmm2 B). One `tdpbssd` now
does 16 * 16 * 64 = 16384 mul-adds vs the 8x8 sibling's 4096 -- a 4x
work-per-instruction gain, expected to translate to ~2x throughput on
512x512x512 / 1024x1024x64 (the 8x8 path is already memory-bound after
the prefetch tuning).

Register layout: ROW-MAJOR accumulators (zmm{m} = row m of C with 16 i32
lanes for n=0..15). This matches `tilestored`'s output layout directly,
so the hot path (Clear -> AddMatMul -> Store/store_strides_i32_row_contig)
needs zero transposes. The 16x16 zmm transpose that a col-major layout
would have required is ~30 cross-lane permutes.

Epilogue surface re-implemented for AVX-512 zmm:
  - scalar / per_row / per_col elementwise ops (zmm broadcasts)
  - leaky_relu via vpcmpgtd mask + vpblendmd
  - 6x q_scale rounding policies (vpsignd has no AVX-512 form; emulated
    via vpcmpgtd k1, 0, acc + vpsubd + vpblendmd)
  - 6x q_shr rounding policies + q_shl (vpsravd / vpsllvd zmm)
  - Store: row-contig fast path (1 vmovdqu32 or vpmovdb per row), generic
    scalar fallback for arbitrary strides
  - AddUnicast: gather via vpgatherdd with index = lane * col_stride
  - LoadTile: gather from col-major scratch with constant index vector
  - AddRowColProducts: outer product via row_data[m] broadcast x col_data

A reuses PackedAmxA(16); B reuses PackedI8K4(16). Both packers are r-
generic (K-padded to multiples of 64; K=4-inner block of 16 N-cols).

The 16x16 is plugged as the primary `qmmm_i32` dispatch target whenever
`has_amx_int8` is true; the 8x8 stays registered as `mmm_impls` so the
dispatcher can pick it for smaller problems. Property-test surface mirrors
the 8x8: 114 tests, skip-pass on non-AMX hosts via the runtime gate.
Adds the new 16x16 kernel alongside the existing avx2 / avx512vnni /
avx512amx_8x8 entries so reviewers running the bench on Sapphire Rapids+
can see the per-shape throughput delta between the two AMX variants
(8x8 vs 16x16) on the same M/K/N points (64x256x64, 256x256x256,
512x512x512, 1024x1024x64).
oneDNN (the Intel-backed reference AMX implementation in jit_brgemm_amx_uker)
distinguishes two roles in the inner-K loop:

  - A is REUSED across the outer matmul's N-tile sweep, so it benefits from
    being cached in L1.  oneDNN uses `tileloadd` (cached) for A with a
    light `prefetcht0` hint to L1.

  - B STREAMS THROUGH once per kernel call (each N-tile gets its own B
    panel).  For the AMX-typical large-matmul case the per-call B working
    set exceeds L1d (32 KB on Sapphire Rapids).  oneDNN's heuristic
    `try_load_nt = footprint(A)+footprint(B)+footprint(C) >= L1` flips B's
    load to `tileloaddt1` (non-temporal, bypasses L1) and steers B-side
    prefetches at L2 (`prefetcht1`) instead of L1.

The previous 16x16 prefetch block (17 `prefetcht0`'s + `tileloadd` for B)
matched the 8x8 pattern proportionally but over-ran Sapphire Rapids' 16
Line Fill Buffer budget: 1 A-prefetch + 16 B-prefetches + 2 active
tileloadds = 19 in-flight slots demanded, vs 16 available.  That backs up
real loads behind dropped prefetches.

This patch aligns 16x16 with oneDNN's defaults for the large-matmul case:

  - A:  prefetcht0 + tileloadd        (1 LFB for prefetch + 1 for load)
  - B:  6x prefetcht1 + tileloaddt1   (6 LFBs for L2 priming + 1 NT load)
        -> primes the head 384 B of next-iter B-panel (6 of 16 lines);
        the SPR/EMR/GNR HW stream prefetcher reliably covers the
        remaining 10 lines once the 1024-B stride is detected.

Total in-flight per iter: 9 (was 19).  This leaves headroom for the OoO
engine to overlap multiple iterations.  The 8x8 kernel is left untouched
since (a) its existing 9-prefetch pattern already fits the LFB budget,
and (b) its 119 GElem/s @ 512x512x512 on EMR has been validated.

Property tests (avx512amx_mmm_i32_16x16 suite) skip-pass on this CL host
via the runtime gate; will be re-validated on AMX HW.

Refs:
  - oneDNN src/cpu/x64/brgemm/brgemm_utils.cpp (load_nt heuristic)
  - oneDNN src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp (tileloaddt1 use)
  - Intel SDM Vol 1, sec 18.3 (AMX), Vol 3 (XSAVE tile state)
  - chipsandcheese.com SPR deep-dive (LFB count = 16)
…all)

The `qmmm_i32` closure now selects between the 8x8 and 16x16 AMX
kernels based on the (m, k, n) hint -- mirroring oneDNN's BRGEMM
ukernel-variant selection logic, where the MR/NR pair is picked per
problem size rather than fixed at build time.

Rationale:
  - 16x16 (1024 B/tile, 16384 mul-adds per tdpbssd) wins on big problems
    where the per-call setup cost (ldtilecfg + 16-row epilogue scratch)
    is amortised across many K-iters.
  - 8x8 (256 B/tile, 4096 mul-adds per tdpbssd) wins on small problems
    where 16x16 would over-pad and pay full epilogue cost on a mostly-
    empty C tile.

Threshold: 16x16 picked iff m >= 16 AND n >= 16 AND k >= 64, all
treating Option<usize>::None ("streaming / unknown") as "large enough"
since dynamic-shape models default to throughput-champion 16x16.

The exact crossover should be re-validated on AMX HW; this is a
heuristic best-guess until then.
Adds `cache_sizes() -> CacheSizes { l1d_bytes, l2_bytes, l3_bytes }`, the
analog of oneDNN's `platform::get_per_core_cache_size`. Probes CPUID leaf
4 deterministic cache parameters iteratively over sub-leaves until a
zero cache-type byte; computes per-cache size as
(ways+1) * (partitions+1) * (line_size+1) * (sets+1). Memoised behind a
OnceLock since the values are constant for the lifetime of the process.

Currently used at AMX-int8 plug time to log the detected cache hierarchy
(useful for diagnostics + future tuning); the public API exists so that
future shape-adaptive kernel variants can mirror oneDNN's `try_load_nt =
footprint(A)+footprint(B)+footprint(C) >= L1` heuristic at runtime.

This makes the existing 16x16 kernel's static "use tileloaddt1 + L2
prefetch for B" choice (currently hardcoded to the AMX-typical large-
working-set case) honest about the assumption, and gives us the
instrument to add a small-working-set 16x16 variant later if HW bench
data shows it's worth it.
Adds an AMX-BF16 path to mmm_f32 mirroring the int8 16x16 work: f32 inputs
are truncated to bf16 at pack time (round-to-nearest-even, matching Intel
VCVTNEPS2BF16) and the inner loop calls TDPBF16PS (16M x 16N x 32K bf16 =
8192 fma per instruction). The f32 accumulators differ from a pure-f32 FMA
reference by ~1/2^8 relative per multiply (bf16 = 8 mantissa bits vs f32's
23) -- the same precision profile as oneDNN "fast-math" f32 matmul on AMX,
acceptable for inference workloads (LLMs, CNNs) that already tolerate bf16.

* avx512amx_mmm_f32_16x16.S.j2 -- 16x16 row-major-zmm-accumulator kernel
  with the same oneDNN-style prefetch pattern as the i32 16x16 (A: tileloadd
  + 1x prefetcht0, B: tileloaddt1 + 6x prefetcht1). q_scale/q_shr/q_shl jump
  to "unsupported" (not meaningful for f32).
* amx_bf16.rs -- PackedAmxBf16A (A side, M-major within panel, K padded to
  multiples of 32 bf16) and PackedBf16K2 (B side, K=2-inner analog of
  PackedI8K4). f32_to_bf16_rne() does the lane-level conversion at pack time.
* amx.rs -- request_amx_tile_xcomp_perm() extracted so the int8 and bf16
  has_*() gates share the single XSAVE permission request (arch_prctl is
  process-wide; only one call is needed for both data types).
* build.rs -- dummy_bf16.S probe checks the assembler accepts TDPBF16PS,
  gated independently of the int8 probe so a future AMX-FP16/FP8 (Diamond
  Rapids+) probe slots in alongside. Sets tract_amx_bf16 cfg on success.
* mmm.rs -- registers the kernel as packing[1]=f32f32_bf16 and overlays the
  AMX 16x16 path onto mmm_f32 for problems where every axis comfortably
  fills at least one tile (M>=16, N>=16, K>=32). Smaller problems defer to
  the prior AVX-512/FMA picker, same shape-adaptive pattern as qmmm_i32.
Mirrors the i32 amx bench (same shapes: 64x256x64 / 256x256x256 /
512x512x512 / 1024x1024x64) but exercises the bf16 path. Three columns:
fma f32 16x6 (AVX2 baseline), avx512 f32 16x12 (AVX-512 reference), and
the new AMX bf16 16x16 kernel under packing index 1 (the f32->bf16
RNE pack path). Skipped when has_amx_bf16() returns false and at build
time when tract_amx_bf16 is unset.
Forks avx512vnni_mmm_i32_8x8.S.j2 with the {vex} instruction prefix on
VPDPBUSD so gas emits the AVX-VNNI (VEX) encoding instead of the
AVX-512-VNNI (EVEX) encoding it defaults to. Body is otherwise byte-for-
byte identical: 8x8 ymm accumulators, PackedI8K4 inner-K (4-byte dot),
+128 bias trick to bridge VPDPBUSD's u8 x s8 into the AVX2 s8 x s8
reference. This ships VPDPBUSD-accelerated i8 GEMM to AVX2-only Atom-class
cores that don't have AVX-512:

  - Alder Lake / Raptor Lake / Meteor Lake E-cores (Gracemont, Crestmont)
  - Sierra Forest (Sierra Glen)
  - Clearwater Forest (Darkmont) -- the gap called out by the user

* avxvnni_mmm_i32_8x8.S.j2 -- the kernel; only the two VPDPBUSD lines
  are prefixed with {vex}.
* avxvnni.rs -- runtime gate via CPUID leaf 7 sub-leaf 1 EAX bit 4 (the
  AVX-VNNI capability bit). Memoised; no XSAVE permission needed (unlike
  AMX, AVX-VNNI uses no extended state).
* build.rs -- assembler probe (dummy_avxvnni.S) checks gas accepts the
  {vex} prefix on VPDPBUSD (binutils 2.36+). Sets tract_avxvnni cfg on
  success; pulls avxvnni_*.S.j2 out of the bulk -mfma compile so older
  toolchains aren't broken.
* mmm.rs -- registers the kernel as packing[1]=i8i8 (same PackedI8K4 as
  AVX-512-VNNI for layout compatibility) and plugs qmmm_i32 to it when
  AVX-VNNI is the highest-quality int8 ISA. On big cores that have both
  AVX-512-VNNI and AVX-VNNI (Sapphire Rapids+, some Alder Lake P-cores)
  plug_avx512vnni runs after this and clobbers qmmm_i32 with the EVEX
  kernel; on AVX-VNNI-only Atom cores this path stays.

All 114 kernel tests pass on this AVX-512-VNNI host (the kernel runs --
big cores with AVX-512-VNNI also carry AVX-VNNI on Sapphire Rapids+; on
this Cascade Lake-class CPU the runtime gate stays off and the kernel
is exercised only via the test harness's direct call path).
Two small finishers on the AMX / AVX-VNNI work:

* mmm.rs -- boost(|| 100) on both AMX 16x16 kernels (i32 and f32). The
  einsum kernel-selection scorer is `-quality_cost*1000 + boost`, so all
  current ManuallyOptimized kernels tie at score 0. The boost makes the
  optimizer prefer the AMX 16x16 tile over the equally-tier'd AVX-512-VNNI
  (i32) and AVX-512 / FMA (f32) candidates when at least one dim is
  symbolic and the shape-adaptive `qmmm_i32` / `mmm_f32` picker isn't
  the path of selection.

* benches/avxvnni_i32.rs -- mirror of amx_i32: same shapes
  (64x256x64 / 256x256x256 / 512x512x512 / 1024x1024x64), three columns
  (avx2 baseline, avxvnni new, avx512vnni reference when present).
  Skipped when has_avxvnni() returns false (CPUID 7.1 EAX.4 unset).
  Ready for an Atom-class host (Sierra Forest / Clearwater Forest /
  Alder Lake-E) to drop in and measure the VPDPBUSD speedup over the
  vpmaddubsw-emulation AVX2 path.
The inline scalar_sub / per_row_sub / per_col_sub handlers (and their
_flipped twins) in the AMX int8 and bf16 16x16 kernels had their operand
order reversed relative to the shared fma_mmm_ymm_ops.j2 convention:
non-flipped sub must compute `operand - acc`, flipped `acc - operand`.
Both kernels did the opposite, so a ScalarSub / per-row / per-col subtract
fused into the matmul produced negated results.

The bug never surfaced because these kernels' test suites are skipped on
hosts without AMX (is_supported_here() == false), and the dev/CI hardware
here is Cascade Lake-class (AVX-512-VNNI, no AMX). It was caught by the new
avx512vnni_mmm_i32_16x16 kernel, which reuses the same epilogue and whose
tests DO run on VNNI hardware: scalar_sub / per_row_sub / per_col_sub each
failed with exactly negated output. The commutative ops (min/max/mul/add)
were unaffected.

https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV
avx512vnni_mmm_i32_16x16 is the zmm-wide (512-bit) sibling of the existing
avx512vnni_mmm_i32_8x8: 16 row-major i32 accumulators (zmm{m} = row m of C),
one VPDPBUSD per row per K=4 block over PackedI8K4(16) for both A and B, so
it issues 1024 mul-adds/block -- 2x the 8x8 ymm kernel's work per iteration.
Same u8 x s8 +128 A-bias trick as the 8x8 kernel, but the row-major layout
makes the per-column 128*sum_k(B) correction a single vector subtract.

Built by adapting the AMX 16x16 i32 template (whose zmm row-major epilogue is
reused verbatim), replacing the AMX tile inner loop with the VPDPBUSD loop and
dropping the tile-config preamble / tilerelease. Because the file is named
avx512vnni_* it stays in the generic -mfma assembler bulk-compile (VPDPBUSD
needs no special gas gating, same as the 8x8 kernel).

Wired into plug_avx512vnni with a shape-adaptive qmmm_i32 picker (16x16 when
M,N >= 16, else 8x8) mirroring the AMX int8 path, plus boost(50) so the einsum
scorer prefers it over the 8x8 for unknown shapes while staying below the AMX
kernels' boost(100) (AMX still wins when both are present). This gives big
cores with AVX-512-VNNI but no AMX (Cascade Lake / Ice Lake / Tiger Lake) a
wider int8 GEMM throughput tier. Added as a third column in the vnni_i32
microbench.

All 114 auto-generated kernel tests (packed-packed i8i8 + i32i32, fused-op
frame, quant rounding, stores, proptest) pass on AVX-512-VNNI hardware.

https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV
Maintainer note covering the AVX2 / AVX-512-VNNI (8x8 + 16x16) / AVX-VNNI /
AMX (8x8 + 16x16 int8, 16x16 bf16) kernel family: the u8 x s8 +128 bias trick,
the PackedI8K4 / PackedAmxA / bf16 packing layouts, the build.rs assembler-probe
cfg gates (tract_amx_int8 / tract_amx_bf16 / tract_avxvnni), the plug() and
qmmm_i32 dispatch cascade with the einsum scorer boost values, the testing model
and why the AMX sub-handler bug stayed hidden (kernel tests skip when the host
CPU lacks the feature), and a short follow-up list.

https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV
… hosts

Self-contained runbook for a session on a CPU with Intel AMX. Tasks that
session to benchmark every int8/bf16 GEMM kernel in the tree -- the AMX kernels
(int8 8x8 + 16x16, bf16 16x16) and the improved AVX-512-VNNI kernels (8x8 + the
new zmm 16x16) -- and to run the AMX correctness suite, which validates the
AMX 16x16 sub fused-op bugfix that could not be exercised on the non-AMX dev box.

Covers: AMX prerequisites (CPUID amx_*, kernel >= 5.16 for the arch_prctl
XTILEDATA permission), the gotcha that AMX kernel tests silently no-op (report
"ok") when the host can't run AMX, using the benches as the authoritative
runtime gate-check, exact test/bench commands, the bench column layout, the
head-to-head comparisons to report (AMX 16x16 vs VNNI 16x16 etc.), a one-shot
script, and a note that Intel SDE can emulate AMX for correctness but not perf.

https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV
Results from running linalg/AMX_BENCH_RUNBOOK.md on an AMX-capable Xeon
(2026-06-02): AMX-live confirmation; correctness (bugfix 99eb75b validated
on silicon; the 3 bf16 test failures root-caused to an f32-grade harness
tolerance and empirically verified against a bf16 reference -- not a kernel
defect); the three int8/bf16 throughput tables; and the four head-to-head
ratios. Includes a reproducibility note (the AMX host was later reclaimed).

https://claude.ai/code/session_018Hes6yEvk2TSWB26SAJfqT
- name: Configure AWS Credentials
continue-on-error: true
uses: aws-actions/configure-aws-credentials@v6
uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6
- name: Configure AWS Credentials
continue-on-error: true
uses: aws-actions/configure-aws-credentials@v6
uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6
# if: github.repository == 'sonos/tract'
continue-on-error: true
uses: aws-actions/configure-aws-credentials@v6
uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6
- name: Configure AWS Credentials
continue-on-error: true
uses: aws-actions/configure-aws-credentials@v6
uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6
- name: Configure AWS Credentials
continue-on-error: true
uses: aws-actions/configure-aws-credentials@v6
uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6

- name: Upload asset
uses: softprops/action-gh-release@v3
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3

- name: Create Release
uses: softprops/action-gh-release@v3
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
path: dist

- uses: pypa/gh-action-pypi-publish@v1.14.0
- uses: pypa/gh-action-pypi-publish@cef221092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants