diff --git a/core/src/ops/matmul/pack.rs b/core/src/ops/matmul/pack.rs index 8a3bcc93ae..934abfee86 100644 --- a/core/src/ops/matmul/pack.rs +++ b/core/src/ops/matmul/pack.rs @@ -7,12 +7,15 @@ use tract_linalg::block_quant::{ }; use tract_linalg::mmm::{MMMInputFormat, MMMInputValue, PackedMatrixStorage}; use tract_linalg::pack::{PackedFormat, PackedI8K4}; +#[cfg(target_arch = "x86_64")] +use tract_linalg::x86_64_fma::amx::PackedAmxA; use super::ModePicker; // Pack one (possibly strided) view with a dynamic packing format. Keeps the // PackedFormat fast path byte-identical; routes the K=4-inner SMOPA packer -// (PackedI8K4) through its view packer. Other formats are unsupported here. +// (PackedI8K4) and the AMX A-side packer (PackedAmxA) through their view +// packers. Other formats are unsupported here. fn pack_view_with( packer: &dyn MMMInputFormat, t: &TensorView, @@ -20,12 +23,16 @@ fn pack_view_with( mn_axis: usize, ) -> TractResult> { if let Some(pf) = packer.downcast_ref::() { - pf.pack_tensor_view(t, k_axis, mn_axis) - } else if let Some(p4) = packer.downcast_ref::() { - p4.pack_view(t, k_axis, mn_axis) - } else { - bail!("OptMatMulPack does not support packing format {packer:?}") + return pf.pack_tensor_view(t, k_axis, mn_axis); } + if let Some(p4) = packer.downcast_ref::() { + return p4.pack_view(t, k_axis, mn_axis); + } + #[cfg(target_arch = "x86_64")] + if let Some(pa) = packer.downcast_ref::() { + return pa.pack_view(t, k_axis, mn_axis); + } + bail!("OptMatMulPack does not support packing format {packer:?}") } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/linalg/AMX_BENCH_RESULTS.md b/linalg/AMX_BENCH_RESULTS.md new file mode 100644 index 0000000000..b1261e61d3 --- /dev/null +++ b/linalg/AMX_BENCH_RESULTS.md @@ -0,0 +1,85 @@ +# AMX validation & benchmark results + +Run of `linalg/AMX_BENCH_RUNBOOK.md` on real Intel AMX hardware. + +- **Host:** `Intel(R) Xeon(R) Processor @ 2.10GHz` (Sapphire/Emerald Rapids-class), 4 vCPU +- **ISA:** `amx_tile amx_int8 amx_bf16` + AVX-512-VNNI; kernel `6.18.5` (≥5.16); binutils `2.42`; rustc `1.94.1` +- **Branch:** `claude/zealous-galileo-fEQ3d` @ `7a23812` +- **Method:** `cargo bench`, default criterion sampling, pinned to core 2 (`taskset -c 2`), idle box (load ≈ 1.0) +- **Date:** 2026-06-02 + +## 1. AMX live confirmation ✅ + +Gate-check (`amx_i32` bench) produced `avx512amx_8x8`/`avx512amx_16x16` columns with real `thrpt:` numbers — **neither** "tract not built with AMX" (build probe) **nor** "AMX not available, skipping" (runtime CPUID + `arch_prctl` XTILEDATA gate) appeared. AMX is genuinely exercised. + +## 2. Correctness + +| Suite | Result | +|---|---| +| `cargo test -p tract-linalg --lib avx512amx` | **297 passed; 3 failed** | +| `cargo test -p tract-linalg --lib x86_64_fma::mmm` | **1833 passed; 3 failed** | + +**Bugfix `99eb75b9d` VALIDATED on silicon** ✅ — every `scalar_sub` / `per_row_sub` / `per_col_sub` (+`_f`) test passed for **both** `avx512amx_mmm_i32_16x16` and `avx512amx_mmm_f32_16x16`. + +**3 failures — all in the AMX bf16 path** (`avx512amx_mmm_f32_16x16::f32f32_bf16`): `fuse::prop`, `frame::prop`, `fuse::packed_packed_bug_3`. + +**Root cause = test-harness tolerance, NOT a kernel defect.** `packed_packed.rs:367` selects the comparison tolerance from the **accumulator** dtype: +```rust +let app = if K::Acc::datum_type() == f16::datum_type() + { Approximation::SuperApproximate } else { Approximation::Approximate }; +``` +This kernel accumulates in **f32** (TDPBF16PS: bf16×bf16→f32), so it gets `Approximate` = `(atol 1e-4, rtol 5e-4, 0 outliers)` — but the `f32f32_bf16` packing truncates inputs to bf16 (~2⁻⁸ ≈ 0.39% rel). bf16-grade error is checked against an f32-grade bar with zero tolerated outliers ⇒ guaranteed failure. `SuperApproximate` `(atol 0.1, rtol 0.05, 1e-4 outliers)` would pass. The structurally identical int8 16×16 kernel passes 100%. + +**Proposed fix:** in `check()`, pick `SuperApproximate` when the packing is bf16-based, not only when `K::Acc == f16`. + +**Empirically verified (on the AMX host):** the kernel was run on 7 cases (including the exact `bug_3` input) and compared against an independent **bf16-truncated** reference — built with the project's own `f32_to_bf16_rne` — judged by the *same* tight `Approximate` bar: **0 outliers across ~335k output elements** (max abs err ≤ 1.3e-5), versus **282,788 outliers** against a pure-f32 reference. The kernel reproduces "truncate inputs→bf16, accumulate→f32" exactly; the 3 red tests are 100% the f32 oracle, with no kernel defect. + +## 3. Benchmarks — throughput (Gelem/s, point estimate) + +### `amx_i32` — int8 GEMM +| M×K×N | avx2 | avx512vnni (8×8) | avx512amx_8×8 | avx512amx_16×16 | +|---|---:|---:|---:|---:| +| 64×256×64 | 0.41 | 11.21 | 68.41 | **233.64** | +| 256×256×256 | 0.41 | 11.31 | 68.47 | **237.29** | +| 512×512×512 | 0.39 | 8.94 † | 112.86 | **228.15** | +| 1024×1024×64 | 0.41 | 34.84 | 178.42 | **279.51** | + +### `amx_f32` — bf16→f32 GEMM +| M×K×N | fma_16×6 | avx512_16×12 | avx512amx_bf16_16×16 | +|---|---:|---:|---:| +| 64×256×64 | 37.12 | 64.31 | **207.35** | +| 256×256×256 | 37.90 | 71.90 | **225.74** | +| 512×512×512 | 39.37 | 64.69 | **348.38** | +| 1024×1024×64 | 36.85 | 59.22 | **318.36** | + +### `vnni_i32` — int8 GEMM (new 16×16 in isolation) +| M×K×N | avx2 | avx512vnni (8×8) | avx512vnni_16×16 | +|---|---:|---:|---:| +| 64×256×64 | 0.41 | 10.90 | **135.74** | +| 256×256×256 | 0.40 | 10.78 | **134.92** | +| 512×512×512 | 0.40 | 20.53 | **154.39** | +| 1024×1024×64 | 0.41 | 34.77 | **161.27** | + +† `avx512vnni`@512³ read 8.94 here vs 20.53 in `vnni_i32` (same kernel/shape). Treat **20.53** as the credible value (it fits the size trend 11.3→20.5→34.8); 8.94 was an outlier. A higher-sampling re-measure was attempted but could not complete — see §6. + +## 4. Head-to-head ratios + +| Comparison | 64×256×64 | 256×256×256 | 512×512×512 | 1024×1024×64 | +|---|---:|---:|---:|---:| +| **AMX 16×16 ÷ VNNI 16×16** (int8, same CPU) | 1.72× | 1.76× | 1.48× | 1.73× | +| **AMX 16×16 ÷ AMX 8×8** (int8) | 3.42× | 3.47× | 2.02× | 1.57× | +| **VNNI 16×16 ÷ VNNI 8×8** (int8) | 12.45× | 12.51× | 7.52× | 4.64× | +| **AMX bf16 16×16 ÷ AVX-512 f32 16×12** | 3.22× | 3.14× | 5.39× | 5.38× | +| *(bonus) AMX bf16 ÷ FMA f32 16×6* | 5.59× | 5.96× | 8.85× | 8.64× | + +## 5. Findings + +1. **AMX int8 16×16 wins everywhere — justifies `boost(100)` > VNNI `boost(50)`.** 1.48–1.76× over the new VNNI 16×16 on the *same* silicon. Dispatch ordering is correct. +2. **AMX 16×16 vs 8×8: 1.57–3.47×.** 16×16 leads on all tested shapes; the 4×-work/instr advantage is largest on compact shapes (3.4× @ 64×256×64) and narrowest on tall-skinny 1024×1024×64 (1.57×, N=64). No tested shape favors 8×8 — any crossover lives below this suite (smaller M or N<16). `qmmm_i32` defaulting to 16×16 here is sound. +3. **VNNI 16×16 vs 8×8: 4.64–12.5× — far above the dev box's 1.3–2.1×.** Likely the 8×8 kernel's ymm (256-bit) accumulators vs the new kernel's zmm (512-bit), amplified on Sapphire Rapids (no AVX-512 license downclock that Cascade Lake suffers). Strongly validates the new kernel; the magnitude warrants one sanity re-check (see #4). +4. **Data-quality flag (resolved by inspection):** `avx512vnni` 8×8 @ 512³ read 8.94 (in `amx_i32`) vs 20.53 (in `vnni_i32`) — a 2.3× swing on the same kernel/shape. **20.53 is the credible figure** (it continues the monotone size trend 11.3 @ 256³ → 20.5 @ 512³ → 34.8 @ 1024×1024×64; 8.94 breaks it). A `--sample-size 200` re-measure was launched but the AMX host was reclaimed before it could run (see §6); the ratio table already uses the consistent 20.53 pairing. AMX columns were stable across runs. +5. **AMX bf16 is 3.1–5.4× the AVX-512 f32 kernel** (5.6–8.9× over FMA), scaling up on larger shapes (348 Gelem/s @ 512³) — with the documented bf16 precision trade (see §2 and `X86_64_INT8_GEMM.md`). + +## 6. Reproducibility note + +Numbers were collected **2026-06-02** on an AMX-capable `Intel(R) Xeon(R) @ 2.10GHz` (`amx_tile/int8/bf16` + AVX-512-VNNI, kernel 6.18.5). The ephemeral session container was subsequently reclaimed and re-provisioned onto a different `Intel(R) Xeon(R) @ 2.80GHz` with **neither AMX nor AVX-512-VNNI** (only `avx512f`), on which `amx_i32`/`vnni_i32` both short-circuit and skip — so the one outstanding re-measure (VNNI-8×8 @ 512³) could not be completed in this session. To reproduce or extend, run on an AMX host (Sapphire Rapids / Emerald Rapids / Granite Rapids Xeon, or Xeon Max) following `linalg/AMX_BENCH_RUNBOOK.md`. diff --git a/linalg/AMX_BENCH_RUNBOOK.md b/linalg/AMX_BENCH_RUNBOOK.md new file mode 100644 index 0000000000..c5fde39c07 --- /dev/null +++ b/linalg/AMX_BENCH_RUNBOOK.md @@ -0,0 +1,211 @@ +# AMX validation & benchmark runbook + +**For: a Claude Code session (or human) on an x86_64 CPU that has Intel AMX.** + +The kernel work on branch `claude/zealous-galileo-fEQ3d` was developed on a +Cascade Lake-class container (AVX-512-VNNI, **no AMX**). Everything that can run +without AMX is already validated there. This runbook covers the two things that +box **could not** do and that need a real AMX CPU. + +## Your task + +**Benchmark every int8 / bf16 GEMM kernel in this tree on this AMX CPU — all the +AMX kernels *and* the AVX-512-VNNI kernels we just improved — and run the AMX +correctness suite.** Full kernel inventory to cover: + +| Kernel | ISA | Covered by bench | +|---|---|---| +| `avx512amx_mmm_i32_8x8` | AMX int8 (`tdpbssd`) | `amx_i32` | +| `avx512amx_mmm_i32_16x16` | AMX int8 (`tdpbssd`) | `amx_i32` | +| `avx512amx_mmm_f32_16x16` | AMX bf16→f32 (`tdpbf16ps`) | `amx_f32` | +| `avx512vnni_mmm_i32_8x8` | AVX-512-VNNI (`vpdpbusd`) | `vnni_i32`, `amx_i32` | +| **`avx512vnni_mmm_i32_16x16`** ← new | AVX-512-VNNI (`vpdpbusd`, zmm) | `vnni_i32` | +| `avx2_mmm_i32_8x8` (baseline) | AVX2 | both i32 benches | + +Running the three benches in Step 4 covers all of the above. Yes — bench the VNNI +kernels here too: an AMX CPU (Sapphire Rapids+) also has AVX-512-VNNI, so it's the +one place you can measure AMX 16×16 and VNNI 16×16 **on the same silicon** and see +how much AMX actually wins. + +In addition, this AMX CPU is the only place that can: + +1. **Correctness-test the AMX kernels** — including a recent bugfix to the AMX + 16×16 `sub` fused-op handlers that was invisible on non-AMX hardware. +2. **Benchmark** the AMX int8 / bf16 kernels and the new AVX-512-VNNI 16×16 + kernel head-to-head. + +> ⚠️ **Most important caveat:** every AMX kernel test short-circuits to "ok" when +> the host can't run AMX (`is_supported_here()` is false). So a green +> `cargo test` on the wrong box proves **nothing**. You must first confirm AMX is +> actually live (Step 2). The **benchmarks are the authoritative gate-check** — +> they print an explicit "AMX … not available, skipping" message and emit no AMX +> columns if the gate is closed. + +--- + +## 0. Prerequisites + +| Requirement | Why | Check | +|---|---|---| +| AMX-capable CPU (Sapphire Rapids / Emerald Rapids / Granite Rapids Xeon, or Xeon Max) | `tdpbssd` / `tdpbf16ps` | `grep -o 'amx[_a-z]*' /proc/cpuinfo \| sort -u` → expect `amx_bf16 amx_int8 amx_tile` | +| Linux kernel ≥ 5.16 | AMX tile-data XSAVE permission via `arch_prctl(ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)` | `uname -r` | +| binutils/gas ≥ 2.34 (≥ 2.36 ideal) | assembles AMX mnemonics (and `{vex}` for AVX-VNNI) | `as --version` | +| Rust stable (dev used 1.94–1.96) | build | `cargo --version` | + +If `/proc/cpuinfo` shows no `amx_*` flags, this is the wrong machine — stop here. + +--- + +## 1. Get the code + +**Fresh clone (preferred):** +```sh +git clone https://github.com/czoli1976/tract.git +cd tract +git checkout claude/zealous-galileo-fEQ3d +``` + +**Existing checkout:** +```sh +git fetch origin claude/zealous-galileo-fEQ3d +git checkout claude/zealous-galileo-fEQ3d && git pull +# IMPORTANT when pulling into a checkout that was built before: the new kernel +# template (avx512vnni_mmm_i32_16x16.S.j2) may not trigger a build-script rerun +# (build.rs emits per-file rerun-if-changed). Force it once: +touch linalg/build.rs +``` +(A fresh clone needs no `touch` — it renders every template on first build.) + +--- + +## 2. Confirm AMX is actually live (do this first) + +The AMX kernels are gated by CPUID **and** the kernel granting tile-data XSAVE +permission. The benchmark is the cleanest runtime probe — if AMX is unavailable +it prints a skip line instead of numbers: + +```sh +cargo bench -p tract-linalg --bench amx_i32 -- --warm-up-time 0.2 --measurement-time 0.5 --sample-size 10 2>&1 | head -20 +``` + +- ✅ **Good:** you see `avx512amx_8x8` and `avx512amx_16x16` lines with `thrpt:`. +- ❌ **Bad:** `AMX int8 not available (CPUID + arch_prctl gate failed), skipping` + → AMX isn't usable (check kernel ≥ 5.16, not in a VM that masks AMX, XSAVE + permission not blocked by a seccomp/container policy). Don't proceed — the + correctness tests would silently no-op. + +Optional: `RUST_LOG=info cargo test -p tract-linalg --lib avx512amx_mmm_i32_16x16 -- --nocapture 2>&1 | grep -i activated` +should log `qmmm_i32: x86_64/avx512amx_int8 (16x16 + 8x8 adaptive) activated`. + +--- + +## 3. Correctness validation (the priority) + +Only meaningful once Step 2 confirms AMX is live. + +```sh +# All three AMX kernel suites: int8 8x8, int8 16x16, bf16 16x16. +cargo test -p tract-linalg --lib avx512amx 2>&1 | tail -30 + +# Full x86_64 mmm suite (AMX + VNNI + AVX2 + FMA + AVX-512), for completeness. +cargo test -p tract-linalg --lib x86_64_fma::mmm 2>&1 | tail -5 +``` + +**Expected:** `test result: ok. passed; 0 failed`. + +**What this specifically proves (and the dev box couldn't):** the +`scalar_sub` / `per_row_sub` / `per_col_sub` (+ `_flipped`) fused-op tests for +`test_avx512amx_mmm_i32_16x16` and `test_avx512amx_mmm_f32_16x16` **actually +execute**. Those guard commit `99eb75b9d`, which fixed swapped operands in the +AMX `sub` handlers (they were computing `acc − operand` instead of +`operand − acc`, i.e. negated results). This fix is currently only +build-verified — **this run is what confirms it on real silicon.** + +--- + +## 4. Benchmarks + +On real hardware use default sampling (drop the reduced flags) and pin a core for +stable numbers. Idle box, turbo/frequency-scaling fixed if you can. + +```sh +# int8: AVX2 vs VNNI 8x8 vs AMX 8x8 vs AMX 16x16 +taskset -c 2 cargo bench -p tract-linalg --bench amx_i32 + +# f32 via bf16: FMA 16x6 vs AVX-512 16x12 vs AMX-BF16 16x16 +taskset -c 2 cargo bench -p tract-linalg --bench amx_f32 + +# the new kernel in isolation: AVX2 vs VNNI 8x8 vs VNNI 16x16 +taskset -c 2 cargo bench -p tract-linalg --bench vnni_i32 +``` + +Bench layout (group `… /packed_packed`, shapes `64x256x64`, `256x256x256`, +`512x512x512`, `1024x1024x64`, throughput in `Gelem/s`): + +| Bench | Columns | +|---|---| +| `amx_i32` | `avx2`, `avx512vnni`, `avx512amx_8x8`, `avx512amx_16x16` | +| `amx_f32` | `fma_16x6`, `avx512_16x12`, `avx512amx_bf16_16x16` | +| `vnni_i32` | `avx2`, `avx512vnni` (8×8), `avx512vnni_16x16` | + +Criterion writes HTML reports under `target/criterion/`. + +--- + +## 5. What to report back + +**Correctness** +- Confirm AMX was live (Step 2 showed AMX columns / cpuinfo has `amx_int8`). +- `cargo test … avx512amx` result line (`N passed; 0 failed`), confirming the + AMX `*_sub` fused-op tests passed → bugfix `99eb75b9d` validated on hardware. + +**Performance** — the `thrpt:` (Gelem/s) per shape per column for all three +benches, plus these head-to-head reads: + +1. **AMX 16×16 vs VNNI 16×16** (compare `amx_i32`'s `avx512amx_16x16` against + `vnni_i32`'s `avx512vnni_16x16`, same shapes). AMX should win — that justifies + the dispatch ordering (`boost(100)` for AMX 16×16 > `boost(50)` for VNNI + 16×16). Report the ratio. +2. **AMX 16×16 vs AMX 8×8** — the 4×-work-per-instruction claim and where 8×8 + wins on small shapes (informs the `qmmm_i32` 16/8 crossover). +3. **VNNI 16×16 vs 8×8** — does the ~1.3–2.1× measured on Cascade Lake hold on + this CPU too? +4. **AMX-BF16 16×16 vs AVX-512 f32 16×12** — bf16 throughput win (with the bf16 + precision trade-off noted in `linalg/X86_64_INT8_GEMM.md`). + +--- + +## Appendix A — one-shot script + +```sh +set -e +echo "## CPU AMX flags:"; grep -o 'amx[_a-z]*' /proc/cpuinfo | sort -u || true +echo "## kernel:"; uname -r +echo "## gate check (expect AMX columns, not a skip message):" +cargo bench -p tract-linalg --bench amx_i32 -- --warm-up-time 0.2 --measurement-time 0.5 --sample-size 10 2>&1 | grep -iE "amx|skipping|thrpt" | head +echo "## correctness:" +cargo test -p tract-linalg --lib avx512amx 2>&1 | tail -3 +cargo test -p tract-linalg --lib x86_64_fma::mmm 2>&1 | tail -3 +echo "## full benches:" +taskset -c 2 cargo bench -p tract-linalg --bench amx_i32 +taskset -c 2 cargo bench -p tract-linalg --bench amx_f32 +taskset -c 2 cargo bench -p tract-linalg --bench vnni_i32 +``` + +## Appendix B — what's on this branch + +Three commits on top of the prior AMX/VNNI work: + +| Commit | Summary | +|---|---| +| `9e8f1c5aa` | doc: `linalg/X86_64_INT8_GEMM.md` — the full int8 GEMM kernel cascade | +| `26726db8e` | **feat**: `avx512vnni_mmm_i32_16x16` — zmm-wide int8 VNNI kernel (1.3–2.1× over 8×8 on Cascade Lake) | +| `99eb75b9d` | **fix**: swapped operands in AMX 16×16 `sub` fused-op handlers (int8 + bf16) — **needs AMX to validate** | + +Background and the kernel-selection/dispatch model: see +`linalg/X86_64_INT8_GEMM.md`. + +> Note on Intel SDE: SDE *can* emulate AMX for **functional/correctness** checks +> on a non-AMX box (`sde64 -spr -- `), but it is **not** a +> performance model — timings under SDE are meaningless. Use it only if no AMX +> hardware is available, and never for the benchmark numbers above. diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index d29cfa27ae..993dc0428d 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -167,3 +167,15 @@ harness = false [[bench]] name = "vnni_i32" harness = false + +[[bench]] +name = "amx_i32" +harness = false + +[[bench]] +name = "amx_f32" +harness = false + +[[bench]] +name = "avxvnni_i32" +harness = false diff --git a/linalg/X86_64_INT8_GEMM.md b/linalg/X86_64_INT8_GEMM.md new file mode 100644 index 0000000000..c960b5f8e5 --- /dev/null +++ b/linalg/X86_64_INT8_GEMM.md @@ -0,0 +1,135 @@ +# x86_64 int8 GEMM kernels + +This note documents the int8 (i32-accumulator) matrix-multiply kernel family for +x86_64, for maintainers touching `linalg/src/x86_64_fma/mmm.rs` (Rust +registration + dispatch) and `linalg/x86_64/fma/*.S.j2` (assembly templates). + +The kernels form a throughput cascade from the portable AVX2 emulation up to +Intel AMX, with AVX-512-VNNI in between. The right kernel is chosen at runtime +from CPUID + (for selection among ties) the einsum kernel scorer. + +## Kernel family + +| Kernel | ISA | Tile M×N | Matmul instr | A packing | B packing | Build gate | +|---|---|---|---|---|---|---| +| `avx2_mmm_i32_8x8` | AVX2 | 8×8 (ymm) | `vpmaddubsw` emulation | `PackedFormat` i8 | `PackedFormat` i8 | always | +| `avx512vnni_mmm_i32_8x8` | AVX-512-VNNI | 8×8 (ymm) | `vpdpbusd` | `PackedI8K4(8)` | `PackedI8K4(8)` | always | +| `avx512vnni_mmm_i32_16x16` | AVX-512-VNNI | 16×16 (zmm) | `vpdpbusd` ×16 rows | `PackedI8K4(16)` | `PackedI8K4(16)` | always | +| `avxvnni_mmm_i32_8x8` | AVX-VNNI (VEX) | 8×8 (ymm) | `{vex} vpdpbusd` | `PackedI8K4(8)` | `PackedI8K4(8)` | `tract_avxvnni` | +| `avx512amx_mmm_i32_8x8` | AMX-INT8 | 8×8 | `tdpbssd` | `PackedAmxA(8)` | `PackedI8K4(8)` | `tract_amx_int8` | +| `avx512amx_mmm_i32_16x16` | AMX-INT8 | 16×16 | `tdpbssd` (16384 MACs) | `PackedAmxA(16)` | `PackedI8K4(16)` | `tract_amx_int8` | +| `avx512amx_mmm_f32_16x16` (f32) | AMX-BF16 | 16×16 | `tdpbf16ps` | `PackedAmxBf16A(16)` | `PackedBf16K2(16)` | `tract_amx_bf16` | + +The two AVX-512-VNNI kernels and the AVX2 one are always compiled (their +mnemonics are in every supported binutils); the AMX and AVX-VNNI kernels are +behind assembler-probe cfgs (see below). + +## The u8×s8 `+128` bias trick (VNNI / AVX-VNNI) + +`vpdpbusd` is **u8 × s8** (unsigned first operand). To compute the s8×s8 product +we need, the kernel offsets the A bytes by `+128` (modular `vpaddb`, making them +u8 in `[0,255]`) and then removes the resulting per-column bias +`128 * sum_k(B[n])` after the K loop. The bias is accumulated cheaply during the +loop with a `vpdpbusd` against an all-`0x01` u8 vector. + +- **8×8 (ymm)** accumulators are *column-major* (`ymm{n}` = column n), so the + bias is computed per column and splatted back with `vpermd`. +- **16×16 (zmm)** accumulators are *row-major* (`zmm{m}` = row m, 16 columns in + the 16 lanes). The per-column bias is then a single lane-aligned vector, so the + correction is one `vpsubd` per row — cleaner and cheaper than the 8×8 path. + +AMX `tdpbssd` is **s8 × s8**, so the AMX int8 kernels need no `+128` trick; their +i32 accumulators are bit-identical to the AVX2 / VNNI reference. + +## Packing formats (see `linalg/src/frame/pack.rs`) + +- **`PackedI8K4(r)`** — K=4-inner. Per K=4 block, `r` elements × 4 K-bytes (= `4r` + bytes); element `e` sits at byte offset `e*4` holding `[e, 4kb..4kb+3]`. K is + zero-padded to a multiple of 4, so kernels read `ceil(k/4)` full blocks safely. +- **`PackedAmxA(r)`** — AMX A layout: per panel of `r` M-rows, row-major within + the panel, K-bytes contiguous, K padded to a multiple of 64 (one `tdpbssd` step + consumes 64 K-bytes). +- **`PackedAmxBf16A` / `PackedBf16K2`** — f32 inputs truncated to bf16 at pack + time (round-to-nearest-even, matching `VCVTNEPS2BF16`) for the AMX-BF16 f32 path. + +## Build-time cfg gating (`linalg/build.rs`) + +Some mnemonics are too new for old toolchains, so each is guarded by an +**assembler probe** that tries to compile a tiny dummy `.S`. The probe sets a cfg +that gates *both* compiling the kernel template and referencing its Rust symbol: + +| cfg | enables | requires | +|---|---|---| +| `tract_amx_int8` | AMX int8 kernels (`tdpbssd`) | gas ≥ 2.34 | +| `tract_amx_bf16` | AMX bf16 kernel (`tdpbf16ps`) | gas ≥ 2.34 | +| `tract_avxvnni` | AVX-VNNI ymm kernel (`{vex}` prefix) | binutils ≥ 2.36 | + +Kernel `.S.j2` templates are sorted by filename prefix in `build.rs`: +`avx512amx_*_i32_*` and `*_f32_*` are pulled into their own gated compiles; +`avxvnni_*` likewise; everything else (including `avx512vnni_*`) stays in the +generic `-mfma` bulk compile. **A new `avx512vnni_*` kernel needs no `build.rs` +change** — but note that adding a brand-new template file may not trigger a +`build.rs` re-run on an incremental build (it emits per-file `rerun-if-changed`), +so `touch linalg/build.rs` after creating one. + +These cfgs reflect **assembler** capability, not the host CPU. A kernel can be +*compiled* (assembler supports the mnemonic) yet never *run* (CPU lacks the +feature) — which matters for tests (below). + +## Dispatch + +`plug()` installs kernels in nested feature order, richest ISA last: + +``` +avx2 → [avxvnni] → fma → avx512f → avx512vnni → [amx_int8] (int8 path) + → [amx_bf16 overlay] (f32 path) +``` + +Each `plug_*` pushes kernels into `ops.mmm_impls` and may set the explicit int8 +picker `ops.qmmm_i32`. Because later plugs overwrite `qmmm_i32`, the best +available ISA wins. The pickers are **shape-adaptive**: the 16×16 tile is the +throughput champion when both M and N fill at least one tile; the 8×8 kernel has +lower per-call setup and wins on small problems. (AMX additionally requires +K ≥ 64; VNNI has no K gate since one `vpdpbusd` step is just 4 K-bytes.) + +For paths that don't go through `qmmm_i32` (symbolic / unknown shapes via the +einsum kernel scorer), selection among equal-quality kernels uses +`-quality_cost*1000 + boost`. All `ManuallyOptimized` kernels tie on quality, so +`boost` breaks the tie: + +| Kernel | boost | +|---|---| +| `avx512amx_mmm_i32_16x16`, `avx512amx_mmm_f32_16x16` | 100 | +| `avx512vnni_mmm_i32_16x16` | 50 | +| all 8×8 kernels | 0 | + +So for unknown shapes: AMX 16×16 ≻ VNNI 16×16 ≻ {VNNI/AMX 8×8}. When AMX is +absent, VNNI 16×16 is the int8 champion. + +## Testing and a cautionary tale + +`MMMExternKernel!` auto-generates a `#[cfg(test)] mod test_` with +packed-packed (per packing), fused-op frame, quant-rounding, store, and proptest +coverage. The harness **skips a kernel when `ker.is_supported_here()` is false** +(runtime CPUID). Consequently **AMX kernel tests only run on AMX hardware.** + +The usual dev/CI host is Cascade Lake-class (AVX-512-VNNI, no AMX), so the AMX +tests are skipped there. That let a swapped-operand bug in the AMX 16×16 `sub` +fused-op handlers (`scalar_sub` / `per_row_sub` / `per_col_sub` and their +`_flipped` twins computed `acc - operand` instead of the correct `operand - acc`) +go unnoticed — until `avx512vnni_mmm_i32_16x16`, which **reuses the same zmm +row-major epilogue** and *does* run on VNNI hardware, exposed it (negated +results). Takeaway: a VNNI kernel that shares an AMX kernel's epilogue effectively +becomes the on-hardware test for that shared epilogue. The convention for the +non-commutative `sub` lives in `linalg/x86_64/fma/fma_mmm_ymm_ops.j2` +(`scalar` / `per_row` / `per_col` macros, `flipped` flag). + +## Possible follow-ups + +- A dispatch integration test asserting `qmmm_i32` selects the 16×16 kernel for + large M,N and the 8×8 for small (no precedent for kernel-selection asserts + in-tree yet; would need a small helper to read back the chosen `MatMatMul`). +- On Sapphire Rapids+ hardware: validate the AMX `sub` fix end-to-end, benchmark + the AMX kernels, and re-check the 16×16/8×8 crossover and the `boost` values. +- A wider AVX-512-BF16 (`vdpbf16ps`) f32 kernel for Cooper Lake-class cores, and + a Q4_0/Q8_0 → s8 packer feeding the AMX/VNNI 16×16 path directly. diff --git a/linalg/benches/amx_f32.rs b/linalg/benches/amx_f32.rs new file mode 100644 index 0000000000..1206bbde57 --- /dev/null +++ b/linalg/benches/amx_f32.rs @@ -0,0 +1,98 @@ +#![allow(dead_code)] +// Kernel-level benchmark: Intel AMX bf16 GEMM for f32 matmul +// (avx512amx_mmm_f32_16x16, TDPBF16PS over 16x16 f32 tile with K=32 bf16 inner) +// vs the AVX-512 f32 16x12 path (avx512_mmm_f32_16x12, FMA) vs the AVX2/FMA +// f32 16x6 path (fma_mmm_f32_16x6). +// +// The AMX path runs the f32f32_bf16 packing (index 1) which truncates f32 to +// bf16 at pack time (round-to-nearest-even, matching VCVTNEPS2BF16) so the f32 +// accumulators carry the bf16 precision profile -- same trade-off as oneDNN +// "fast-math" f32 matmul on AMX. The two reference kernels run their default +// f32 packing (index 0). +// +// Skipped at runtime when has_amx_bf16() returns false (= CPUID lacks +// amx-bf16/tile or the arch_prctl XSAVE permission was denied), and at build +// time when the tract_amx_bf16 cfg was not emitted. +use criterion::*; +use tract_data::internal::*; +use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; + +fn run_kernel(be: &mut Bencher, mmm: &dyn MatMatMul, packing: usize, m: usize, k: usize, n: usize) { + let a = Tensor::zero_dt(DatumType::F32, &[m, k]).unwrap(); + let b = Tensor::zero_dt(DatumType::F32, &[k, n]).unwrap(); + let (pack_a, pack_b) = &mmm.packings()[packing]; + let pa = pack_a.prepare_one(&a, 1, 0).unwrap(); + let pb = pack_b.prepare_one(&b, 0, 1).unwrap(); + let mut scratch = unsafe { mmm.allocate_scratch_space() }; + be.iter_custom(|iters| { + let mut dur = std::time::Duration::default(); + for _ in 0..iters { + let t = std::time::Instant::now(); + unsafe { + mmm.run_with_scratch_space( + m, + n, + scratch.as_mut(), + &[FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing, + }], + ) + .unwrap() + }; + dur += t.elapsed(); + } + dur + }); +} + +fn benches(c: &mut Criterion) { + #[cfg(tract_amx_bf16)] + { + use tract_linalg::x86_64_fma::amx_bf16::has_amx_bf16; + use tract_linalg::x86_64_fma::mmm::*; + if !has_amx_bf16() { + eprintln!("AMX bf16 not available (CPUID + arch_prctl gate failed), skipping"); + return; + } + // Same shapes as amx_i32 so reviewers can directly compare bf16->f32 vs + // i8->i32 throughput at matching M/K/N. K=32 (single tdpbf16ps step) + // and K=64 (one i8 tile) are tested via 256 / 256x256 / 512x512x512. + for &(m, k, n) in + &[(64usize, 256usize, 64usize), (256, 256, 256), (512, 512, 512), (1024, 1024, 64)] + { + let id = format!("{m}x{k}x{n}"); + let mut g = c.benchmark_group("amx_f32/packed_packed"); + g.throughput(Throughput::Elements((m * k * n) as u64)); + // Reference: FMA f32 16x6 (the kernel mmm_f32 picks for these N). + g.bench_with_input(BenchmarkId::new("fma_16x6", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*fma_mmm_f32_16x6.mmm(), 0, m, k, n) + }); + if std::is_x86_feature_detected!("avx512f") { + // Reference: AVX-512 f32 16x12. + g.bench_with_input( + BenchmarkId::new("avx512_16x12", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512_mmm_f32_16x12.mmm(), 0, m, k, n), + ); + } + // AMX bf16 path (packing index 1 = f32f32_bf16: pack-time RNE + // conversion of f32 -> bf16, then TDPBF16PS in the inner loop). + g.bench_with_input( + BenchmarkId::new("avx512amx_bf16_16x16", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512amx_mmm_f32_16x16.mmm(), 1, m, k, n), + ); + g.finish(); + } + } + #[cfg(not(tract_amx_bf16))] + { + eprintln!("tract not built with AMX bf16 support (probe failed at build time)"); + let _ = c; + } +} + +criterion_group!(g, benches); +criterion_main!(g); diff --git a/linalg/benches/amx_i32.rs b/linalg/benches/amx_i32.rs new file mode 100644 index 0000000000..ae5f58ebc8 --- /dev/null +++ b/linalg/benches/amx_i32.rs @@ -0,0 +1,88 @@ +#![allow(dead_code)] +// Kernel-level benchmark: Intel AMX int8 GEMM (avx512amx_mmm_i32_8x8, TDPBSSD over +// 8x8 i32 tile with K=64 inner) vs the AVX-512 VNNI int8 path (avx512vnni_mmm_i32_8x8, +// VPDPBUSD over PackedI8K4 with K=4 inner) vs the AVX2 int8 path +// (avx2_mmm_i32_8x8, vpmaddubsw-style widening). All three run the same i8i8 +// packing index (1) over the same M/K/N so the only difference is the matmul +// inner loop. +use criterion::*; +use tract_data::internal::*; +use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; + +fn run_kernel(be: &mut Bencher, mmm: &dyn MatMatMul, m: usize, k: usize, n: usize) { + let a = Tensor::zero_dt(DatumType::I8, &[m, k]).unwrap(); + let b = Tensor::zero_dt(DatumType::I8, &[k, n]).unwrap(); + let (pack_a, pack_b) = &mmm.packings()[1]; + let pa = pack_a.prepare_one(&a, 1, 0).unwrap(); + let pb = pack_b.prepare_one(&b, 0, 1).unwrap(); + let mut scratch = unsafe { mmm.allocate_scratch_space() }; + be.iter_custom(|iters| { + let mut dur = std::time::Duration::default(); + for _ in 0..iters { + let t = std::time::Instant::now(); + unsafe { + mmm.run_with_scratch_space( + m, + n, + scratch.as_mut(), + &[FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 1, + }], + ) + .unwrap() + }; + dur += t.elapsed(); + } + dur + }); +} + +fn benches(c: &mut Criterion) { + #[cfg(tract_amx_int8)] + { + use tract_linalg::x86_64_fma::amx::has_amx_int8; + use tract_linalg::x86_64_fma::mmm::*; + if !has_amx_int8() { + eprintln!("AMX int8 not available (CPUID + arch_prctl gate failed), skipping"); + return; + } + for &(m, k, n) in + &[(64usize, 256usize, 64usize), (256, 256, 256), (512, 512, 512), (1024, 1024, 64)] + { + let id = format!("{m}x{k}x{n}"); + let mut g = c.benchmark_group("amx_i32/packed_packed"); + g.throughput(Throughput::Elements((m * k * n) as u64)); + g.bench_with_input(BenchmarkId::new("avx2", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avx2_mmm_i32_8x8.mmm(), m, k, n) + }); + if std::is_x86_feature_detected!("avx512vnni") { + g.bench_with_input( + BenchmarkId::new("avx512vnni", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512vnni_mmm_i32_8x8.mmm(), m, k, n), + ); + } + g.bench_with_input( + BenchmarkId::new("avx512amx_8x8", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512amx_mmm_i32_8x8.mmm(), m, k, n), + ); + g.bench_with_input( + BenchmarkId::new("avx512amx_16x16", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512amx_mmm_i32_16x16.mmm(), m, k, n), + ); + g.finish(); + } + } + #[cfg(not(tract_amx_int8))] + { + eprintln!("tract not built with AMX int8 support (probe failed at build time)"); + let _ = c; + } +} + +criterion_group!(g, benches); +criterion_main!(g); diff --git a/linalg/benches/avxvnni_i32.rs b/linalg/benches/avxvnni_i32.rs new file mode 100644 index 0000000000..19deb1f013 --- /dev/null +++ b/linalg/benches/avxvnni_i32.rs @@ -0,0 +1,97 @@ +#![allow(dead_code)] +// Kernel-level benchmark: AVX-VNNI ymm int8 GEMM (avxvnni_mmm_i32_8x8, +// VEX-encoded VPDPBUSD over PackedI8K4 with K=4 inner) vs the AVX2 emulation +// path (avx2_mmm_i32_8x8, vpmaddubsw-style widening). Both kernels run the +// same i8i8 packing index (1) over the same M/K/N so the only difference is +// the matmul inner loop. +// +// Designed for Atom-class hosts that have AVX-VNNI but no AVX-512: +// +// * Alder Lake / Raptor Lake / Meteor Lake E-cores (Gracemont, Crestmont) +// * Sierra Forest (Sierra Glen) +// * Clearwater Forest (Darkmont) +// +// Big cores with both AVX-512-VNNI and AVX-VNNI still run AVX-VNNI here for +// comparison purposes; in production dispatch the EVEX-encoded +// avx512vnni_mmm_i32_8x8 wins on those CPUs because it can later be widened +// to zmm without an ISA-level rewrite. +use criterion::*; +use tract_data::internal::*; +use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; + +fn run_kernel(be: &mut Bencher, mmm: &dyn MatMatMul, m: usize, k: usize, n: usize) { + let a = Tensor::zero_dt(DatumType::I8, &[m, k]).unwrap(); + let b = Tensor::zero_dt(DatumType::I8, &[k, n]).unwrap(); + let (pack_a, pack_b) = &mmm.packings()[1]; + let pa = pack_a.prepare_one(&a, 1, 0).unwrap(); + let pb = pack_b.prepare_one(&b, 0, 1).unwrap(); + let mut scratch = unsafe { mmm.allocate_scratch_space() }; + be.iter_custom(|iters| { + let mut dur = std::time::Duration::default(); + for _ in 0..iters { + let t = std::time::Instant::now(); + unsafe { + mmm.run_with_scratch_space( + m, + n, + scratch.as_mut(), + &[FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 1, + }], + ) + .unwrap() + }; + dur += t.elapsed(); + } + dur + }); +} + +fn benches(c: &mut Criterion) { + #[cfg(tract_avxvnni)] + { + use tract_linalg::x86_64_fma::avxvnni::has_avxvnni; + use tract_linalg::x86_64_fma::mmm::*; + if !has_avxvnni() { + eprintln!("AVX-VNNI not available (CPUID leaf 7.1 EAX.4 unset), skipping"); + return; + } + // Same shapes as amx_i32 / vnni_i32 for direct side-by-side comparison. + for &(m, k, n) in + &[(64usize, 256usize, 64usize), (256, 256, 256), (512, 512, 512), (1024, 1024, 64)] + { + let id = format!("{m}x{k}x{n}"); + let mut g = c.benchmark_group("avxvnni_i32/packed_packed"); + g.throughput(Throughput::Elements((m * k * n) as u64)); + g.bench_with_input(BenchmarkId::new("avx2", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avx2_mmm_i32_8x8.mmm(), m, k, n) + }); + g.bench_with_input(BenchmarkId::new("avxvnni", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avxvnni_mmm_i32_8x8.mmm(), m, k, n) + }); + // When the same host also reports AVX-512-VNNI, include it as a + // reference point: the same kernel body runs as EVEX/zmm-encoded + // VPDPBUSD, which should match the AVX-VNNI throughput on Sapphire + // Rapids+ but can diverge on Cooper/Cascade Lake where the EVEX + // decoder is on the AVX-512 fused unit. + if std::is_x86_feature_detected!("avx512vnni") { + g.bench_with_input( + BenchmarkId::new("avx512vnni", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512vnni_mmm_i32_8x8.mmm(), m, k, n), + ); + } + g.finish(); + } + } + #[cfg(not(tract_avxvnni))] + { + eprintln!("tract not built with AVX-VNNI support (probe failed at build time)"); + let _ = c; + } +} + +criterion_group!(g, benches); +criterion_main!(g); diff --git a/linalg/benches/vnni_i32.rs b/linalg/benches/vnni_i32.rs index 59e6f01676..6901427f4a 100644 --- a/linalg/benches/vnni_i32.rs +++ b/linalg/benches/vnni_i32.rs @@ -1,8 +1,10 @@ #![allow(dead_code)] -// Kernel-level benchmark: AVX-512 VNNI int8 GEMM (avx512vnni_mmm_i32_8x8, VPDPBUSD -// over the K=4-inner PackedI8K4 layout) vs the AVX2 int8 path (avx2_mmm_i32_8x8, -// vpmaddubsw-style widening). Both run the i8i8 packing (index 1) over the same -// M/K/N so the only difference is the matmul inner loop. +// Kernel-level benchmark: AVX-512 VNNI int8 GEMM over the K=4-inner PackedI8K4 +// layout (VPDPBUSD) vs the AVX2 int8 path (avx2_mmm_i32_8x8, vpmaddubsw-style +// widening). Three columns: the AVX2 baseline, the 8x8 ymm VNNI kernel, and the +// 16x16 zmm VNNI kernel (twice the columns per accumulator). All run the i8i8 +// packing (index 1) over the same M/K/N so the only difference is the matmul +// inner loop and tile geometry. use criterion::*; use tract_data::internal::*; use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; @@ -55,6 +57,11 @@ fn benches(c: &mut Criterion) { g.bench_with_input(BenchmarkId::new("avx512vnni", &id), &(m, k, n), |b, &(m, k, n)| { run_kernel(b, &*avx512vnni_mmm_i32_8x8.mmm(), m, k, n) }); + g.bench_with_input( + BenchmarkId::new("avx512vnni_16x16", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512vnni_mmm_i32_16x16.mmm(), m, k, n), + ); g.finish(); } } diff --git a/linalg/build.rs b/linalg/build.rs index 600a2115f6..cb337ddaa5 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -68,6 +68,55 @@ fn assembler_supports_avx512vnni() -> bool { .is_ok() } +// Probe whether the target assembler can actually assemble Intel AMX int8 +// instructions (`ldtilecfg`, `tilezero`, `tdpbusd`, `tilerelease`). Older +// binutils (e.g. Debian stretch's gas 2.28) predate AMX and reject these +// mnemonics outright, which would break the x86_64 build for users on those +// toolchains. When the probe fails we skip the AMX kernel entirely; the +// matching `tract_amx_int8` cfg keeps the Rust side from referencing the +// (absent) kernel symbol, and `qmmm_i32` dispatch falls back to VNNI (or +// AVX2 when VNNI is itself unavailable). +fn assembler_supports_amx_int8() -> bool { + cc::Build::new() + .file("x86_64/avx512amx/dummy.S") + .cargo_metadata(false) + .cargo_warnings(false) + .warnings(false) + .try_compile("tract_amx_int8_probe") + .is_ok() +} + +// Probe whether the assembler accepts the `{vex}` prefix on VPDPBUSD -- +// needed to force the AVX-VNNI (VEX) form instead of the AVX-512-VNNI +// (EVEX) form gas defaults to. `{vex}` / `{evex}` instruction prefixes +// were added in binutils 2.36; older toolchains reject them. When the +// probe fails the avxvnni_mmm_i32_8x8 kernel is skipped and dispatch +// falls back to the AVX2 emulation kernel on AVX-VNNI-only hardware. +fn assembler_supports_avxvnni() -> bool { + cc::Build::new() + .file("x86_64/avx512amx/dummy_avxvnni.S") + .cargo_metadata(false) + .cargo_warnings(false) + .warnings(false) + .try_compile("tract_avxvnni_probe") + .is_ok() +} + +// Probe whether the target assembler can assemble AMX bf16 instructions +// (`tdpbf16ps`). Both int8 and bf16 AMX mnemonics require binutils >= 2.34, +// so in practice this probe succeeds whenever `assembler_supports_amx_int8` +// does. Provided separately so the two cfgs are independently controlled +// and users on exotic toolchains can opt-out of just the bf16 kernel. +fn assembler_supports_amx_bf16() -> bool { + cc::Build::new() + .file("x86_64/avx512amx/dummy_bf16.S") + .cargo_metadata(false) + .cargo_warnings(false) + .warnings(false) + .try_compile("tract_amx_bf16_probe") + .is_ok() +} + fn include_sve() -> bool { // SVE/SVE2 lives on ARMv9 server/mobile cores (Neoverse V1+/N2+, Cortex-X2+, // Graviton 3/4) — Linux aarch64. No Apple silicon has SVE. @@ -165,6 +214,14 @@ fn main() { println!("cargo:rustc-check-cfg=cfg(tract_arm64_dotprod)"); // Set below only when the x86_64 assembler probe for vpdpbusd ymm passes. println!("cargo:rustc-check-cfg=cfg(tract_avx512vnni)"); + // Set below only when the x86_64 assembler accepts AMX int8 mnemonics + // (avoids breaking the build on toolchains predating AMX). + println!("cargo:rustc-check-cfg=cfg(tract_amx_int8)"); + // Set below only when the assembler accepts AMX bf16 mnemonics (tdpbf16ps). + println!("cargo:rustc-check-cfg=cfg(tract_amx_bf16)"); + // Set below only when the assembler accepts the `{vex}` prefix on + // VPDPBUSD (binutils >= 2.36) -- needed for the AVX-VNNI ymm kernel. + println!("cargo:rustc-check-cfg=cfg(tract_avxvnni)"); match arch.as_ref() { "x86_64" => { @@ -176,6 +233,54 @@ fn main() { }); files.extend(preprocess_files("x86_64/avx512", &[], &suffix, false)); + // Pull the AMX kernel templates out of the generic fma bulk-compile + // so they can be gated behind assembler probes below. All AMX + // mnemonics require gas >= 2.34; old toolchains (Debian stretch's + // binutils 2.28) would otherwise fail the whole build. + // + // Split by accumulator type: + // avx512amx_*_i32_* → tdpbssd → gated on tract_amx_int8 + // avx512amx_*_f32_* → tdpbf16ps → gated on tract_amx_bf16 + let amx_int8_files: Vec = files + .iter() + .filter(|f| { + f.file_name() + .and_then(|n| n.to_str()) + .map(|n| n.starts_with("avx512amx_") && n.contains("_i32_")) + .unwrap_or(false) + }) + .cloned() + .collect(); + let amx_bf16_files: Vec = files + .iter() + .filter(|f| { + f.file_name() + .and_then(|n| n.to_str()) + .map(|n| n.starts_with("avx512amx_") && n.contains("_f32_")) + .unwrap_or(false) + }) + .cloned() + .collect(); + // AVX-VNNI ymm kernel: gas requires the `{vex}` instruction prefix + // (binutils 2.36+) -- pulled aside so the bulk -mfma compile, which + // is fine on older binutils, isn't broken when the AVX-VNNI cfg is + // disabled. + let avxvnni_files: Vec = files + .iter() + .filter(|f| { + f.file_name() + .and_then(|n| n.to_str()) + .map(|n| n.starts_with("avxvnni_")) + .unwrap_or(false) + }) + .cloned() + .collect(); + files.retain(|f| { + !amx_int8_files.contains(f) + && !amx_bf16_files.contains(f) + && !avxvnni_files.contains(f) + }); + if os == "windows" { if use_masm() { let mut lib_exe = cc::windows_registry::find(&target, "lib.exe") @@ -224,20 +329,57 @@ fn main() { } else { cc::Build::new().files(files).flag("-mfma").compile("x86_64_fma"); } - // VNNI kernel compiled separately so old assemblers (binutils < 2.30, + // VNNI kernels compiled separately so old assemblers (binutils < 2.30, // e.g. Debian stretch) that can't encode `vpdpbusd ymm` don't break // the whole x86_64 build. The `tract_avx512vnni` cfg gates the // matching Rust extern declarations and dispatch registration. // - // The template stays in x86_64/fma/ (alongside dispatcher.j2 and the - // other partials it includes) so the jinja env can resolve its includes. + // The templates stay in x86_64/fma/ (alongside dispatcher.j2 and the + // other partials they include) so the jinja env can resolve its includes. if assembler_supports_avx512vnni() { let tmpl = path::Path::new("x86_64/fma/avx512vnni_mmm_i32_8x8.S.j2"); let out = out_dir.join(format!("avx512vnni_mmm_i32_8x8_{suffix}.S")); preprocess_file(tmpl, &out, &[], &suffix, false); - cc::Build::new().file(&out).flag("-mfma").compile("x86_64_avx512vnni"); + // The zmm 16x16 sibling shares the VPDPBUSD probe; compile it into + // the same object so `tract_avx512vnni` gates both kernels together. + let tmpl16 = path::Path::new("x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2"); + let out16 = out_dir.join(format!("avx512vnni_mmm_i32_16x16_{suffix}.S")); + preprocess_file(tmpl16, &out16, &[], &suffix, false); + cc::Build::new().file(&out).file(&out16).flag("-mfma").compile("x86_64_avx512vnni"); println!("cargo:rustc-cfg=tract_avx512vnni"); } + + // AMX int8 kernel: compile only when the assembler accepts the + // mnemonics, and the kernel template was actually pulled aside + // above. Unix only for now (the .S uses the GAS intel-syntax + // path). The `tract_amx_int8` cfg gates the Rust-side symbol + // reference: when the probe fails on old toolchains (e.g. Debian + // stretch's binutils 2.28), the kernel is omitted and `qmmm_i32` + // dispatch falls back to VNNI or AVX2 with no build error. + if os != "windows" && !amx_int8_files.is_empty() && assembler_supports_amx_int8() { + cc::Build::new().files(&amx_int8_files).compile("x86_64_avx512amx"); + println!("cargo:rustc-cfg=tract_amx_int8"); + } + + // AMX bf16 kernel for f32 matmul (tdpbf16ps). Same toolchain + // requirement and Unix-only constraint as the int8 path. When the + // probe fails, the `tract_amx_bf16` cfg stays unset and + // `plug_avx512amx_bf16` is compiled out — `mmm_f32` then falls + // back to AVX-512 / FMA without any build error. + if os != "windows" && !amx_bf16_files.is_empty() && assembler_supports_amx_bf16() { + cc::Build::new().files(&amx_bf16_files).compile("x86_64_avx512amx_bf16"); + println!("cargo:rustc-cfg=tract_amx_bf16"); + } + + // AVX-VNNI ymm int8 kernel. Independent of the AMX gates: this + // kernel ships VPDPBUSD-accelerated i8 GEMM to Atom-class cores + // (Alder Lake-E, Sierra Forest, Clearwater Forest / Darkmont) + // that have AVX-VNNI but no AVX-512, falling back to AVX2 + // emulation when the runtime CPUID detection misses. + if os != "windows" && !avxvnni_files.is_empty() && assembler_supports_avxvnni() { + cc::Build::new().files(&avxvnni_files).compile("x86_64_avxvnni"); + println!("cargo:rustc-cfg=tract_avxvnni"); + } } "arm" | "armv7" => { let files = preprocess_files("arm32/armvfpv2", &[], &suffix, false); diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index e61baa2efe..1deaf41081 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -9,6 +9,10 @@ pub mod mmm; pub mod act; pub mod act_f16; pub mod act_f16_fp16; + +pub mod amx; +pub mod amx_bf16; +pub mod avxvnni; pub mod by_scalar; pub mod erf; mod intel; diff --git a/linalg/src/x86_64_fma/amx.rs b/linalg/src/x86_64_fma/amx.rs new file mode 100644 index 0000000000..a830ac2ac2 --- /dev/null +++ b/linalg/src/x86_64_fma/amx.rs @@ -0,0 +1,264 @@ +// Intel AMX int8 support: A packing format and runtime gate. +// +// The kernel `avx512amx_mmm_i32_8x8` uses TDPBSSD (signed-signed). Per +// iteration of its inner loop it consumes one 8x64-byte A tile and one +// 16x32-byte B tile and updates an 8x8 i32 C tile. The B-side packing +// matches the existing K=4-inner `PackedI8K4` layout, so it is reused +// unchanged. The A-side packing is novel: AMX's tile-A semantics require +// M-major-within-panel row-major bytes, which is incompatible with the +// K-major-outer `PackedI8K4`. `PackedAmxA` below produces that layout. +// +// Runtime gate: CPUID `amx-int8` is necessary but not sufficient on Linux — +// the kernel must also call `arch_prctl(ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)` +// to receive AMX tile-data XSAVE permission from the kernel before any tile +// instruction can run. `has_amx_int8()` performs both checks once and caches +// the result; it returns false on non-Linux even if CPUID reports AMX. + +use std::sync::OnceLock; + +use tract_data::internal::*; + +use crate::WeightType; +use crate::frame::mmm::{ + EagerPackedInput, MMMInputFormat, MMMInputValue, PackedExoticFact, PackedMatrixStorage, +}; + +/// Per-cache geometry from CPUID leaf 4 deterministic cache parameters +/// (the mechanism oneDNN's `platform::get_per_core_cache_size` ultimately +/// reads). Used here for runtime adaptive choices that depend on the +/// hardware -- e.g. picking `tileloadd` vs `tileloaddt1` based on whether +/// the matmul working set fits in L1d (oneDNN's `try_load_nt` heuristic). +#[derive(Clone, Copy, Debug, Default)] +pub struct CacheSizes { + pub l1d_bytes: usize, + pub l2_bytes: usize, + pub l3_bytes: usize, +} + +/// Probe per-core L1d/L2/L3 cache sizes via CPUID leaf 4 deterministic +/// cache parameters. Iterates sub-leaves 0, 1, 2, ... until cache type = 0 +/// (no more caches). Each cache is described by: +/// EAX[4:0] = cache type (0=null, 1=data, 2=instr, 3=unified) +/// EAX[7:5] = cache level (1, 2, 3, ...) +/// EBX[11:0] = ways - 1 +/// EBX[21:12]= partitions - 1 +/// EBX[31:22]= line_size_bytes - 1 +/// ECX = sets - 1 +/// cache_bytes = (ways+1) * (partitions+1) * (line_size+1) * (sets+1) +/// Returns zeros for unknown levels (e.g. on a CPU without an L3, or if +/// the CPUID interface is unavailable). Memoised; called at most once. +pub fn cache_sizes() -> CacheSizes { + static CACHE: OnceLock = OnceLock::new(); + *CACHE.get_or_init(|| { + let mut out = CacheSizes::default(); + for sub in 0..16 { + #[allow(unused_unsafe)] + let r = unsafe { std::arch::x86_64::__cpuid_count(4, sub) }; + let cache_type = r.eax & 0x1F; + if cache_type == 0 { + break; + } + let level = (r.eax >> 5) & 0x7; + let ways = ((r.ebx >> 22) & 0x3FF) + 1; + let partitions = ((r.ebx >> 12) & 0x3FF) + 1; + let line_size = (r.ebx & 0xFFF) + 1; + let sets = r.ecx + 1; + let bytes = + (ways as usize) * (partitions as usize) * (line_size as usize) * (sets as usize); + // type=1 (data), type=3 (unified) for L1d / L2 / L3 + match (level, cache_type) { + (1, 1) => out.l1d_bytes = bytes, + (2, 1 | 3) => out.l2_bytes = bytes, + (3, 1 | 3) => out.l3_bytes = bytes, + _ => {} + } + } + out + }) +} + +/// Detect AMX-INT8 + AMX-TILE via CPUID leaf 7 sub-leaf 0 (EDX bits 24-25). +/// Stable-Rust friendly: `is_x86_feature_detected!("amx-int8")` is gated on +/// the nightly `x86_amx_intrinsics` feature, so we read CPUID by hand. +fn cpu_has_amx_int8() -> bool { + if !std::is_x86_feature_detected!("avx512f") { + return false; + } + #[allow(unused_unsafe)] + let r = unsafe { std::arch::x86_64::__cpuid_count(7, 0) }; + // bit 24 = AMX-TILE, bit 25 = AMX-INT8 in EDX. + const AMX_TILE: u32 = 1 << 24; + const AMX_INT8: u32 = 1 << 25; + (r.edx & AMX_TILE) != 0 && (r.edx & AMX_INT8) != 0 +} + +/// Linux only: ask the kernel for permission to use the AMX tile-data XSAVE +/// state via `arch_prctl(ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)`. Returns +/// true if the kernel grants permission (or if the process already has it). +/// Exposed via `request_amx_tile_xcomp_perm()` below so the bf16 path can +/// share the same OS-level gate. +#[cfg(target_os = "linux")] +unsafe fn request_amx_xcomp_perm() -> bool { + // x86_64 syscall: rax=158 (arch_prctl), rdi=0x1023 (REQ_XCOMP_PERM), + // rsi=18 (XFEATURE_XTILEDATA). Returns 0 on success. + let rc: i64; + unsafe { + std::arch::asm!( + "syscall", + in("rax") 158i64, + in("rdi") 0x1023i64, + in("rsi") 18i64, + lateout("rax") rc, + out("rcx") _, + out("r11") _, + options(nostack), + ); + } + rc == 0 +} + +/// Memoised wrapper around `request_amx_xcomp_perm` -- arch_prctl has a +/// process-wide effect and only needs to be called once for the whole +/// lifetime of the process. Returns true iff the OS has granted permission +/// for XFEATURE_XTILEDATA (and hence enables both AMX int8 AND AMX bf16 +/// kernels). Returns false on non-Linux. +pub fn request_amx_tile_xcomp_perm() -> bool { + static GATE: OnceLock = OnceLock::new(); + *GATE.get_or_init(|| { + #[cfg(target_os = "linux")] + { + unsafe { request_amx_xcomp_perm() } + } + #[cfg(not(target_os = "linux"))] + { + false + } + }) +} + +/// Returns true iff Intel AMX int8 is available AND the OS has granted this +/// process permission to use the AMX tile-data XSAVE state. Result is +/// memoised — the arch_prctl call has process-wide effect and only needs to +/// run once. +pub fn has_amx_int8() -> bool { + static GATE: OnceLock = OnceLock::new(); + *GATE.get_or_init(|| cpu_has_amx_int8() && request_amx_tile_xcomp_perm()) +} + +/// AMX-friendly A packing: per `r`-row panel, M-rows are laid out row-major +/// across `K_padded = ceil(K / 64) * 64` contiguous bytes per row. AMX's +/// `tileloadd` with stride = K_padded reads exactly 8 contiguous M-rows of +/// 64 K-bytes each per call. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PackedAmxA { + pub r: usize, + pub align: usize, +} + +impl PackedAmxA { + pub fn new(r: usize) -> Self { + PackedAmxA { r, align: 64 } + } + fn k_padded(&self, k: usize) -> usize { + k.div_ceil(64) * 64 + } + fn panel(&self, k: usize) -> usize { + self.k_padded(k) * self.r + } + pub fn single_panel_len(&self, k: usize) -> usize { + self.panel(k) + } + pub fn len(&self, k: usize, mn: usize) -> usize { + mn.div_ceil(self.r) * self.panel(k) + } + pub fn alignment(&self) -> usize { + self.align + } + + pub fn pack_view( + &self, + t: &TensorView, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + let k = t.shape()[k_axis]; + let mn = t.shape()[mn_axis]; + let kp = self.k_padded(k); + let pl = kp * self.r; + let panels = mn.div_ceil(self.r); + let st = t.strides(); + let (ks, ms) = (st[k_axis], st[mn_axis]); + let mut blob = unsafe { Blob::new_for_size_and_align(panels * pl, self.align) }; + blob.as_bytes_mut().fill(0); + unsafe { + let src = t.as_ptr_unchecked::(); + let dst = blob.as_mut_ptr() as *mut i8; + for p in 0..panels { + let pw = self.r.min(mn - p * self.r); + let panel = dst.add(p * pl); + let mn0 = (p * self.r) as isize; + for lm in 0..pw { + let drow = panel.add(lm * kp); + let srow_base = src.offset((mn0 + lm as isize) * ms); + for kk in 0..k { + *drow.add(kk) = *srow_base.offset(kk as isize * ks); + } + } + } + } + Ok(Box::new(EagerPackedInput { + fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k }, + packed: blob.into(), + panel_bytes: pl, + mn, + })) + } +} + +impl std::fmt::Display for PackedAmxA { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "AmxA[{}]", self.r) + } +} + +impl MMMInputFormat for PackedAmxA { + fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult { + Ok(PackedMatrixStorage::new(self.prepare_one(t, k_axis, mn_axis)?) + .into_tensor(t.datum_type())) + } + fn prepare_one( + &self, + t: &Tensor, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + self.pack_view(&t.view(), k_axis, mn_axis) + } + fn precursor(&self) -> WeightType { + WeightType::Plain(i8::datum_type()) + } + fn r(&self) -> usize { + self.r + } + fn k_alignment(&self) -> usize { + // AMX consumes K=64 bytes per tdpbssd inner iteration; the packer + // already pads internally, but expose the alignment so upstream + // schedulers can reason about K-blocking. + 64 + } + fn merge_with<'o, 'a: 'o, 'b: 'o>( + &'a self, + o: &'b dyn MMMInputFormat, + ) -> Option<&'o dyn MMMInputFormat> { + o.downcast_ref::().filter(|x| x.r == self.r).map(|_| self as _) + } + fn mem_size(&self, k: TDim, mn: TDim) -> TDim { + mn.divceil(self.r) * self.panel(k.to_usize().unwrap_or(0)) + } + fn extract_at_mn_f16(&self, _: &EagerPackedInput, _: usize, _: &mut [f16]) -> TractResult<()> { + bail!("no f16 extract") + } + fn extract_at_mn_f32(&self, _: &EagerPackedInput, _: usize, _: &mut [f32]) -> TractResult<()> { + bail!("no f32 extract") + } +} diff --git a/linalg/src/x86_64_fma/amx_bf16.rs b/linalg/src/x86_64_fma/amx_bf16.rs new file mode 100644 index 0000000000..abda445c63 --- /dev/null +++ b/linalg/src/x86_64_fma/amx_bf16.rs @@ -0,0 +1,316 @@ +// Intel AMX bf16 support: f32 -> bf16 packers and the AMX bf16 runtime gate. +// +// The kernel `avx512amx_mmm_f32_16x16` uses TDPBF16PS (bf16 x bf16 -> f32) to +// accelerate f32 matmul on Sapphire Rapids+ AMX hardware. The inputs are +// truncated from f32 to bf16 at pack time (round-to-nearest-even, matching +// Intel's VCVTNEPS2BF16 semantics); the f32 accumulators are bit-identical +// to a "scalar bf16 multiply + f32 accumulate" reference but DIFFER from a +// pure-f32 FMA reference by ~1 / 2^8 relative per multiply (bf16 has 8 +// mantissa bits vs f32's 23). This precision loss is the same as oneDNN +// "fast-math" f32 matmul on AMX and is acceptable for inference workloads +// (LLMs, CNNs) that already tolerate bf16. +// +// Tile geometry mirrors the i32 16x16 kernel: 16 rows x 64 colsb per tile. +// Per TDPBF16PS: 16 M-rows x 16 N-cols x 32 K-bf16 = 8192 fma operations +// per single instruction -- the same throughput as TDPBSSD. + +use std::sync::OnceLock; + +use tract_data::internal::*; + +use crate::WeightType; +use crate::frame::mmm::{ + EagerPackedInput, MMMInputFormat, MMMInputValue, PackedExoticFact, PackedMatrixStorage, +}; + +/// Detect AMX-BF16 + AMX-TILE via CPUID leaf 7 sub-leaf 0 (EDX bits 22, 24). +/// AMX-BF16 is the bit-22 capability; AMX-TILE (bit 24) is mandatory for any +/// AMX use. Returns false unless both are present. +fn cpu_has_amx_bf16() -> bool { + if !std::is_x86_feature_detected!("avx512f") { + return false; + } + #[allow(unused_unsafe)] + let r = unsafe { std::arch::x86_64::__cpuid_count(7, 0) }; + const AMX_BF16: u32 = 1 << 22; + const AMX_TILE: u32 = 1 << 24; + (r.edx & AMX_BF16) != 0 && (r.edx & AMX_TILE) != 0 +} + +/// Returns true iff Intel AMX bf16 is available AND the OS has granted this +/// process permission to use the AMX tile-data XSAVE state. Reuses the +/// arch_prctl XCOMP-perm request mechanism from the int8 path -- the same +/// XFEATURE_XTILEDATA permission gates both data types. +pub fn has_amx_bf16() -> bool { + static GATE: OnceLock = OnceLock::new(); + *GATE.get_or_init(|| cpu_has_amx_bf16() && super::amx::request_amx_tile_xcomp_perm()) +} + +/// Convert an f32 to bf16 with round-to-nearest-even (matches Intel's +/// VCVTNEPS2BF16). NaN inputs are preserved as quiet NaN. Used by the bf16 +/// packers below (scalar; AMX hardware is on Sapphire Rapids+ which has the +/// AVX-512-BF16 intrinsic for batched conversion, but packing is amortised +/// over many kernel calls so the scalar path is fine). +#[inline] +pub fn f32_to_bf16_rne(x: f32) -> u16 { + let bits = x.to_bits(); + // NaN check: exponent all-ones and mantissa nonzero. + if (bits & 0x7F80_0000) == 0x7F80_0000 && (bits & 0x007F_FFFF) != 0 { + // Quiet NaN: set the top mantissa bit of the bf16 result. + ((bits >> 16) as u16) | 0x0040 + } else { + // round-to-nearest-even: add 0x7FFF + (lsb of bf16) before truncating. + let lsb = (bits >> 16) & 1; + let rounding = 0x0000_7FFF + lsb; + (bits.wrapping_add(rounding) >> 16) as u16 + } +} + +/// AMX-friendly A packing for f32 matmul via bf16. Per `r`-row panel, the +/// M-rows are laid out row-major in bf16 across `K_padded` contiguous bf16 +/// per row (K_padded = ceil(K/32)*32, so each row is a whole number of +/// AMX K-tile widths). Source is f32; conversion happens at pack time. +/// +/// panel_bytes = r * K_padded * 2 (each bf16 = 2 bytes) +/// +/// AMX `tileloadd` with stride = K_padded*2 reads exactly 16 M-rows of +/// 64 bytes (= 32 bf16) per call -- one inner-K iter's worth. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PackedAmxBf16A { + pub r: usize, + pub align: usize, +} + +impl PackedAmxBf16A { + pub fn new(r: usize) -> Self { + PackedAmxBf16A { r, align: 64 } + } + fn k_padded(&self, k: usize) -> usize { + k.div_ceil(32) * 32 + } + fn panel(&self, k: usize) -> usize { + self.k_padded(k) * self.r * 2 + } + pub fn single_panel_len(&self, k: usize) -> usize { + self.panel(k) + } + pub fn len(&self, k: usize, mn: usize) -> usize { + mn.div_ceil(self.r) * self.panel(k) + } + pub fn alignment(&self) -> usize { + self.align + } + + pub fn pack_view( + &self, + t: &TensorView, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + let k = t.shape()[k_axis]; + let mn = t.shape()[mn_axis]; + let kp = self.k_padded(k); + let pl = kp * self.r * 2; // bytes per panel + let panels = mn.div_ceil(self.r); + let st = t.strides(); + let (ks, ms) = (st[k_axis], st[mn_axis]); + let mut blob = unsafe { Blob::new_for_size_and_align(panels * pl, self.align) }; + blob.as_bytes_mut().fill(0); + unsafe { + let src = t.as_ptr_unchecked::(); + let dst = blob.as_mut_ptr() as *mut u16; + for p in 0..panels { + let pw = self.r.min(mn - p * self.r); + let panel = dst.add(p * (kp * self.r)); // panel_offset in u16 elements + let mn0 = (p * self.r) as isize; + for lm in 0..pw { + let drow = panel.add(lm * kp); + let srow_base = src.offset((mn0 + lm as isize) * ms); + for kk in 0..k { + let v = *srow_base.offset(kk as isize * ks); + *drow.add(kk) = f32_to_bf16_rne(v); + } + } + } + } + Ok(Box::new(EagerPackedInput { + fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k }, + packed: blob.into(), + panel_bytes: pl, + mn, + })) + } +} + +impl std::fmt::Display for PackedAmxBf16A { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "AmxBf16A[{}]", self.r) + } +} + +impl MMMInputFormat for PackedAmxBf16A { + fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult { + Ok(PackedMatrixStorage::new(self.prepare_one(t, k_axis, mn_axis)?) + .into_tensor(t.datum_type())) + } + fn prepare_one( + &self, + t: &Tensor, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + self.pack_view(&t.view(), k_axis, mn_axis) + } + fn k_alignment(&self) -> usize { + // tdpbf16ps consumes 32 bf16 per K-step. + 32 + } + fn r(&self) -> usize { + self.r + } + fn precursor(&self) -> WeightType { + WeightType::Plain(f32::datum_type()) + } + fn merge_with<'o, 'a: 'o, 'b: 'o>( + &'a self, + o: &'b dyn MMMInputFormat, + ) -> Option<&'o dyn MMMInputFormat> { + o.downcast_ref::().filter(|x| x.r == self.r).map(|_| self as _) + } + fn mem_size(&self, k: TDim, mn: TDim) -> TDim { + mn.divceil(self.r) * self.panel(k.to_usize().unwrap_or(0)) + } + fn extract_at_mn_f16(&self, _: &EagerPackedInput, _: usize, _: &mut [f16]) -> TractResult<()> { + bail!("no f16 extract") + } + fn extract_at_mn_f32(&self, _: &EagerPackedInput, _: usize, _: &mut [f32]) -> TractResult<()> { + bail!("no f32 extract") + } +} + +/// AMX-friendly B packing for f32 matmul via bf16 (analog of PackedI8K4 but +/// K=2-inner instead of K=4-inner -- tdpbf16ps groups 2 bf16 per K-step). +/// +/// Per K=2 block: r N-cols x 2 K-bf16 = r * 2 * 2 bytes = 4r bytes. +/// Block layout: byte (n*4 + ki*2..(n*4 + ki*2 + 2)) = bf16 of B[2kb+ki, n]. +/// For r=16: 64 bytes per K=2 block, 16 blocks per K=32 AMX tile -> 1024 B. +/// +/// One AMX `tileloadd` with stride = 4r bytes reads 16 K-pair-rows of +/// r * 4 bytes each = one inner-K iter's worth of B. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PackedBf16K2 { + pub r: usize, + pub align: usize, +} + +impl PackedBf16K2 { + pub fn new(r: usize) -> Self { + PackedBf16K2 { r, align: 64 } + } + fn k_padded(&self, k: usize) -> usize { + k.div_ceil(2) * 2 + } + fn panel(&self, k: usize) -> usize { + self.k_padded(k) * self.r * 2 + } + pub fn single_panel_len(&self, k: usize) -> usize { + self.panel(k) + } + pub fn len(&self, k: usize, mn: usize) -> usize { + mn.div_ceil(self.r) * self.panel(k) + } + pub fn alignment(&self) -> usize { + self.align + } + pub fn pack_view( + &self, + t: &TensorView, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + let k = t.shape()[k_axis]; + let mn = t.shape()[mn_axis]; + let kp = self.k_padded(k); + let pl = kp * self.r * 2; // bytes per panel + let panels = mn.div_ceil(self.r); + let st = t.strides(); + let mut blob = unsafe { Blob::new_for_size_and_align(panels * pl, self.align) }; + blob.as_bytes_mut().fill(0); + let (ks, ms) = (st[k_axis], st[mn_axis]); + let kblocks = kp / 2; + unsafe { + let src = t.as_ptr_unchecked::(); + let dst = blob.as_mut_ptr() as *mut u16; + for p in 0..panels { + let pw = self.r.min(mn - p * self.r); + let panel = dst.add(p * (kp * self.r)); + let mn0 = (p * self.r) as isize; + for kb in 0..kblocks { + for ki in 0..2 { + let kk = kb * 2 + ki; + if kk >= k { + break; + } + let srow = src.offset(kk as isize * ks + mn0 * ms); + let dblock = panel.add(kb * self.r * 2 + ki); + for lm in 0..pw { + let v = *srow.offset(lm as isize * ms); + *dblock.add(lm * 2) = f32_to_bf16_rne(v); + } + } + } + } + } + Ok(Box::new(EagerPackedInput { + fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k }, + packed: blob.into(), + panel_bytes: pl, + mn, + })) + } +} + +impl std::fmt::Display for PackedBf16K2 { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Bf16K2[{}]", self.r) + } +} + +impl MMMInputFormat for PackedBf16K2 { + fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult { + Ok(PackedMatrixStorage::new(self.prepare_one(t, k_axis, mn_axis)?) + .into_tensor(t.datum_type())) + } + fn prepare_one( + &self, + t: &Tensor, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + self.pack_view(&t.view(), k_axis, mn_axis) + } + fn k_alignment(&self) -> usize { + 2 + } + fn r(&self) -> usize { + self.r + } + fn precursor(&self) -> WeightType { + WeightType::Plain(f32::datum_type()) + } + fn merge_with<'o, 'a: 'o, 'b: 'o>( + &'a self, + o: &'b dyn MMMInputFormat, + ) -> Option<&'o dyn MMMInputFormat> { + o.downcast_ref::().filter(|x| x.r == self.r).map(|_| self as _) + } + fn mem_size(&self, k: TDim, mn: TDim) -> TDim { + mn.divceil(self.r) * self.panel(k.to_usize().unwrap_or(0)) + } + fn extract_at_mn_f16(&self, _: &EagerPackedInput, _: usize, _: &mut [f16]) -> TractResult<()> { + bail!("no f16 extract") + } + fn extract_at_mn_f32(&self, _: &EagerPackedInput, _: usize, _: &mut [f32]) -> TractResult<()> { + bail!("no f32 extract") + } +} diff --git a/linalg/src/x86_64_fma/avxvnni.rs b/linalg/src/x86_64_fma/avxvnni.rs new file mode 100644 index 0000000000..ee6af7d857 --- /dev/null +++ b/linalg/src/x86_64_fma/avxvnni.rs @@ -0,0 +1,44 @@ +// AVX-VNNI int8 GEMM runtime gate. +// +// AVX-VNNI (CPUID leaf 7 sub-leaf 1 EAX bit 4) is the VEX-encoded sibling of +// AVX-512-VNNI's VPDPBUSD: same i32 += u8 * s8 dot4 semantics, but addressable +// from VEX (= AVX2-class) decoders. It exists primarily for Atom-class +// server / E-core SKUs that have AVX2 + AVX-VNNI but no AVX-512: +// +// * Alder Lake / Raptor Lake / Meteor Lake E-cores (Gracemont, Crestmont) +// * Sierra Forest (Sierra Glen) +// * Clearwater Forest (Darkmont) +// +// On a CPU with AVX-512-VNNI (Cascade Lake, Ice Lake, Sapphire Rapids+), this +// detector still returns true if CPUID leaf 7.1 EAX.4 is set -- some big-core +// SKUs report AVX-VNNI alongside AVX-512-VNNI -- but the dispatch in mmm.rs +// prefers the EVEX-encoded avx512vnni kernel in that case (same throughput, +// 32 zmm registers available for unrolling). The AVX-VNNI kernel is only +// selected when AVX-512-VNNI is absent. + +use std::sync::OnceLock; + +/// CPUID leaf 7 sub-leaf 1, EAX bit 4 = AVX-VNNI (Intel SDM Vol 2 Table 1-7). +/// Sub-leaf 1 is only valid when CPUID.7.0.EAX (the max sub-leaf field) >= 1; +/// older CPUs return zeroed structures. We check the max-sub-leaf first to +/// avoid a misleading bit on pre-AVX-VNNI silicon. +fn cpu_has_avxvnni() -> bool { + if !std::is_x86_feature_detected!("avx2") { + return false; + } + #[allow(unused_unsafe)] + let max_sub = unsafe { std::arch::x86_64::__cpuid_count(7, 0) }.eax; + if max_sub < 1 { + return false; + } + #[allow(unused_unsafe)] + let r = unsafe { std::arch::x86_64::__cpuid_count(7, 1) }; + (r.eax & (1 << 4)) != 0 +} + +/// Returns true iff CPUID reports AVX-VNNI on this CPU. Memoised; no OS +/// permission gate is required (unlike AMX, AVX-VNNI uses no extended state). +pub fn has_avxvnni() -> bool { + static GATE: OnceLock = OnceLock::new(); + *GATE.get_or_init(cpu_has_avxvnni) +} diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index eaaf40b9d0..1579ac5d37 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -2,10 +2,25 @@ use crate::Ops; use crate::block_quant::*; use crate::mmm::ImplementationQuality::ManuallyOptimized; use crate::mmm::MatMatMul; -use crate::pack::{PackedFormat, PackedI8K4}; +use crate::pack::PackedFormat; +#[cfg(any(tract_avx512vnni, tract_avxvnni, tract_amx_int8))] +use crate::pack::PackedI8K4; +#[cfg(tract_amx_int8)] +use super::amx::{PackedAmxA, has_amx_int8}; +#[cfg(tract_amx_bf16)] +use super::amx_bf16::{PackedAmxBf16A, PackedBf16K2, has_amx_bf16}; +#[cfg(tract_avxvnni)] +use super::avxvnni::has_avxvnni; use super::*; +#[cfg(tract_amx_int8)] +const AVX512AMX: fn() -> bool = has_amx_int8; +#[cfg(tract_amx_bf16)] +const AVX512AMX_BF16: fn() -> bool = has_amx_bf16; +#[cfg(tract_avxvnni)] +const AVXVNNI: fn() -> bool = has_avxvnni; + /// One candidate kernel in a dispatcher's pool, with its tile geometry /// and a relative-throughput scale (1.0 = baseline, used to break /// near-ties between kernels with similar tile waste). @@ -107,9 +122,105 @@ MMMExternKernel! { avx512vnni_mmm_i32_8x8(8,8)@(256,4) where(AVX512VNNI) store(i8) } +// AVX-512 VNNI int8 GEMM, zmm-wide 16x16 sibling of avx512vnni_mmm_i32_8x8. +// Accumulators are ROW-MAJOR (zmm{m} = row m of C, 16 columns per zmm), so one +// VPDPBUSD covers 16 columns x 4 K and the K=4 inner step issues 16 of them +// (one per row) = 1024 mul-adds/block, 2x the 8x8 ymm kernel's work per +// iteration. Same +128 A-bias / per-column correction as the 8x8 kernel, and +// the same PackedI8K4 layout (r=16 for both A and B). This is the int8 +// throughput tier of qmmm_i32 for big cores with AVX-512-VNNI but no AMX +// (Cascade Lake / Ice Lake / Tiger Lake server + client). +// +// boost(50) lifts it above the 8x8 VNNI candidate in the einsum kernel-selection +// scorer for unknown shapes, while staying below the AMX 16x16 kernels' boost(100) +// so AMX still wins when both are present. +#[cfg(tract_avx512vnni)] +MMMExternKernel! { avx512vnni_mmm_i32_16x16(16,16)@(64,4) where(AVX512VNNI) + packing[1] = i8i8 => |k| k.with_packing(PackedI8K4::new(16), PackedI8K4::new(16)); + quality(ManuallyOptimized) + boost(|| 50) + store(i8) +} + +// AVX-VNNI ymm int8 GEMM: byte-for-byte the same body as avx512vnni_mmm_i32_8x8 +// (8x8 ymm accumulators, PackedI8K4 inner-K, +128 bias trick), but the +// VPDPBUSD instructions are forced to the VEX (AVX-VNNI) encoding via the +// `{vex}` prefix. Runs on Atom-class cores (Alder Lake-E, Sierra Forest, +// Clearwater Forest / Darkmont) which have AVX-VNNI but no AVX-512. On big +// cores with both AVX-512-VNNI and AVX-VNNI (Sapphire Rapids+, some Alder +// Lake P-core SKUs) dispatch prefers the EVEX-encoded kernel above. +#[cfg(tract_avxvnni)] +MMMExternKernel! { avxvnni_mmm_i32_8x8(8,8)@(256,4) where(AVXVNNI) + packing[1] = i8i8 => |k| k.with_packing(PackedI8K4::new(8), PackedI8K4::new(8)); + quality(ManuallyOptimized) + store(i8) +} + +// Same epilogue as avx512vnni_mmm_i32_8x8 (8x8 ymm accumulators), but the i8i8 +// matmul inner loop uses TDPBSSD (16-M x 16-N x 64-K mul-acc per instruction) +// over AMX tiles. A's packing is novel (PackedAmxA, M-major-within-panel, +// K-padded to multiples of 64); B reuses VNNI's K=4-inner PackedI8K4 layout +// unchanged. TDPBSSD is s8 x s8 so no +128 bias trick — accumulators are +// bit-identical to AVX2/VNNI. Gated by `where(AVX512AMX)` (= CPUID amx-int8 +// AND Linux XSAVE permission via arch_prctl). +#[cfg(tract_amx_int8)] +MMMExternKernel! { avx512amx_mmm_i32_8x8(8,8)@(64,4) where(AVX512AMX) + packing[1] = i8i8 => |k| k.with_packing(PackedAmxA::new(8), PackedI8K4::new(8)); + quality(ManuallyOptimized) + store(i8) +} + +// 16x16 i32 sibling. One tdpbssd does 16*16*64 = 16384 mul-adds (4x the 8x8). +// Same A/B packing (PackedAmxA, PackedI8K4) just with r=16. Row-major +// accumulators (zmm{m} = row m of C) so the hot path (Clear -> AddMatMul -> +// Store) needs no transpose. +// +// boost(100) pushes this kernel above the equally-ManuallyOptimized AVX-512-VNNI +// and AMX 8x8 candidates in the einsum kernel-selection scorer (which uses +// `-quality_cost*1000 + boost` per kernel). When more than one dim is symbolic +// the shape-adaptive `qmmm_i32` picker isn't invoked, so the boost is what +// causes the optimizer to prefer the 16x16 tile for unknown-shape matmuls. +#[cfg(tract_amx_int8)] +MMMExternKernel! { avx512amx_mmm_i32_16x16(16,16)@(64,4) where(AVX512AMX) + packing[1] = i8i8 => |k| k.with_packing(PackedAmxA::new(16), PackedI8K4::new(16)); + quality(ManuallyOptimized) + boost(|| 100) + store(i8) +} + +// AMX bf16 16x16 kernel for f32 matmul: uses TDPBF16PS (bf16 x bf16 -> f32). +// f32 inputs are truncated to bf16 at pack time (round-to-nearest-even, matching +// Intel VCVTNEPS2BF16). One tdpbf16ps consumes 16M x 16N x 32K bf16 = 8192 fma +// per instruction. f32 accumulators differ from a pure-f32 reference by ~1/2^8 +// relative per multiply (bf16 = 8 mantissa bits vs f32's 23) -- same precision +// loss profile as oneDNN "fast-math" f32 matmul on AMX, acceptable for +// inference workloads (LLMs, CNNs) that already tolerate bf16. +// +// Default packing[0] (the framework's PackedFormat) is retained so the +// kernel can still be selected for f32 paths even when the BF16 packer +// isn't a precursor match; packing[1] is the fast bf16-from-f32 path. +// boost(100) puts this AMX kernel above the AVX-512 f32 / FMA f32 kernels at +// the same ManuallyOptimized tier so the einsum scorer prefers it whenever +// supported, mirroring the i32 16x16 behaviour. The bf16 vs f32 precision +// trade is intentional and amortised over the same call sites that already +// use bf16-via-`dotbf16ps`-style fast-math elsewhere in the stack. +#[cfg(tract_amx_bf16)] +MMMExternKernel! { avx512amx_mmm_f32_16x16(16,16)@(64,4) where(AVX512AMX_BF16) + packing[1] = f32f32_bf16 => |k| k.with_packing(PackedAmxBf16A::new(16), PackedBf16K2::new(16)); + quality(ManuallyOptimized) + boost(|| 100) +} + pub fn plug(ops: &mut Ops) { if is_x86_feature_detected!("avx2") { plug_avx2(ops); + // AVX-VNNI runs on AVX2-only Atom-class cores (Alder Lake-E, Sierra + // Forest, Clearwater Forest / Darkmont). Plug it here so big cores + // can overlay AVX-512-VNNI / AMX on top below. + #[cfg(tract_avxvnni)] + if has_avxvnni() { + plug_avxvnni(ops); + } if is_x86_feature_detected!("fma") { plug_fma(ops); if is_x86_feature_detected!("avx512f") { @@ -117,6 +228,19 @@ pub fn plug(ops: &mut Ops) { #[cfg(tract_avx512vnni)] if is_x86_feature_detected!("avx512vnni") { plug_avx512vnni(ops); + // AMX int8 preferred over VNNI when both available AND the OS + // has granted XSAVE tile-data permission (see `has_amx_int8`). + #[cfg(tract_amx_int8)] + if has_amx_int8() { + plug_avx512amx_int8(ops); + } + } + // AMX bf16 for f32 matmul is independent of int8/VNNI gates: + // a future Xeon SKU could ship AMX-BF16 without VNNI, and the + // permission gate is shared with the int8 path inside has_amx_bf16(). + #[cfg(tract_amx_bf16)] + if has_amx_bf16() { + plug_avx512amx_bf16(ops); } } } @@ -126,8 +250,100 @@ pub fn plug(ops: &mut Ops) { #[cfg(tract_avx512vnni)] pub fn plug_avx512vnni(ops: &mut Ops) { ops.mmm_impls.push(avx512vnni_mmm_i32_8x8.mmm()); - ops.qmmm_i32 = Box::new(|_, _, _| avx512vnni_mmm_i32_8x8.mmm()); - log::info!("qmmm_i32: x86_64/avx512vnni activated"); + ops.mmm_impls.push(avx512vnni_mmm_i32_16x16.mmm()); + // Shape-adaptive dispatch mirroring the AMX int8 path: the zmm 16x16 tile is + // the throughput champion when each of M and N fills at least one tile; the + // 8x8 ymm kernel has lower per-call setup (smaller epilogue, half the + // accumulator file) and wins on small problems where the 16x16 tile-padding + // overhead dominates. Unknown dims default to the 16x16 champion. (No K gate: + // one VPDPBUSD step is only 4 K-bytes, so any K is fine; the choice is about + // filling the 16-wide M/N tile.) + ops.qmmm_i32 = Box::new(|m, _, n| { + let big = |o: Option, t: usize| o.is_none_or(|v| v >= t); + if big(m, 16) && big(n, 16) { + avx512vnni_mmm_i32_16x16.mmm() + } else { + avx512vnni_mmm_i32_8x8.mmm() + } + }); + log::info!("qmmm_i32: x86_64/avx512vnni (16x16 + 8x8 adaptive) activated"); +} + +#[cfg(tract_avxvnni)] +pub fn plug_avxvnni(ops: &mut Ops) { + ops.mmm_impls.push(avxvnni_mmm_i32_8x8.mmm()); + // On AVX-VNNI-only cores (no AVX-512) this is the int8 throughput champion; + // replace the AVX2 emulation default. On big cores that also have + // AVX-512-VNNI, plug_avx512vnni below runs after this and clobbers + // qmmm_i32 again with the EVEX kernel. + ops.qmmm_i32 = Box::new(|_, _, _| avxvnni_mmm_i32_8x8.mmm()); + log::info!("qmmm_i32: x86_64/avxvnni (VEX-encoded VPDPBUSD) activated"); +} + +#[cfg(tract_amx_bf16)] +pub fn plug_avx512amx_bf16(ops: &mut Ops) { + ops.mmm_impls.push(avx512amx_mmm_f32_16x16.mmm()); + // Save the previously-installed f32 picker so we can defer to it when + // the AMX kernel isn't a good fit (small M/N, or K < 32 -- one TDPBF16PS + // consumes 32 bf16 K-lanes so the panel must have at least one full step). + let prev: crate::MMMImpl = + std::mem::replace(&mut ops.mmm_f32, Box::new(|_, _, _| unreachable!())); + ops.mmm_f32 = Box::new(move |m, k, n| { + let big = |o: Option, t: usize| o.is_none_or(|v| v >= t); + // Same dispatch shape as the int8 16x16/8x8 split: hand off to AMX + // only when each axis comfortably fills at least one tile. The 32-K + // threshold matches PackedAmxBf16A::k_alignment() (one tdpbf16ps = + // 32 bf16 K-lanes); below that, the AVX-512 / FMA path's smaller + // tiles waste less work. + if big(m, 16) && big(n, 16) && big(k, 32) { + avx512amx_mmm_f32_16x16.mmm() + } else { + prev(m, k, n) + } + }); + let c = super::amx::cache_sizes(); + log::info!( + "mmm_f32: x86_64/avx512amx_bf16 (16x16) overlay activated; \ + L1d={} KB, L2={} KB, L3={} KB", + c.l1d_bytes / 1024, + c.l2_bytes / 1024, + c.l3_bytes / 1024, + ); +} + +#[cfg(tract_amx_int8)] +pub fn plug_avx512amx_int8(ops: &mut Ops) { + ops.mmm_impls.push(avx512amx_mmm_i32_8x8.mmm()); + ops.mmm_impls.push(avx512amx_mmm_i32_16x16.mmm()); + // Shape-adaptive dispatch: + // - 16x16 hits the full AMX tile (1024 B/tile, 16384 mul-adds per + // tdpbssd) and is the throughput champion when at least one tile + // of each dim is fully utilised. + // - 8x8 has lower per-call setup cost (1/4 the tile-store scratch, + // half the prefetch budget, smaller epilogue) and beats 16x16 on + // small problems where the framework's tile-padding overhead + // dominates. + // The exact crossover should be re-validated on AMX HW; oneDNN uses + // similar shape-based MR/NR selection for its BRGEMM ukernel variants. + ops.qmmm_i32 = Box::new(|m, k, n| { + // m, k, n are Option -- None means "unknown / streaming dim". + // For unknown dims default to the throughput champion (16x16); only + // fall back to 8x8 when a static dim is known to be tiny. + let big = |o: Option, t: usize| o.is_none_or(|v| v >= t); + if big(m, 16) && big(n, 16) && big(k, 64) { + avx512amx_mmm_i32_16x16.mmm() + } else { + avx512amx_mmm_i32_8x8.mmm() + } + }); + let c = super::amx::cache_sizes(); + log::info!( + "qmmm_i32: x86_64/avx512amx_int8 (16x16 + 8x8 adaptive) activated; \ + L1d={} KB, L2={} KB, L3={} KB", + c.l1d_bytes / 1024, + c.l2_bytes / 1024, + c.l3_bytes / 1024, + ); } pub fn plug_avx2(ops: &mut Ops) { diff --git a/linalg/x86_64/avx512amx/dummy.S b/linalg/x86_64/avx512amx/dummy.S new file mode 100644 index 0000000000..544e9c749f --- /dev/null +++ b/linalg/x86_64/avx512amx/dummy.S @@ -0,0 +1,29 @@ +// Build-time capability probe for the assembler, used by build.rs +// (assembler_supports_amx_int8). Older binutils — notably the Debian stretch +// x86_64 cross-toolchain in CI — predate AMX and cannot assemble these +// mnemonics. If this file fails to assemble, build.rs skips the AMX kernels +// and the `tract_amx_int8` cfg, and the runtime falls back to VNNI (or AVX2) +// for `qmmm_i32`. Not linked into anything. +.intel_syntax noprefix +.text +.globl tract_amx_int8_probe +tract_amx_int8_probe: + push rbp + mov rbp, rsp + sub rsp, 64 // room for the tilecfg block + mov qword ptr [rsp], 0 + mov qword ptr [rsp+8], 0 + mov qword ptr [rsp+16], 0 + mov qword ptr [rsp+24], 0 + mov qword ptr [rsp+32], 0 + mov qword ptr [rsp+40], 0 + mov qword ptr [rsp+48], 0 + mov qword ptr [rsp+56], 0 + mov byte ptr [rsp], 1 // palette = 1 + ldtilecfg [rsp] + tilezero tmm0 + tdpbusd tmm0, tmm1, tmm2 + tilerelease + mov rsp, rbp + pop rbp + ret diff --git a/linalg/x86_64/avx512amx/dummy_avxvnni.S b/linalg/x86_64/avx512amx/dummy_avxvnni.S new file mode 100644 index 0000000000..0579b2ed84 --- /dev/null +++ b/linalg/x86_64/avx512amx/dummy_avxvnni.S @@ -0,0 +1,16 @@ +// Build-time capability probe for the assembler, used by build.rs +// (assembler_supports_avxvnni). Checks that the assembler accepts the +// `{vex}` prefix on VPDPBUSD, which forces the AVX-VNNI (VEX-encoded) +// form instead of the AVX-512-VNNI (EVEX-encoded) form gas defaults to. +// Requires binutils >= 2.36 (which added `{vex}`/`{evex}` prefixes for +// explicit encoding selection). When the probe fails the AVX-VNNI kernel +// is skipped and dispatch falls back to AVX2 emulation on AVX-VNNI-only +// hardware (Clearwater Forest / Sierra Forest / Alder Lake E-cores). +// Not linked into anything. +.intel_syntax noprefix +.text +.globl tract_avxvnni_probe +tract_avxvnni_probe: + // AVX-VNNI: u8 x s8 -> i32 dot4 (VEX-encoded) + {vex} vpdpbusd ymm0, ymm1, ymm2 + ret diff --git a/linalg/x86_64/avx512amx/dummy_bf16.S b/linalg/x86_64/avx512amx/dummy_bf16.S new file mode 100644 index 0000000000..03f5a8d7a4 --- /dev/null +++ b/linalg/x86_64/avx512amx/dummy_bf16.S @@ -0,0 +1,28 @@ +// Build-time capability probe for the assembler, used by build.rs +// (assembler_supports_amx_bf16). Checks that the assembler accepts the +// TDPBF16PS mnemonic (AMX bf16 dot-product). Same binutils version requirement +// as AMX int8 (>= 2.34); provided as a separate probe so the two cfgs can be +// set independently if needed. Not linked into anything. +.intel_syntax noprefix +.text +.globl tract_amx_bf16_probe +tract_amx_bf16_probe: + push rbp + mov rbp, rsp + sub rsp, 64 + mov qword ptr [rsp], 0 + mov qword ptr [rsp+8], 0 + mov qword ptr [rsp+16], 0 + mov qword ptr [rsp+24], 0 + mov qword ptr [rsp+32], 0 + mov qword ptr [rsp+40], 0 + mov qword ptr [rsp+48], 0 + mov qword ptr [rsp+56], 0 + mov byte ptr [rsp], 1 // palette = 1 + ldtilecfg [rsp] + tilezero tmm0 + tdpbf16ps tmm0, tmm1, tmm2 // AMX bf16: the instruction this probe checks + tilerelease + mov rsp, rbp + pop rbp + ret diff --git a/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 new file mode 100644 index 0000000000..654cc660cd --- /dev/null +++ b/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 @@ -0,0 +1,514 @@ +// vim: set syntax=asm : +// +// Intel AMX bf16 GEMM kernel, 16 M-rows x 16 N-cols f32 accumulator output. +// +// One `tdpbf16ps tmm0, tmm1, tmm2` instruction performs: +// tmm0[m, n] += sum_{k=0..31} A[m, k] * B[k, n] (multiplies in bf16, +// accumulates in f32) +// for m=0..15, n=0..15: 16 * 16 * 32 = 8192 fma per single instruction -- +// the same throughput as TDPBSSD on the same hardware. Accelerates f32 +// matmul on Sapphire Rapids+ at the cost of bf16 input truncation +// (~1/256 relative error per multiply; sqrt(K) compounded by FMA chain). +// +// Tile geometry (palette 1, the maximum-bytes AMX tile shape): +// tmm0 = C accumulator: 16 rows x 64 colsb = 16 M-rows x 16 N-cols of f32 +// tmm1 = A tile: 16 rows x 64 colsb = 16 M-rows x 32 K-bf16 / iter +// tmm2 = B tile: 16 rows x 64 colsb = 16 K-pair-rows x 16 N x 2 bf16 +// +// A is packed via PackedAmxBf16A(16): per panel of 16 M-rows, row-major +// within the panel, K-bf16 contiguous along the row, K_padded = +// ceil(K/32)*32 bf16. Source f32 is truncated to bf16 at pack time using +// round-to-nearest-even (matches VCVTNEPS2BF16 semantics). +// +// B is packed via PackedBf16K2(16): per K=2 block, 16 N-cols x 2 K-bf16 = +// 64 bytes; 16 K-blocks per tmm2 tile. Source f32 -> bf16 same as A. +// +// REGISTER LAYOUT (mirrors the i32 16x16 sibling): +// zmm0..zmm15 = accumulators, ROW-MAJOR: zmm{m} = row m of C as 16 f32 +// lanes [C[m, 0], C[m, 1], ..., C[m, 15]]. + +{% if msvc %} + +_text segment +avx512amx_mmm_f32_16x16_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512amx_mmm_f32_16x16_{{suffix}} +{{G}}avx512amx_mmm_f32_16x16_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + + // Reserve 64 bytes for the AMX tile-config block, zero it, populate + // palette + dims (all three tiles are 16 rows x 64 colsb). Same shape + // as the i32 16x16 sibling. + sub rsp, 64 + vpxor xmm15, xmm15, xmm15 + vmovdqu [rsp ], xmm15 + vmovdqu [rsp + 16], xmm15 + vmovdqu [rsp + 32], xmm15 + vmovdqu [rsp + 48], xmm15 + mov byte ptr [rsp + 0 ], 1 // palette = 1 + mov word ptr [rsp + 16], 64 // colsb[0] = 64 (tmm0) + mov word ptr [rsp + 18], 64 // colsb[1] = 64 (tmm1) + mov word ptr [rsp + 20], 64 // colsb[2] = 64 (tmm2) + mov byte ptr [rsp + 48], 16 // rows[0] = 16 (tmm0) + mov byte ptr [rsp + 49], 16 // rows[1] = 16 (tmm1) + mov byte ptr [rsp + 50], 16 // rows[2] = 16 (tmm2) + ldtilecfg [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + {% for r in range(0, 16) %} + vpxorq zmm{{r}}, zmm{{r}}, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_bf16 + +{{L}}main_loop_packed_packed: + // Generic f32 x f32 fallback path (non-AMX). For row-major + // accumulators zmm{m} = row m of C, accumulating C[m, n] += A[m, k] * + // B[k, n]: load 16 B values for this K row into zmm16, then for each + // m broadcast A[m, k] and FMA add to zmm{m}. + vmovups zmm16, [rbx] // 16 f32 of B at this K row + + {% for m in range(0, 16) %} + vbroadcastss zmm17, dword ptr [rax + {{m}} * 4] + vfmadd231ps zmm{{m}}, zmm16, zmm17 + {% endfor %} + + add rax, 64 // 16 f32 lanes per K step + add rbx, 64 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_bf16: + // AMX bf16 layout: + // A panel: 16 M-rows x K_padded bf16 ROW-major within the panel + // (PackedAmxBf16A, K_padded = ceil(K/32)*32 bf16 = + // ceil(K/32)*64 bytes per row). + // B panel: PackedBf16K2(16) -- 16 N-cols x 2 K-bf16 per K=2 block, + // with 16 K-blocks per tdpbf16ps iter (16 K-pair-rows x + // 64 colsb). + // + // Per tdpbf16ps: tmm0[m, n] += sum_{k=0..31} A[m, k] * B[k, n] + // with multiplies in bf16 and accumulation in f32. Inner loop steps + // along K in 32-bf16 chunks (= 64 bytes per A row). + + // r8 <- K_padded_in_bytes = ceil(k/32)*64 = byte-stride between A's + // M-rows. (Each bf16 is 2 bytes, so K_padded_bf16 * 2.) + mov r8, rcx + add r8, 31 + and r8, -32 + shl r8, 1 // *2 (bf16 = 2 bytes) + + // rcx <- ceil(k/32) = number of K=32 AMX inner iterations. + add rcx, 31 + shr rcx, 5 + + // r9 <- 64 = byte-stride between B's K-pair-rows (16 N-cols * 4 bytes + // per K-pair = 16 * 4 = 64). + mov r9, 64 + + tilezero tmm0 + +{{L}}loop_32k_amx_bf16_16x16: + // oneDNN-aligned cache strategy (same as the i32 sibling): + // A -> cached (tileloadd + prefetcht0 to L1), reused across N-tiles. + // B -> non-temporal (tileloaddt1 + prefetcht1 to L2), streams once. + // Each iter advances A by 64 bytes and B by 1024 bytes; we prime the + // first 6 of next-iter's 16 B cache lines and let the SPR HW stream + // prefetcher cover the remaining 10. + prefetcht0 [rax + 64] + prefetcht1 [rbx + 1024] + prefetcht1 [rbx + 1088] + prefetcht1 [rbx + 1152] + prefetcht1 [rbx + 1216] + prefetcht1 [rbx + 1280] + prefetcht1 [rbx + 1344] + tileloadd tmm1, [rax + r8 * 1] // A tile (cached): stride = K_padded_bytes + tileloaddt1 tmm2, [rbx + r9 * 1] // B tile (non-temporal): stride = 64 + tdpbf16ps tmm0, tmm1, tmm2 + add rax, 64 // +32 bf16 in A row 0 + add rbx, 1024 // 16 K-pairs * 64 = 1024 B + dec rcx + jnz {{L}}loop_32k_amx_bf16_16x16 + + // tmm0 -> stack scratch (16 rows x 64 bytes = 1024 B row-major f32). + // Each row's 16 f32 are contiguous, so one 64-byte load per row. + sub rsp, 1024 + mov r10, rsp + mov r11, 64 + tilestored [r10 + r11 * 1], tmm0 + + {% for m in range(0, 16) %} + vmovups zmm{{m}}, [r10 + {{ m * 64 }}] + {% endfor %} + + add rsp, 1024 + + jmp {{L}}non_linear_loop + +// ---- Scalar / per-row / per-col f32 epilogues ---------------------------- + +{{L}}scalar_min: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vminps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_max: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vmaxps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_add: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vaddps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_mul: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vmulps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub: + // non-flipped sub = operand - acc (matches fma_mmm_ymm_ops.j2 scalar macro) + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub_flipped: + // flipped sub = acc - operand + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}leaky_relu: + // C[m, n] = (C[m, n] >= 0) ? C[m, n] : alpha * C[m, n] + vbroadcastss zmm17, dword ptr [rdi + 8] // alpha + vpxorq zmm16, zmm16, zmm16 // 0.0 + {% for r in range(0, 16) %} + vmulps zmm18, zmm{{r}}, zmm17 // alpha * x + vcmpps k1, zmm{{r}}, zmm16, 1 // imm 1 = LT (signed): 1 where x < 0 + vblendmps zmm{{r}}{k1}, zmm{{r}}, zmm18 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_min: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vminps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_max: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vmaxps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_add: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vaddps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_mul: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vmulps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub: + // non-flipped sub = operand - acc + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vsubps zmm{{m}}, zmm16, zmm{{m}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub_flipped: + // flipped sub = acc - operand + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vsubps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_min: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vminps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_max: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vmaxps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_add: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vaddps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_mul: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vmulps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub: + // non-flipped sub = operand - acc + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub_flipped: + // flipped sub = acc - operand + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}load_tile: + // Scratch layout is COL-MAJOR f32 (col_byte_stride = item_size * MR = + // 4 * 16 = 64): tile[col][row] at offset col*64 + row*4. Gather row m's + // 16 cols at index step 64. + mov r8, [rdi + 8] + vmovdqa32 zmm16, [rip + {{L}}lane_offsets_64] + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm{{m}}{k1}, [r8 + zmm16 + {{m * 4}}] + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_unicast: + mov r10, [rdi + 8] // c ptr (base) + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size (4 for f32) + + cmp r8, 4 + jne {{L}}unsupported // f32 kernel: only item_size=4 + + // i32-strided gather (f32 same bit-width: vpgatherdd is correct). + mov eax, ebx + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + vpmulld zmm16, zmm16, [rip + {{L}}lane_indices] + + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm17{k1}, [r10 + zmm16] + vaddps zmm{{m}}, zmm{{m}}, zmm17 + add r10, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_row_col_products: + // bias[m, n] = row_data[m] * col_data[n], FMA-add to C[m, n]. + mov rax, [rdi + 8] + mov rbx, [rdi + 16] + + vmovups zmm16, [rbx] // 16 col_data values + + {% for m in range(0, 16) %} + vbroadcastss zmm17, dword ptr [rax + {{m * 4}}] // splat row_data[m] + vfmadd231ps zmm{{m}}, zmm17, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +// ---- q_scale / q_shr / q_shl: not meaningful for f32, stub to unsupported. +{{L}}q_scale: +{{L}}q_shl: +{{L}}q_shr: + jmp {{L}}unsupported + +// ---- Store --------------------------------------------------------------- + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + jne {{L}}unsupported // f32 kernel: only item_size=4 + + cmp rdx, 4 + je {{L}}store_strides_f32_row_contig + + // Generic f32 strided store + {% for m in range(0, 16) %} + mov r10, r8 + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_f32_row_contig: + // C is row-major in memory: each row's 16 f32 are contiguous; one + // 64-byte vmovups per row. + {% for m in range(0, 16) %} + vmovups [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}return: + tilerelease + add rsp, 64 + + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + +// ---- Read-only data (RIP-relative) --------------------------------------- + +.p2align 6 +{{L}}lane_offsets_64: + .int 0, 64, 128, 192, 256, 320, 384, 448 + .int 512, 576, 640, 704, 768, 832, 896, 960 + +.p2align 6 +{{L}}lane_indices: + .int 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + +{% if msvc %} +avx512amx_mmm_f32_16x16_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} diff --git a/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 new file mode 100644 index 0000000000..5c97cb7c6a --- /dev/null +++ b/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 @@ -0,0 +1,932 @@ +// vim: set syntax=asm : +// +// Intel AMX int8 GEMM kernel, 16 M-rows x 16 N-cols i32 accumulator output. +// +// One `tdpbssd tmm0, tmm1, tmm2` instruction performs: +// tmm0[m, n] += sum_{k=0..63} A[m, k] * B[k, n] +// for m=0..15, n=0..15: 16 * 16 * 64 = 16384 mul-adds per single instruction. +// That's 4x the work-per-instruction of the 8x8 sibling kernel, hitting the +// full AMX i8 tile geometry (max colsb=64, max rows=16, max bytes=1024). +// +// Tile geometry (palette 1): +// tmm0 = C accumulator: 16 rows x 64 colsb = 16 M-rows x 16 N-cols of i32 +// tmm1 = A tile: 16 rows x 64 colsb = 16 M-rows x 64 K-bytes per iter +// tmm2 = B tile: 16 rows x 64 colsb = 16 K-pair-rows x 16 N-cols * 4 K +// +// `tdpbssd` is signed-signed, so no +128 trick is needed; the i32 accumulators +// are bit-identical to the AVX2 / VNNI / 8x8-AMX reference paths. +// +// A is packed via PackedAmxA(16): per panel of 16 M-rows, row-major within the +// panel, K-bytes contiguous along the row, K_padded = ceil(K/64)*64. +// B reuses PackedI8K4(16): per K=4 block, 16 N-cols * 4 K-bytes = 64 bytes; +// 16 such K-blocks per tmm2 tile = 1024 bytes = one tileloadd. +// +// REGISTER LAYOUT +// zmm0..zmm15 = accumulators, ROW-MAJOR: zmm{m} holds the 16 i32 lanes +// [C[m, 0], C[m, 1], ..., C[m, 15]] for row m. +// This matches the row-major i32 layout that `tilestored` writes directly, +// so the hot path (Clear -> AddMatMul -> Store) needs no transpose. + +{% if msvc %} + +_text segment +avx512amx_mmm_i32_16x16_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512amx_mmm_i32_16x16_{{suffix}} +{{G}}avx512amx_mmm_i32_16x16_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + + // Reserve 64 bytes for the AMX tile-config block, zero it, populate + // palette + dims (all three tiles are 16 rows x 64 colsb, the maximum + // i8 tile geometry on Sapphire Rapids / Emerald Rapids / Granite Rapids). + sub rsp, 64 + vpxor xmm15, xmm15, xmm15 + vmovdqu [rsp ], xmm15 + vmovdqu [rsp + 16], xmm15 + vmovdqu [rsp + 32], xmm15 + vmovdqu [rsp + 48], xmm15 + mov byte ptr [rsp + 0 ], 1 // palette = 1 + mov word ptr [rsp + 16], 64 // colsb[0] = 64 (tmm0) + mov word ptr [rsp + 18], 64 // colsb[1] = 64 (tmm1) + mov word ptr [rsp + 20], 64 // colsb[2] = 64 (tmm2) + mov byte ptr [rsp + 48], 16 // rows[0] = 16 (tmm0) + mov byte ptr [rsp + 49], 16 // rows[1] = 16 (tmm1) + mov byte ptr [rsp + 50], 16 // rows[2] = 16 (tmm2) + ldtilecfg [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + // vzeroall only zeros lower-256 of zmm0..15; explicitly zero the full + // accumulators (zmm0..zmm15) for AMX. + {% for r in range(0, 16) %} + vpxorq zmm{{r}}, zmm{{r}}, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_i8i8 + +{{L}}main_loop_packed_packed: + // Generic i32 x i32 fallback path (not AMX). For row-major accumulators + // with zmm{m} = row m of C, accumulating C[m, n] += A[m, k] * B[k, n]: + // - load 16 B values for this K row into zmm16 (row of B) + // - for each m: broadcast A[m, k], multiply by zmm16, add to zmm{m} + vmovups zmm16, [rbx] // 16 i32 of B at this K row + + {% for m in range(0, 16) %} + vpbroadcastd zmm17, dword ptr [rax + {{m}} * 4] + vpmulld zmm18, zmm16, zmm17 + vpaddd zmm{{m}}, zmm{{m}}, zmm18 + {% endfor %} + + add rax, 64 // 16 i32 lanes per K step + add rbx, 64 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_i8i8: + // AMX i8 layout: + // A panel: 16 M-rows x K_padded K-bytes ROW-major within the panel + // (PackedAmxA, K_padded = ceil(K/64)*64). + // B panel: PackedI8K4(16) -- 16 N-cols x 4 K-bytes per K=4 block, with + // 16 K-blocks per tileloadd (16 K-pair-rows x 64 colsb). + // + // Per tdpbssd: tmm0[m, n] += sum_{k=0..63} A[m, k] * B[k, n]. + // Inner loop steps along K in 64-K chunks. + + // r8 <- K_padded = ceil(k/64) * 64 = byte-stride between A's M-rows. + mov r8, rcx + add r8, 63 + and r8, -64 + + // rcx <- ceil(k/64) = number of K=64 AMX inner iterations. + add rcx, 63 + shr rcx, 6 + + // r9 <- 64 = byte-stride between B's K-pair-rows (each row = 16 N-cols * 4 K). + mov r9, 64 + + tilezero tmm0 + +{{L}}loop_64k_amx_i8i8_16x16: + // Cache strategy follows oneDNN's AMX BRGEMM heuristics (Intel-backed): + // - A is reused across N-tiles in tract's outer matmul loop, so we use + // `tileloadd` (cached, brings into L1) and `prefetcht0` (to L1) for A. + // - B streams through once per kernel call (one B-panel per N-tile), and + // for the AMX-typical large-matmul case the B working set exceeds the + // 32 KB L1d on Sapphire Rapids+. We use `tileloaddt1` (non-temporal, + // bypasses L1) and `prefetcht1` (to L2) for B -- the same pattern + // oneDNN picks when its footprint heuristic crosses the L1 threshold. + // - Sapphire Rapids has 16 L1d Fill Buffers (LFBs); each in-flight + // prefetch/load consumes one. The previous version's 17 prefetches + + // 2 active tileloadds overflowed the LFB budget. The reduced count + // below leaves headroom and lets the HW streaming prefetcher cover + // the remaining B-panel lines. + // + // A advances 64 B / iter (one cache line). B advances 1024 B / iter + // (16 cache lines). We prime 6 of the next 16 B-lines at +1024..+1344, + // then trust the HW stream prefetcher (very aggressive on SPR/EMR/GNR) + // to cover lines +1408..+1984. + prefetcht0 [rax + 64] // next A-row K-block (to L1) + prefetcht1 [rbx + 1024] // next B-panel head (to L2) + prefetcht1 [rbx + 1088] + prefetcht1 [rbx + 1152] + prefetcht1 [rbx + 1216] + prefetcht1 [rbx + 1280] + prefetcht1 [rbx + 1344] + tileloadd tmm1, [rax + r8 * 1] // A tile (cached): stride = K_padded + tileloaddt1 tmm2, [rbx + r9 * 1] // B tile (non-temporal): stride = 64 + tdpbssd tmm0, tmm1, tmm2 + add rax, 64 // +64 K-bytes in A row 0 + add rbx, 1024 // 16 K-pairs * 64 = 1024 bytes + dec rcx + jnz {{L}}loop_64k_amx_i8i8_16x16 + + // tmm0 -> stack scratch (16 rows x 64 bytes = 1024 B row-major i32). + // Then load each row into zmm0..zmm15. Row m's 16 i32 are contiguous + // in memory, so each load is a single 64-byte vmovdqu32. + sub rsp, 1024 + mov r10, rsp + mov r11, 64 + tilestored [r10 + r11 * 1], tmm0 + + {% for m in range(0, 16) %} + vmovdqu32 zmm{{m}}, [r10 + {{ m * 64 }}] + {% endfor %} + + add rsp, 1024 + + jmp {{L}}non_linear_loop + +// ---- Scalar / per-row / per-col elementwise epilogues ------------------- + +{{L}}scalar_min: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpminsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_max: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpmaxsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_add: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpaddd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_mul: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpmulld zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub: + // non-flipped sub = operand - acc (matches fma_mmm_ymm_ops.j2 scalar macro) + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub_flipped: + // flipped sub = acc - operand + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}leaky_relu: + // C[m, n] = (C[m, n] >= 0) ? C[m, n] : alpha * C[m, n] + vpbroadcastd zmm17, dword ptr [rdi + 8] // alpha as i32 scale factor + vpxorq zmm16, zmm16, zmm16 + {% for r in range(0, 16) %} + vpmulld zmm18, zmm{{r}}, zmm17 + vpcmpgtd k1, zmm16, zmm{{r}} // 1 where C < 0 + vpblendmd zmm{{r}}{k1}, zmm{{r}}, zmm18 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_min: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpminsd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_max: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpmaxsd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_add: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpaddd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_mul: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpmulld zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub: + // non-flipped sub = operand - acc + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpsubd zmm{{m}}, zmm16, zmm{{m}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub_flipped: + // flipped sub = acc - operand + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpsubd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_min: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpminsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_max: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpmaxsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_add: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpaddd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_mul: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpmulld zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub: + // non-flipped sub = operand - acc + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub_flipped: + // flipped sub = acc - operand + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}load_tile: + // Scratch layout is COL-MAJOR i32 from scratch.rs Store/AddUnicast remnant: + // tile[col][row] at offset (col*MR + row)*4 with MR=16 + // = offset col*64 + row*4 + // For row-major accumulators we gather row m's 16 cols at index step 64. + mov r8, [rdi + 8] + vmovdqa32 zmm16, [rip + {{L}}lane_offsets_64] // [0, 64, 128, ..., 15*64] + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm{{m}}{k1}, [r8 + zmm16 + {{m * 4}}] + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_unicast: + mov r10, [rdi + 8] // c ptr (base) + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size + + cmp r8, 4 + je {{L}}non_linear_addc_i32 + + // i8 path: read 16 i8 from [r10 + m*rsi + n*rbx] for n=0..15, sign-extend + // to i32, add to zmm{m}. Use a stack scratch buffer (16 bytes per row). + sub rsp, 16 + {% for m in range(0, 16) %} + mov r8, r10 + {% for n in range(0, 16) %} + mov al, [r8] + mov byte ptr [rsp + {{n}}], al + add r8, rbx + {% endfor %} + vpmovsxbd zmm16, [rsp] + vpaddd zmm{{m}}, zmm{{m}}, zmm16 + add r10, rsi + {% endfor %} + add rsp, 16 + jmp {{L}}non_linear_loop + +{{L}}non_linear_addc_i32: + // i32 strided read of external (or scratch) tile. Build per-lane index + // vector [0, rbx, 2*rbx, ..., 15*rbx] once, then gather row by row. + mov eax, ebx + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + vpmulld zmm16, zmm16, [rip + {{L}}lane_indices] // [0, rbx, 2*rbx, ...] + + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm17{k1}, [r10 + zmm16] + vpaddd zmm{{m}}, zmm{{m}}, zmm17 + add r10, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_row_col_products: + // bias[m, n] = row_data[m] * col_data[n], add to C[m, n]. + // For row-major regs: load 16 col_data values once into zmm16, + // for each m: broadcast row_data[m], FMA add. + mov rax, [rdi + 8] + mov rbx, [rdi + 16] + + vmovdqu32 zmm16, [rax] // 16 row_data values + vmovdqu32 zmm17, [rbx] // 16 col_data values + + {% for m in range(0, 16) %} + vpbroadcastd zmm18, dword ptr [rax + {{m * 4}}] // splat row_data[m] + vpmulld zmm19, zmm18, zmm17 + vpaddd zmm{{m}}, zmm{{m}}, zmm19 + {% endfor %} + jmp {{L}}non_linear_loop + +// ---- Q-scale (mult-shift with rounding) --------------------------------- + +{{L}}q_scale: + mov r8, [rdi + 16] // policy + vpbroadcastd zmm16, dword ptr [rdi + 24] // multi (broadcast i32) + + mov rax, 1 + vmovq xmm17, rax + vpbroadcastq zmm17, xmm17 // zmm17 <- 1 (i64 lanes) + + mov rax, [rdi + 8] // shift + add rax, 31 + vmovq xmm18, rax + vpbroadcastq zmm18, xmm18 // zmm18 <- (shift+31) (i64 lanes) + + vpsubq zmm19, zmm18, zmm17 + vpsllvq zmm19, zmm17, zmm19 // zmm19 <- 1 << (shift+31-1) (i64) + + // Per-lane interleave mask for blending evens / shifted-odds. + // bit i = 1 means take from "evens" source in vpblendmd; bit 0,2,4,...,14 set. + mov eax, 0x5555 + kmovw k7, eax + + cmp r8, 1 + je {{L}}q_scale_rounding_zero + cmp r8, 2 + je {{L}}q_scale_rounding_away + cmp r8, 3 + je {{L}}q_scale_rounding_minus_inf + cmp r8, 4 + je {{L}}q_scale_rounding_plus_inf + cmp r8, 5 + je {{L}}q_scale_rounding_even + cmp r8, 6 + je {{L}}q_scale_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_scale_rounding_zero: // signum * ( (abs + nudge - 1) >> shift ) +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 // even-lane i32 -> i64 mul + vpmuldq zmm21, zmm21, zmm16 // odd-lane i32 -> i64 mul + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsubq zmm20, zmm20, zmm17 + vpsubq zmm21, zmm21, zmm17 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 // k7=0x5555: evens from zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_away: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_minus_inf: // nudge by -1 where input was negative +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpxorq zmm22, zmm22, zmm22 + vpcmpgtd k1, zmm{{i}}, zmm22 // k1: 1 where input > 0 (we want the inverse, see below) + knotw k1, k1 // 1 where input <= 0 -- we want "input was negative => subtract 1" + // For "<0": use compare against 0 with vpcmpltd + vpxorq zmm22, zmm22, zmm22 + vpcmpltd k1, zmm{{i}}, zmm22 // 1 where input < 0 + vmovdqa32 zmm23{k1}{z}, [rip + {{L}}all_ones_i32] // (1 << 0) per neg lane, 0 elsewhere + + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + // Subtract 1 from i64-evens / i64-odds where the original i32 input was < 0. + vpsrldq zmm24, zmm23, 4 + vpmovsxdq zmm25, ymm23 + vpmovsxdq zmm26, ymm24 + vpsubq zmm20, zmm20, zmm25 + vpsubq zmm21, zmm21, zmm26 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_plus_inf: // nudge by +1 where input was non-negative +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpxorq zmm22, zmm22, zmm22 + vpcmpled k1, zmm22, zmm{{i}} // 1 where input >= 0 + vmovdqa32 zmm23{k1}{z}, [rip + {{L}}all_ones_i32] + + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrldq zmm24, zmm23, 4 + vpmovsxdq zmm25, ymm23 + vpmovsxdq zmm26, ymm24 + vpsubq zmm20, zmm20, zmm25 + vpsubq zmm21, zmm21, zmm26 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_even: // banker's: round half to even +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpsrlq zmm22, zmm20, xmm18 + vpandq zmm22, zmm22, zmm17 + vpaddq zmm20, zmm20, zmm22 + vpsubq zmm20, zmm20, zmm17 + + vpsrlq zmm22, zmm21, xmm18 + vpandq zmm22, zmm22, zmm17 + vpaddq zmm21, zmm21, zmm22 + vpsubq zmm21, zmm21, zmm17 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_odd: // round half to odd +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpsrlq zmm22, zmm20, xmm18 + vpandq zmm22, zmm22, zmm17 + vpsubq zmm20, zmm20, zmm22 + + vpsrlq zmm22, zmm21, xmm18 + vpandq zmm22, zmm22, zmm17 + vpsubq zmm21, zmm21, zmm22 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shl: + mov eax, [rdi + 8] // -shift (count: i32) + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + {% for i in range(0, 16) %}vpsllvd zmm{{i}}, zmm{{i}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr: + mov r8, [rdi + 16] // policy + + mov eax, 1 + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 // zmm16 <- 1 (i32 lanes) + + mov eax, [rdi + 8] // shift + vmovd xmm17, eax + vpbroadcastd zmm17, xmm17 // zmm17 <- shift (i32 lanes) + + mov ebx, 1 + mov cl, al + sub cl, 1 + sal ebx, cl // ebx <- 1 << (shift - 1) + vmovd xmm18, ebx + vpbroadcastd zmm18, xmm18 // zmm18 <- "half" + + vpxorq zmm19, zmm19, zmm19 // zeroes + + cmp r8, 1 + je {{L}}q_shr_rounding_zero + cmp r8, 2 + je {{L}}q_shr_rounding_away + cmp r8, 3 + je {{L}}q_shr_rounding_minus_inf + cmp r8, 4 + je {{L}}q_shr_rounding_plus_inf + cmp r8, 5 + je {{L}}q_shr_rounding_even + cmp r8, 6 + je {{L}}q_shr_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_shr_rounding_zero: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsubd zmm20, zmm20, zmm16 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_away: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_minus_inf: +{% for i in range(0, 16) %} + vpsubd zmm{{i}}, zmm{{i}}, zmm16 + vpaddd zmm{{i}}, zmm{{i}}, zmm18 + vpsravd zmm{{i}}, zmm{{i}}, zmm17 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_plus_inf: +{% for i in range(0, 16) %} + vpaddd zmm{{i}}, zmm{{i}}, zmm18 + vpsravd zmm{{i}}, zmm{{i}}, zmm17 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_even: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsravd zmm21, zmm20, zmm17 + vpandq zmm21, zmm21, zmm16 + vpsubd zmm21, zmm21, zmm16 // nudge = ((abs >>l shift) & 1) - 1 + vpaddd zmm20, zmm20, zmm21 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_odd: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsravd zmm21, zmm20, zmm17 + vpandq zmm21, zmm21, zmm16 + vpsubd zmm21, zmm19, zmm21 // nudge = -((abs >>l shift) & 1) + vpaddd zmm20, zmm20, zmm21 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +// ---- Store --------------------------------------------------------------- + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + je {{L}}store_strides_i32 + // else: i8 fallthrough + + cmp rdx, 1 + je {{L}}store_strides_i8_row_contig + + // Generic i8 strided store: per row, per lane scalar byte stores + {% for m in range(0, 16) %} + mov r10, r8 + // Extract from each 128-bit slice of zmm{{m}} + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i8_row_contig: + // Each row is 16 i8 contiguous; one vpmovdb per row. + {% for m in range(0, 16) %} + vpmovdb [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32: + cmp rdx, 4 + je {{L}}store_strides_i32_row_contig + + // Generic i32 strided store + {% for m in range(0, 16) %} + mov r10, r8 + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32_row_contig: + // C is row-major in memory: each row's 16 i32 are contiguous; one + // 64-byte aligned-or-unaligned store per row. + {% for m in range(0, 16) %} + vmovdqu32 [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}return: + tilerelease + add rsp, 64 + + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + +// ---- Read-only data (RIP-relative) --------------------------------------- + +.p2align 6 +{{L}}lane_offsets_64: + .int 0, 64, 128, 192, 256, 320, 384, 448 + .int 512, 576, 640, 704, 768, 832, 896, 960 + +.p2align 6 +{{L}}lane_indices: + .int 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + +.p2align 6 +{{L}}all_ones_i32: + .int 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + +{% if msvc %} +avx512amx_mmm_i32_16x16_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} diff --git a/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 new file mode 100644 index 0000000000..ff65b6484a --- /dev/null +++ b/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 @@ -0,0 +1,764 @@ +{# +// vim: set syntax=asm : + +/* mmm 8x8: + + ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6 ymm7 + +System V ABI: + args: rdi, rsi, rdx, rcx, r8, r9 + preserve: rbx, rsp, rbp, r12, r13, r14, r15 + scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11 + return: rax (+rdx) + +Windows ABI: + args: RCX, RDX, R8, R9 + preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15 + scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15 + return: rax (+rdx) +*/ +#} + +{% if msvc %} + +_text segment +avx512amx_mmm_i32_8x8_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512amx_mmm_i32_8x8_{{suffix}} +{{G}}avx512amx_mmm_i32_8x8_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} +// https://www.agner.org/optimize/calling_conventions.pdf xmm6-15 are not scratch +// https://stackoverflow.com/questions/43358429/save-value-of-xmm-registers + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + + // Reserve 64 bytes of stack for the AMX tile-config block, zero it, + // populate palette + tile dimensions, then ldtilecfg. The tile config + // stays live for the whole function; tilerelease is emitted at return. + // + // tmm0 = C accumulator: 8 rows x 32 colsb (8 M-rows x 8 N-cols of i32) + // tmm1 = A tile: 8 rows x 64 colsb (8 M-rows x 64 K-bytes per inner iter) + // tmm2 = B tile: 16 rows x 32 colsb (16 K-pair-rows x 8 N-cols * 4 K-bytes) + sub rsp, 64 + vpxor xmm15, xmm15, xmm15 + vmovdqu [rsp ], xmm15 + vmovdqu [rsp + 16], xmm15 + vmovdqu [rsp + 32], xmm15 + vmovdqu [rsp + 48], xmm15 + mov byte ptr [rsp + 0 ], 1 // palette = 1 + mov word ptr [rsp + 16], 32 // colsb[0] = 32 (tmm0) + mov word ptr [rsp + 18], 64 // colsb[1] = 64 (tmm1) + mov word ptr [rsp + 20], 32 // colsb[2] = 32 (tmm2) + mov byte ptr [rsp + 48], 8 // rows[0] = 8 (tmm0) + mov byte ptr [rsp + 49], 8 // rows[1] = 8 (tmm1) + mov byte ptr [rsp + 50], 16 // rows[2] = 16 (tmm2) + ldtilecfg [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + vzeroall + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_i8i8 + +{{L}}main_loop_packed_packed: + vmovaps ymm12, [rax] + + {% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{i}} * 4] + vpmulld ymm13, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm13 + {% endfor %} + + add rax, 32 + add rbx, 32 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_i8i8: + // AMX i8 layout: A panel is 8 M-rows x K_padded K-bytes ROW-major within + // each 8-row panel (PackedAmxA8); B panel reuses the existing VNNI K=4- + // inner format (8 N-cols x 4 K-bytes per K-block, 16 such blocks per + // K=64 AMX tile). K is padded to a multiple of 64 by the packer. + // + // tdpbssd is s8 x s8 -> i32 (Sapphire Rapids+), so no +128 trick is needed: + // the i32 accumulators are bit-identical to the AVX2 / VNNI paths. + // + // Per tdpbssd: tmm0[m, n] += sum_{k=0..63} A[m, k] * B[k, n] + // (16 M-rows x 16 N-i32-lanes x 64 K = 16384 mul-acc per instruction) + + // r8 <- K_padded = ceil(k/64) * 64 = byte-stride between A's M-rows. + mov r8, rcx + add r8, 63 + and r8, -64 + + // rcx <- ceil(k/64) = number of K=64 AMX inner iterations. + add rcx, 63 + shr rcx, 6 + + // r9 <- 32 = byte-stride between B's K-pair-rows. + mov r9, 32 + + tilezero tmm0 + +{{L}}loop_64k_amx_i8i8: + // Prefetch the data we'll need ONE iteration ahead. tileloadd brings + // the active tile data into L1 on demand; the prefetcht0 hints below + // ask the hardware prefetcher to start the L2->L1 fill for the next + // iter's A row (64 B further along the K axis) and the next iter's + // B panel (512 B = 8 cache lines further). For the long K loops + // (K>=256) the B-side prefetch matters most since each iter consumes + // 8 cache lines of B vs 1 cache line of A row 0. + prefetcht0 [rax + 64] + prefetcht0 [rbx + 512] + prefetcht0 [rbx + 576] + prefetcht0 [rbx + 640] + prefetcht0 [rbx + 704] + prefetcht0 [rbx + 768] + prefetcht0 [rbx + 832] + prefetcht0 [rbx + 896] + prefetcht0 [rbx + 960] + tileloadd tmm1, [rax + r8 * 1] // A tile: stride r8 = K_padded + tileloadd tmm2, [rbx + r9 * 1] // B tile: stride r9 = 32 + tdpbssd tmm0, tmm1, tmm2 + add rax, 64 // +64 K-bytes in A row 0 + add rbx, 512 // +16 K-pairs * 32 = 512 B bytes + dec rcx + jnz {{L}}loop_64k_amx_i8i8 + + // tmm0 -> ymm0..ymm7 via 256-byte stack scratch (8 rows x 32 bytes). + // After tilestored, the layout is row-major i32: byte (m*32 + n*4) = C[m, n]. + // We need ymm{n} = column n of C with 8 i32 lanes (rows m=0..7) — the + // dispatcher epilogue convention. So we (a) load 8 ymms = 8 rows of C, + // then (b) transpose 8x8 i32 in place. + sub rsp, 256 + mov r10, rsp + mov r11, 32 + tilestored [r10 + r11 * 1], tmm0 + + {% for r in range(0, 8) %} + vmovdqu ymm{{r}}, [r10 + {{ r * 32 }}] + {% endfor %} + + add rsp, 256 + + // 8x8 i32 transpose: ymm0..ymm7 row-major -> column-major in place. + // Stage 1: interleave 32-bit dwords pairwise (ymm0..ymm7 -> ymm8..ymm15). + vpunpckldq ymm8, ymm0, ymm1 // [r0[0], r1[0], r0[1], r1[1], r0[4], r1[4], r0[5], r1[5]] + vpunpckhdq ymm9, ymm0, ymm1 + vpunpckldq ymm10, ymm2, ymm3 + vpunpckhdq ymm11, ymm2, ymm3 + vpunpckldq ymm12, ymm4, ymm5 + vpunpckhdq ymm13, ymm4, ymm5 + vpunpckldq ymm14, ymm6, ymm7 + vpunpckhdq ymm15, ymm6, ymm7 + + // Stage 2: interleave 64-bit quads (ymm8..ymm15 -> ymm0..ymm7). + vpunpcklqdq ymm0, ymm8, ymm10 // [r0[0], r1[0], r2[0], r3[0], r0[4], r1[4], r2[4], r3[4]] + vpunpckhqdq ymm1, ymm8, ymm10 + vpunpcklqdq ymm2, ymm9, ymm11 + vpunpckhqdq ymm3, ymm9, ymm11 + vpunpcklqdq ymm4, ymm12, ymm14 + vpunpckhqdq ymm5, ymm12, ymm14 + vpunpcklqdq ymm6, ymm13, ymm15 + vpunpckhqdq ymm7, ymm13, ymm15 + + // Stage 3: cross-lane permute (128-bit halves). Two phases so we can + // overwrite the inputs incrementally without clobbering needed data. + vperm2i128 ymm8, ymm0, ymm4, 0x20 // col 0: low(y0) | low(y4) + vperm2i128 ymm9, ymm1, ymm5, 0x20 // col 1 + vperm2i128 ymm10, ymm2, ymm6, 0x20 // col 2 + vperm2i128 ymm11, ymm3, ymm7, 0x20 // col 3 + vperm2i128 ymm12, ymm0, ymm4, 0x31 // col 4: high(y0) | high(y4) + vperm2i128 ymm13, ymm1, ymm5, 0x31 // col 5 + vperm2i128 ymm14, ymm2, ymm6, 0x31 // col 6 + vperm2i128 ymm15, ymm3, ymm7, 0x31 // col 7 + + vmovdqa ymm0, ymm8 + vmovdqa ymm1, ymm9 + vmovdqa ymm2, ymm10 + vmovdqa ymm3, ymm11 + vmovdqa ymm4, ymm12 + vmovdqa ymm5, ymm13 + vmovdqa ymm6, ymm14 + vmovdqa ymm7, ymm15 + + jmp {{L}}non_linear_loop + +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_scalars.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_rows.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_cols.j2" %} +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_load_tile.j2" %} + +{{L}}add_unicast: + + mov r10, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size + + cmp r8, 4 + je {{L}}non_linear_addc_i32 + +{# +// This is not great as vgatherdps reads 32-bits values and goes beyond our buffer. Probably harmless though. +// Commented and replaced with the "mov al" loop beyond to pacify valgrind. +// ymm14 and ymm15 are the same as in the non_linear_addc_i32 case (compute them before the test right above here. +// {% for i in range(0, 8) %} +// vpcmpeqd ymm15, ymm15, ymm15 +// vgatherdps ymm12, [ r10 + ymm14 ], ymm15 // 0xxx 1xxx 2xxx 3xxx 4xxx 5xxx 6xxx 7xxx +// +// // we need to go through vpmovsxbd, shuffling naively erases signs +// vpshufb ymm12, ymm12, ymm10 // 0123 0123 0123 0123 4567 4567 4567 4567 +// +// vpermd ymm12, ymm11, ymm12 // 0123 4567 +// vpmovsxbd ymm12, xmm12 // sign extend +// +// vpaddd ymm{{i}}, ymm{{i}}, ymm12 +// add r10, rbx +// {% endfor %} +#} + + {% for col in range(0, 8) %} + mov r8, r10 + {% for half in range(0, 2) %} + {% for lane in range(0, 4) %} + mov al, [ r8 ] + add r8, rsi + movsx eax, al + pinsrd xmm10, eax, {{lane}} + {% endfor %} + vperm2f128 ymm10, ymm10, ymm10, 1 + {% endfor %} + vpaddd ymm{{col}}, ymm{{col}}, ymm10 + add r10, rbx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}non_linear_addc_i32: + + mov eax, 0 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 + + +{% if msvc %} + vpbroadcastd ymm10, dword ptr [ offset byte_shuffle ] + vmovups ymm11, dword ptr [ offset i128_shuffle ] +{% else %} + vpbroadcastd ymm10, [ rip + {{L}}byte_shuffle ] + vmovups ymm11, [ rip + {{L}}i128_shuffle ] +{% endif %} + +{% for i in range(0, 8) %} + vpcmpeqd ymm15, ymm15, ymm15 + vgatherdps ymm12, [ r10 + ymm14 ], ymm15 + vpaddd ymm{{i}}, ymm{{i}}, ymm12 + add r10, rbx +{% endfor %} + + jmp {{L}}non_linear_loop + +{% if msvc %} +.data +byte_shuffle dd 201851904 // 0x0c080400 +i128_shuffle dd 0, 4 +.code +{% else %} +{{L}}byte_shuffle: .int 201851904 // 0x0c080400 +{{L}}i128_shuffle: .int 0, 4 +{% endif %} + +{{L}}add_row_col_products: + mov rax, [ rdi + 8 ] + mov rbx, [ rdi + 16 ] + + vmovups ymm12, [rax] + +{% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{ i * 4 }} ] + vpmulld ymm15, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm15 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale: + mov r8, [ rdi + 16 ] // policy + vbroadcastss ymm8, dword ptr [rdi + 24] // multi + + mov rax, 1 + movq xmm9, rax + vpbroadcastq ymm9, xmm9 // ymm9 <- 1 + + mov rax, [ rdi + 8 ] // xmm10 <- shift + 31 + add rax, 31 + movq xmm10, rax + vpbroadcastq ymm10, xmm10 + + mov rax, 1 + movq xmm11, rax + vpsubq ymm12, ymm10, ymm9 // shift+31 - 1 + vpsllq ymm11, ymm9, xmm12 // ymm11 <- 1 << (shift + 31 - 1) + + cmp r8, 1 + je {{L}}q_scale_rounding_zero + cmp r8, 2 + je {{L}}q_scale_rounding_away + cmp r8, 3 + je {{L}}q_scale_rounding_minus_inf + cmp r8, 4 + je {{L}}q_scale_rounding_plus_inf + cmp r8, 5 + je {{L}}q_scale_rounding_even + cmp r8, 6 + je {{L}}q_scale_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_scale_rounding_zero: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsubq ymm14, ymm14, ymm9 + vpsubq ymm15, ymm15, ymm9 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_away: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_minus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + // sign extract for nudging in the right direction + vpxor ymm13, ymm13, ymm13 + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpsrld ymm13, ymm13, 31 // then just 0 or 1 + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_plus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) + + vpbroadcastd ymm9, xmm9 + +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpxor ymm13, ymm13, ymm13 + + // sign extract for nudging in the right direction + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpaddd ymm13, ymm13, ymm9 // if val >= 0 { 0i32 } else { 1i32 } + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_even: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm14, ymm14, ymm12 + vpsubq ymm14, ymm14, ymm9 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm15, ymm15, ymm12 + vpsubq ymm15, ymm15, ymm9 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_odd: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm14, ymm14, ymm12 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm15, ymm15, ymm12 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_shl: + mov eax, [ rdi + 8 ] // xmm10 <- -shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + +{% for i in range(0, 8) %} + vpsllvd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr: + mov r8, [ rdi + 16 ] // policy + + mov eax, 1 + movd xmm9, eax + vpbroadcastd ymm9, xmm9 // ymm9 <- 1u32 (8 times) + + mov eax, [ rdi + 8 ] // xmm10 <- shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + + mov ebx, 1 + mov cl, al + sub cl, 1 // rcx <- shift -1 + sal ebx, cl // rbx <- (1 << (shift - 1)) + movd xmm11, ebx + vpbroadcastd ymm11, xmm11 // ymm11 <- "half" + + vpxor ymm12, ymm12, ymm12 // ymm12 <- zeroes + + cmp r8, 1 + je {{L}}q_shr_rounding_zero + cmp r8, 2 + je {{L}}q_shr_rounding_away + cmp r8, 3 + je {{L}}q_shr_rounding_minus_inf + cmp r8, 4 + je {{L}}q_shr_rounding_plus_inf + cmp r8, 5 + je {{L}}q_shr_rounding_even + cmp r8, 6 + je {{L}}q_shr_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_shr_rounding_zero: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsubd ymm14, ymm14, ymm9 + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_away: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_minus_inf: +{% for i in range(0, 8) %} + vpsubd ymm{{i}}, ymm{{i}}, ymm9 + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_plus_inf: +{% for i in range(0, 8) %} + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_even: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm13, ymm9 // nudge = ((abs >>l shift) & 0x01) - 1 + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_odd: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm12, ymm13 // nudge = - ((abs >>l shift) & 0x01) + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + je {{L}}store_strides_i32 + + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32: + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}return: + // Tear down AMX state: release tile registers and reclaim the tile-config + // stack space we allocated right after the standard prologue. + tilerelease + add rsp, 64 + + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + + +{{L}}one_32bit: +{% if msvc %} + dd 1 +{% else %} + .int 1 +{% endif %} + +{% if msvc %} +avx512amx_mmm_i32_8x8_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} diff --git a/linalg/x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2 b/linalg/x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2 new file mode 100644 index 0000000000..4c61169ff9 --- /dev/null +++ b/linalg/x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2 @@ -0,0 +1,885 @@ +// vim: set syntax=asm : +// +// AVX-512 VNNI int8 GEMM kernel, 16 M-rows x 16 N-cols i32 accumulator output. +// +// The zmm-wide (512-bit) sibling of avx512vnni_mmm_i32_8x8: where the 8x8 +// kernel accumulates 8 columns per ymm, this one accumulates 16 columns per +// zmm over 16 rows, so one VPDPBUSD covers a 16-lane x 4-K = 64 mul-add slab +// and the K=4 inner step issues 16 of them (one per row) -- 1024 mul-adds per +// K=4 block, 2x the work-per-iteration of the 8x8 ymm kernel. It is the int8 +// throughput tier of qmmm_i32 on big cores that have AVX-512-VNNI but no AMX +// (Cascade Lake / Ice Lake / Tiger Lake server + client SKUs). +// +// VPDPBUSD is u8 x s8, so (like the 8x8 kernel) the A bytes are offset by +128 +// to become u8 and the resulting 128 * sum_k(B[n]) bias is removed per column +// after the loop; the i32 accumulators are then bit-identical to the AVX2 / +// VNNI-8x8 / AMX reference paths. +// +// A and B both use PackedI8K4(16): per K=4 block, 16 elements x 4 K-bytes = 64 +// bytes, element e at byte offset e*4 holding [e, 4kb..4kb+3]; K is zero-padded +// to a multiple of 4 by the packer. +// +// REGISTER LAYOUT +// zmm0..zmm15 = accumulators, ROW-MAJOR: zmm{m} holds the 16 i32 lanes +// [C[m, 0], C[m, 1], ..., C[m, 15]] for row m. Row-major makes +// the per-column +128 bias a single vector subtract and lets +// the Store path write each row with one vmovdqu32. +// zmm16 = B K=4 block (lane n = B[n, 4kb..]); zmm17 = u8 ones (0x01010101); +// zmm18 = broadcast A[m, 4kb..] (+128 -> u8); zmm19 = bias (sum_k B[n]); +// zmm20 = 0x80808080 (the +128 byte bias added to A). + +{% if msvc %} + +_text segment +avx512vnni_mmm_i32_16x16_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512vnni_mmm_i32_16x16_{{suffix}} +{{G}}avx512vnni_mmm_i32_16x16_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + // vzeroall only zeros lower-256 of zmm0..15; explicitly zero the full + // accumulators (zmm0..zmm15) for AMX. + {% for r in range(0, 16) %} + vpxorq zmm{{r}}, zmm{{r}}, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_i8i8 + +{{L}}main_loop_packed_packed: + // Generic i32 x i32 fallback path (not AMX). For row-major accumulators + // with zmm{m} = row m of C, accumulating C[m, n] += A[m, k] * B[k, n]: + // - load 16 B values for this K row into zmm16 (row of B) + // - for each m: broadcast A[m, k], multiply by zmm16, add to zmm{m} + vmovups zmm16, [rbx] // 16 i32 of B at this K row + + {% for m in range(0, 16) %} + vpbroadcastd zmm17, dword ptr [rax + {{m}} * 4] + vpmulld zmm18, zmm16, zmm17 + vpaddd zmm{{m}}, zmm{{m}}, zmm18 + {% endfor %} + + add rax, 64 // 16 i32 lanes per K step + add rbx, 64 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_i8i8: + // PackedI8K4(16) for both A and B: per K=4 block, 16 elements x 4 K-bytes = + // 64 bytes, element e at byte offset e*4 holding [e, 4kb..4kb+3]. + // B block -> zmm16, lane n = B[n, 4kb..] (the s8 operand) + // A[m] (its 4 K-bytes) broadcast to all 16 lanes, +128 -> the u8 operand + // VPDPBUSD zmm{m}, A_bcast(u8), Bblock(s8): lane n += sum_t (A[m,t]+128)*B[n,t] + // = C[m, n] + 128 * sum_t B[n, t]. That 128*sum_k(B[n]) bias is the same + // for every row m, so it is accumulated once per column in zmm19 (via a u8 + // all-ones VPDPBUSD) and subtracted from every accumulator after the loop, + // leaving the i32 accumulators bit-identical to the AVX2 / 8x8 paths. + + add rcx, 3 + shr rcx, 2 // rcx <- ceil(k/4) K=4 blocks + + mov r8d, 0x01010101 + vmovd xmm17, r8d + vpbroadcastd zmm17, xmm17 // zmm17 <- u8 ones (sum of B) + + mov r8d, 0x80808080 + vmovd xmm20, r8d + vpbroadcastd zmm20, xmm20 // zmm20 <- byte 0x80 (A + 128) + + vpxorq zmm19, zmm19, zmm19 // zmm19 <- per-col sum_k B[n] + +{{L}}loop_4k_i8i8_16x16: + vmovdqu32 zmm16, [rbx] // B block: lane n = B[n, 4kb..] + vpdpbusd zmm19, zmm17, zmm16 // sum_k B[n] += sum_t B[n, 4kb+t] + + {% for m in range(0, 16) %} + vpbroadcastd zmm18, dword ptr [rax + {{ m * 4 }}] + vpaddb zmm18, zmm18, zmm20 // s8 -> u8 (+128, modular) + vpdpbusd zmm{{m}}, zmm18, zmm16 // acc[m][n] += sum_t (A[m]+128)*B[n] + {% endfor %} + + add rax, 64 // next A K=4 block (16 rows * 4 K) + add rbx, 64 // next B K=4 block (16 cols * 4 K) + dec rcx + jnz {{L}}loop_4k_i8i8_16x16 + + // remove the +128 bias added on A: acc[m][n] -= 128 * sum_k B[n] (per column) + vpslld zmm19, zmm19, 7 // lane n <- 128 * sum_k B[n] + {% for m in range(0, 16) %} + vpsubd zmm{{m}}, zmm{{m}}, zmm19 + {% endfor %} + + jmp {{L}}non_linear_loop + +// ---- Scalar / per-row / per-col elementwise epilogues ------------------- + +{{L}}scalar_min: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpminsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_max: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpmaxsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_add: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpaddd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_mul: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpmulld zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub: + // non-flipped sub = operand - acc (matches fma_mmm_ymm_ops.j2 scalar macro) + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub_flipped: + // flipped sub = acc - operand + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}leaky_relu: + // C[m, n] = (C[m, n] >= 0) ? C[m, n] : alpha * C[m, n] + vpbroadcastd zmm17, dword ptr [rdi + 8] // alpha as i32 scale factor + vpxorq zmm16, zmm16, zmm16 + {% for r in range(0, 16) %} + vpmulld zmm18, zmm{{r}}, zmm17 + vpcmpgtd k1, zmm16, zmm{{r}} // 1 where C < 0 + vpblendmd zmm{{r}}{k1}, zmm{{r}}, zmm18 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_min: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpminsd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_max: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpmaxsd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_add: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpaddd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_mul: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpmulld zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub: + // non-flipped sub = operand - acc + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpsubd zmm{{m}}, zmm16, zmm{{m}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub_flipped: + // flipped sub = acc - operand + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpsubd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_min: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpminsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_max: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpmaxsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_add: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpaddd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_mul: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpmulld zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub: + // non-flipped sub = operand - acc + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub_flipped: + // flipped sub = acc - operand + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}load_tile: + // Scratch layout is COL-MAJOR i32 from scratch.rs Store/AddUnicast remnant: + // tile[col][row] at offset (col*MR + row)*4 with MR=16 + // = offset col*64 + row*4 + // For row-major accumulators we gather row m's 16 cols at index step 64. + mov r8, [rdi + 8] + vmovdqa32 zmm16, [rip + {{L}}lane_offsets_64] // [0, 64, 128, ..., 15*64] + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm{{m}}{k1}, [r8 + zmm16 + {{m * 4}}] + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_unicast: + mov r10, [rdi + 8] // c ptr (base) + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size + + cmp r8, 4 + je {{L}}non_linear_addc_i32 + + // i8 path: read 16 i8 from [r10 + m*rsi + n*rbx] for n=0..15, sign-extend + // to i32, add to zmm{m}. Use a stack scratch buffer (16 bytes per row). + sub rsp, 16 + {% for m in range(0, 16) %} + mov r8, r10 + {% for n in range(0, 16) %} + mov al, [r8] + mov byte ptr [rsp + {{n}}], al + add r8, rbx + {% endfor %} + vpmovsxbd zmm16, [rsp] + vpaddd zmm{{m}}, zmm{{m}}, zmm16 + add r10, rsi + {% endfor %} + add rsp, 16 + jmp {{L}}non_linear_loop + +{{L}}non_linear_addc_i32: + // i32 strided read of external (or scratch) tile. Build per-lane index + // vector [0, rbx, 2*rbx, ..., 15*rbx] once, then gather row by row. + mov eax, ebx + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + vpmulld zmm16, zmm16, [rip + {{L}}lane_indices] // [0, rbx, 2*rbx, ...] + + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm17{k1}, [r10 + zmm16] + vpaddd zmm{{m}}, zmm{{m}}, zmm17 + add r10, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_row_col_products: + // bias[m, n] = row_data[m] * col_data[n], add to C[m, n]. + // For row-major regs: load 16 col_data values once into zmm16, + // for each m: broadcast row_data[m], FMA add. + mov rax, [rdi + 8] + mov rbx, [rdi + 16] + + vmovdqu32 zmm16, [rax] // 16 row_data values + vmovdqu32 zmm17, [rbx] // 16 col_data values + + {% for m in range(0, 16) %} + vpbroadcastd zmm18, dword ptr [rax + {{m * 4}}] // splat row_data[m] + vpmulld zmm19, zmm18, zmm17 + vpaddd zmm{{m}}, zmm{{m}}, zmm19 + {% endfor %} + jmp {{L}}non_linear_loop + +// ---- Q-scale (mult-shift with rounding) --------------------------------- + +{{L}}q_scale: + mov r8, [rdi + 16] // policy + vpbroadcastd zmm16, dword ptr [rdi + 24] // multi (broadcast i32) + + mov rax, 1 + vmovq xmm17, rax + vpbroadcastq zmm17, xmm17 // zmm17 <- 1 (i64 lanes) + + mov rax, [rdi + 8] // shift + add rax, 31 + vmovq xmm18, rax + vpbroadcastq zmm18, xmm18 // zmm18 <- (shift+31) (i64 lanes) + + vpsubq zmm19, zmm18, zmm17 + vpsllvq zmm19, zmm17, zmm19 // zmm19 <- 1 << (shift+31-1) (i64) + + // Per-lane interleave mask for blending evens / shifted-odds. + // bit i = 1 means take from "evens" source in vpblendmd; bit 0,2,4,...,14 set. + mov eax, 0x5555 + kmovw k7, eax + + cmp r8, 1 + je {{L}}q_scale_rounding_zero + cmp r8, 2 + je {{L}}q_scale_rounding_away + cmp r8, 3 + je {{L}}q_scale_rounding_minus_inf + cmp r8, 4 + je {{L}}q_scale_rounding_plus_inf + cmp r8, 5 + je {{L}}q_scale_rounding_even + cmp r8, 6 + je {{L}}q_scale_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_scale_rounding_zero: // signum * ( (abs + nudge - 1) >> shift ) +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 // even-lane i32 -> i64 mul + vpmuldq zmm21, zmm21, zmm16 // odd-lane i32 -> i64 mul + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsubq zmm20, zmm20, zmm17 + vpsubq zmm21, zmm21, zmm17 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 // k7=0x5555: evens from zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_away: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_minus_inf: // nudge by -1 where input was negative +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpxorq zmm22, zmm22, zmm22 + vpcmpgtd k1, zmm{{i}}, zmm22 // k1: 1 where input > 0 (we want the inverse, see below) + knotw k1, k1 // 1 where input <= 0 -- we want "input was negative => subtract 1" + // For "<0": use compare against 0 with vpcmpltd + vpxorq zmm22, zmm22, zmm22 + vpcmpltd k1, zmm{{i}}, zmm22 // 1 where input < 0 + vmovdqa32 zmm23{k1}{z}, [rip + {{L}}all_ones_i32] // (1 << 0) per neg lane, 0 elsewhere + + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + // Subtract 1 from i64-evens / i64-odds where the original i32 input was < 0. + vpsrldq zmm24, zmm23, 4 + vpmovsxdq zmm25, ymm23 + vpmovsxdq zmm26, ymm24 + vpsubq zmm20, zmm20, zmm25 + vpsubq zmm21, zmm21, zmm26 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_plus_inf: // nudge by +1 where input was non-negative +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpxorq zmm22, zmm22, zmm22 + vpcmpled k1, zmm22, zmm{{i}} // 1 where input >= 0 + vmovdqa32 zmm23{k1}{z}, [rip + {{L}}all_ones_i32] + + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrldq zmm24, zmm23, 4 + vpmovsxdq zmm25, ymm23 + vpmovsxdq zmm26, ymm24 + vpsubq zmm20, zmm20, zmm25 + vpsubq zmm21, zmm21, zmm26 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_even: // banker's: round half to even +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpsrlq zmm22, zmm20, xmm18 + vpandq zmm22, zmm22, zmm17 + vpaddq zmm20, zmm20, zmm22 + vpsubq zmm20, zmm20, zmm17 + + vpsrlq zmm22, zmm21, xmm18 + vpandq zmm22, zmm22, zmm17 + vpaddq zmm21, zmm21, zmm22 + vpsubq zmm21, zmm21, zmm17 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_odd: // round half to odd +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpsrlq zmm22, zmm20, xmm18 + vpandq zmm22, zmm22, zmm17 + vpsubq zmm20, zmm20, zmm22 + + vpsrlq zmm22, zmm21, xmm18 + vpandq zmm22, zmm22, zmm17 + vpsubq zmm21, zmm21, zmm22 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shl: + mov eax, [rdi + 8] // -shift (count: i32) + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + {% for i in range(0, 16) %}vpsllvd zmm{{i}}, zmm{{i}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr: + mov r8, [rdi + 16] // policy + + mov eax, 1 + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 // zmm16 <- 1 (i32 lanes) + + mov eax, [rdi + 8] // shift + vmovd xmm17, eax + vpbroadcastd zmm17, xmm17 // zmm17 <- shift (i32 lanes) + + mov ebx, 1 + mov cl, al + sub cl, 1 + sal ebx, cl // ebx <- 1 << (shift - 1) + vmovd xmm18, ebx + vpbroadcastd zmm18, xmm18 // zmm18 <- "half" + + vpxorq zmm19, zmm19, zmm19 // zeroes + + cmp r8, 1 + je {{L}}q_shr_rounding_zero + cmp r8, 2 + je {{L}}q_shr_rounding_away + cmp r8, 3 + je {{L}}q_shr_rounding_minus_inf + cmp r8, 4 + je {{L}}q_shr_rounding_plus_inf + cmp r8, 5 + je {{L}}q_shr_rounding_even + cmp r8, 6 + je {{L}}q_shr_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_shr_rounding_zero: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsubd zmm20, zmm20, zmm16 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_away: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_minus_inf: +{% for i in range(0, 16) %} + vpsubd zmm{{i}}, zmm{{i}}, zmm16 + vpaddd zmm{{i}}, zmm{{i}}, zmm18 + vpsravd zmm{{i}}, zmm{{i}}, zmm17 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_plus_inf: +{% for i in range(0, 16) %} + vpaddd zmm{{i}}, zmm{{i}}, zmm18 + vpsravd zmm{{i}}, zmm{{i}}, zmm17 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_even: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsravd zmm21, zmm20, zmm17 + vpandq zmm21, zmm21, zmm16 + vpsubd zmm21, zmm21, zmm16 // nudge = ((abs >>l shift) & 1) - 1 + vpaddd zmm20, zmm20, zmm21 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_odd: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsravd zmm21, zmm20, zmm17 + vpandq zmm21, zmm21, zmm16 + vpsubd zmm21, zmm19, zmm21 // nudge = -((abs >>l shift) & 1) + vpaddd zmm20, zmm20, zmm21 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +// ---- Store --------------------------------------------------------------- + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + je {{L}}store_strides_i32 + // else: i8 fallthrough + + cmp rdx, 1 + je {{L}}store_strides_i8_row_contig + + // Generic i8 strided store: per row, per lane scalar byte stores + {% for m in range(0, 16) %} + mov r10, r8 + // Extract from each 128-bit slice of zmm{{m}} + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i8_row_contig: + // Each row is 16 i8 contiguous; one vpmovdb per row. + {% for m in range(0, 16) %} + vpmovdb [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32: + cmp rdx, 4 + je {{L}}store_strides_i32_row_contig + + // Generic i32 strided store + {% for m in range(0, 16) %} + mov r10, r8 + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32_row_contig: + // C is row-major in memory: each row's 16 i32 are contiguous; one + // 64-byte aligned-or-unaligned store per row. + {% for m in range(0, 16) %} + vmovdqu32 [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}return: + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + +// ---- Read-only data (RIP-relative) --------------------------------------- + +.p2align 6 +{{L}}lane_offsets_64: + .int 0, 64, 128, 192, 256, 320, 384, 448 + .int 512, 576, 640, 704, 768, 832, 896, 960 + +.p2align 6 +{{L}}lane_indices: + .int 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + +.p2align 6 +{{L}}all_ones_i32: + .int 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + +{% if msvc %} +avx512vnni_mmm_i32_16x16_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} diff --git a/linalg/x86_64/fma/avxvnni_mmm_i32_8x8.S.j2 b/linalg/x86_64/fma/avxvnni_mmm_i32_8x8.S.j2 new file mode 100644 index 0000000000..904335c511 --- /dev/null +++ b/linalg/x86_64/fma/avxvnni_mmm_i32_8x8.S.j2 @@ -0,0 +1,685 @@ +{# +// vim: set syntax=asm : + +/* AVX-VNNI int8 GEMM (mmm 8x8), VEX-encoded VPDPBUSD. +// +// Body-identical to avx512vnni_mmm_i32_8x8 (same 8-row x 8-col ymm accumulators, +// same PackedI8K4 inner-K layout, same +128 bias trick to bridge VPDPBUSD's +// u8 x s8 into the AVX2 s8 x s8 reference). The only difference is that the +// two VPDPBUSD instructions are prefixed with {vex} so gas emits the AVX-VNNI +// (VEX) form instead of the AVX-512-VNNI (EVEX) form it defaults to. The VEX +// form runs on Atom-class cores (Alder Lake E-cores, Sierra Forest, Clearwater +// Forest / Darkmont) which have AVX-VNNI but no AVX-512, where the existing +// avx512vnni kernel would fault. + + ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6 ymm7 + +System V ABI: + args: rdi, rsi, rdx, rcx, r8, r9 + preserve: rbx, rsp, rbp, r12, r13, r14, r15 + scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11 + return: rax (+rdx) + +Windows ABI: + args: RCX, RDX, R8, R9 + preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15 + scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15 + return: rax (+rdx) +*/ +#} + +{% if msvc %} + +_text segment +avxvnni_mmm_i32_8x8_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avxvnni_mmm_i32_8x8_{{suffix}} +{{G}}avxvnni_mmm_i32_8x8_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} +// https://www.agner.org/optimize/calling_conventions.pdf xmm6-15 are not scratch +// https://stackoverflow.com/questions/43358429/save-value-of-xmm-registers + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + vzeroall + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_i8i8 + +{{L}}main_loop_packed_packed: + vmovaps ymm12, [rax] + + {% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{i}} * 4] + vpmulld ymm13, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm13 + {% endfor %} + + add rax, 32 + add rbx, 32 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_i8i8: + // PackedI8K4 layout: per K=4 block, the A panel is 8 rows x 4 K-bytes (32 + // bytes, lane m = A[m, 4kb..4kb+3]) and the B panel is 8 cols x 4 K-bytes + // (lane n = B[n, 4kb..4kb+3]). VPDPBUSD is u8 x s8, so A is offset by +128 + // (-> u8) and the resulting 128*sum_k(B[n]) bias is removed per column after + // the loop, leaving the i32 accumulators identical to the AVX2 path. + + add rcx, 3 + shr rcx, 2 // rcx <- ceil(k/4) K=4 blocks + + mov r8d, 0x01010101 + movd xmm11, r8d + vpbroadcastd ymm11, xmm11 // ymm11 <- u8 ones (sum of B) + + mov r8d, 0x80808080 + movd xmm12, r8d + vpbroadcastd ymm12, xmm12 // ymm12 <- byte 0x80 (A + 128) + + vpxor ymm10, ymm10, ymm10 // ymm10 <- per-col sum_k B[n] + +{{L}}loop_4k_i8i8: + vmovdqu ymm8, [rax] // A block: lane m = A[m,4kb..] + vpaddb ymm8, ymm8, ymm12 // s8 -> u8 (+128, modular) + + vmovdqu ymm9, [rbx] // B block: lane n = B[n,4kb..] + {vex} vpdpbusd ymm10, ymm11, ymm9 // sum_k B[n] += sum_t B[n,4kb+t] + + {% for n in range(0, 8) %} + vpbroadcastd ymm13, dword ptr [rbx + {{n}} * 4] + {vex} vpdpbusd ymm{{n}}, ymm8, ymm13 // acc[n][m] += sum_t (A[m]+128)*B[n] + {% endfor %} + + add rax, 32 + add rbx, 32 + dec rcx + jnz {{L}}loop_4k_i8i8 + + // remove the +128 bias added on A: acc[n] -= 128 * sum_k B[n] + vpslld ymm10, ymm10, 7 // lane n <- 128 * sum_k B[n] + {% for n in range(0, 8) %} + mov r8d, {{n}} + movd xmm14, r8d + vpbroadcastd ymm14, xmm14 // index = n in every lane + vpermd ymm15, ymm14, ymm10 // splat 128*sum_k B[n] + vpsubd ymm{{n}}, ymm{{n}}, ymm15 + {% endfor %} + + jmp {{L}}non_linear_loop + +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_scalars.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_rows.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_cols.j2" %} +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_load_tile.j2" %} + +{{L}}add_unicast: + + mov r10, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size + + cmp r8, 4 + je {{L}}non_linear_addc_i32 + +{# +// This is not great as vgatherdps reads 32-bits values and goes beyond our buffer. Probably harmless though. +// Commented and replaced with the "mov al" loop beyond to pacify valgrind. +// ymm14 and ymm15 are the same as in the non_linear_addc_i32 case (compute them before the test right above here. +// {% for i in range(0, 8) %} +// vpcmpeqd ymm15, ymm15, ymm15 +// vgatherdps ymm12, [ r10 + ymm14 ], ymm15 // 0xxx 1xxx 2xxx 3xxx 4xxx 5xxx 6xxx 7xxx +// +// // we need to go through vpmovsxbd, shuffling naively erases signs +// vpshufb ymm12, ymm12, ymm10 // 0123 0123 0123 0123 4567 4567 4567 4567 +// +// vpermd ymm12, ymm11, ymm12 // 0123 4567 +// vpmovsxbd ymm12, xmm12 // sign extend +// +// vpaddd ymm{{i}}, ymm{{i}}, ymm12 +// add r10, rbx +// {% endfor %} +#} + + {% for col in range(0, 8) %} + mov r8, r10 + {% for half in range(0, 2) %} + {% for lane in range(0, 4) %} + mov al, [ r8 ] + add r8, rsi + movsx eax, al + pinsrd xmm10, eax, {{lane}} + {% endfor %} + vperm2f128 ymm10, ymm10, ymm10, 1 + {% endfor %} + vpaddd ymm{{col}}, ymm{{col}}, ymm10 + add r10, rbx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}non_linear_addc_i32: + + mov eax, 0 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 + + +{% if msvc %} + vpbroadcastd ymm10, dword ptr [ offset byte_shuffle ] + vmovups ymm11, dword ptr [ offset i128_shuffle ] +{% else %} + vpbroadcastd ymm10, [ rip + {{L}}byte_shuffle ] + vmovups ymm11, [ rip + {{L}}i128_shuffle ] +{% endif %} + +{% for i in range(0, 8) %} + vpcmpeqd ymm15, ymm15, ymm15 + vgatherdps ymm12, [ r10 + ymm14 ], ymm15 + vpaddd ymm{{i}}, ymm{{i}}, ymm12 + add r10, rbx +{% endfor %} + + jmp {{L}}non_linear_loop + +{% if msvc %} +.data +byte_shuffle dd 201851904 // 0x0c080400 +i128_shuffle dd 0, 4 +.code +{% else %} +{{L}}byte_shuffle: .int 201851904 // 0x0c080400 +{{L}}i128_shuffle: .int 0, 4 +{% endif %} + +{{L}}add_row_col_products: + mov rax, [ rdi + 8 ] + mov rbx, [ rdi + 16 ] + + vmovups ymm12, [rax] + +{% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{ i * 4 }} ] + vpmulld ymm15, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm15 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale: + mov r8, [ rdi + 16 ] // policy + vbroadcastss ymm8, dword ptr [rdi + 24] // multi + + mov rax, 1 + movq xmm9, rax + vpbroadcastq ymm9, xmm9 // ymm9 <- 1 + + mov rax, [ rdi + 8 ] // xmm10 <- shift + 31 + add rax, 31 + movq xmm10, rax + vpbroadcastq ymm10, xmm10 + + mov rax, 1 + movq xmm11, rax + vpsubq ymm12, ymm10, ymm9 // shift+31 - 1 + vpsllq ymm11, ymm9, xmm12 // ymm11 <- 1 << (shift + 31 - 1) + + cmp r8, 1 + je {{L}}q_scale_rounding_zero + cmp r8, 2 + je {{L}}q_scale_rounding_away + cmp r8, 3 + je {{L}}q_scale_rounding_minus_inf + cmp r8, 4 + je {{L}}q_scale_rounding_plus_inf + cmp r8, 5 + je {{L}}q_scale_rounding_even + cmp r8, 6 + je {{L}}q_scale_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_scale_rounding_zero: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsubq ymm14, ymm14, ymm9 + vpsubq ymm15, ymm15, ymm9 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_away: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_minus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + // sign extract for nudging in the right direction + vpxor ymm13, ymm13, ymm13 + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpsrld ymm13, ymm13, 31 // then just 0 or 1 + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_plus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) + + vpbroadcastd ymm9, xmm9 + +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpxor ymm13, ymm13, ymm13 + + // sign extract for nudging in the right direction + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpaddd ymm13, ymm13, ymm9 // if val >= 0 { 0i32 } else { 1i32 } + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_even: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm14, ymm14, ymm12 + vpsubq ymm14, ymm14, ymm9 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm15, ymm15, ymm12 + vpsubq ymm15, ymm15, ymm9 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_odd: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm14, ymm14, ymm12 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm15, ymm15, ymm12 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_shl: + mov eax, [ rdi + 8 ] // xmm10 <- -shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + +{% for i in range(0, 8) %} + vpsllvd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr: + mov r8, [ rdi + 16 ] // policy + + mov eax, 1 + movd xmm9, eax + vpbroadcastd ymm9, xmm9 // ymm9 <- 1u32 (8 times) + + mov eax, [ rdi + 8 ] // xmm10 <- shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + + mov ebx, 1 + mov cl, al + sub cl, 1 // rcx <- shift -1 + sal ebx, cl // rbx <- (1 << (shift - 1)) + movd xmm11, ebx + vpbroadcastd ymm11, xmm11 // ymm11 <- "half" + + vpxor ymm12, ymm12, ymm12 // ymm12 <- zeroes + + cmp r8, 1 + je {{L}}q_shr_rounding_zero + cmp r8, 2 + je {{L}}q_shr_rounding_away + cmp r8, 3 + je {{L}}q_shr_rounding_minus_inf + cmp r8, 4 + je {{L}}q_shr_rounding_plus_inf + cmp r8, 5 + je {{L}}q_shr_rounding_even + cmp r8, 6 + je {{L}}q_shr_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_shr_rounding_zero: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsubd ymm14, ymm14, ymm9 + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_away: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_minus_inf: +{% for i in range(0, 8) %} + vpsubd ymm{{i}}, ymm{{i}}, ymm9 + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_plus_inf: +{% for i in range(0, 8) %} + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_even: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm13, ymm9 // nudge = ((abs >>l shift) & 0x01) - 1 + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_odd: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm12, ymm13 // nudge = - ((abs >>l shift) & 0x01) + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + je {{L}}store_strides_i32 + + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32: + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}return: + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + + +{{L}}one_32bit: +{% if msvc %} + dd 1 +{% else %} + .int 1 +{% endif %} + +{% if msvc %} +avxvnni_mmm_i32_8x8_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %}