From 9efa301b64bde9015abc7e48d95fcd7af598990e Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Sat, 13 Jun 2026 14:18:45 +0100 Subject: [PATCH] core: don't recompute ReduceMax scalar after the SIMD max kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `max_t` (the f32 ReduceMax reducer) called the vectorized `max_f32` linalg kernel, *discarded its result*, then unconditionally recomputed the max with a scalar partial-ord fold over the same slice — so ReduceMax did the reduction twice and was effectively scalar-bound (the "optimized" path was strictly slower than having no kernel at all). Return the SIMD kernel's result for the f32 contiguous case; fall through to the scalar fold only for non-f32 dtypes, non-contiguous (strided) slices, or empty slices. Adds a correctness test covering both branches (contiguous + tail, strided, single-element). Benchmark (M-series, f32 max over the trailing axis, via the added reduce_max_bench example): shape before after speedup 1024 x 4096 2.44 ms / 6.9GB/s 0.32 ms / 52GB/s 7.5x 4096 x 1024 2.46 ms / 6.8GB/s 0.42 ms / 40GB/s 5.9x 256 x 65536 9.44 ms / 7.1GB/s 1.04 ms / 65GB/s 9.1x Identical results. Benefits ReduceMax, MaxPool and the softmax max pre-pass. Co-Authored-By: Claude Fable 5 --- core/examples/reduce_max_bench.rs | 39 +++++++++++++++++++++++++++++ core/src/ops/nn/reduce.rs | 41 ++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 core/examples/reduce_max_bench.rs diff --git a/core/examples/reduce_max_bench.rs b/core/examples/reduce_max_bench.rs new file mode 100644 index 0000000000..e83b6d7ae6 --- /dev/null +++ b/core/examples/reduce_max_bench.rs @@ -0,0 +1,39 @@ +//! Benchmark for the f32 max-reduction. Before the fix, ReduceMax ran the SIMD +//! `max_f32` kernel, threw the result away, then recomputed with a scalar fold; +//! after, it returns the SIMD result. Times `Reducer::Max` over a contiguous +//! trailing axis (the SIMD path). +//! +//! Run: cargo run --release --example reduce_max_bench -p tract-core +use std::time::Instant; + +use tract_core::internal::*; +use tract_core::ops::nn::Reducer; + +fn main() -> TractResult<()> { + for (rows, k) in [(1024usize, 4096usize), (4096, 1024), (256, 65536)] { + let n = rows * k; + let data: Vec = + (0..n).map(|i| (((i * 2654435761) >> 13) as f32 / 1e6).sin()).collect(); + let t = Tensor::from_shape(&[rows, k], &data)?; + + for _ in 0..3 { + let _ = Reducer::Max.reduce(&[1], &t)?; + } + let runs = 50; + let mut chk = 0f32; + let s = Instant::now(); + for _ in 0..runs { + let o = Reducer::Max.reduce(&[1], &t)?; + chk += unsafe { o.as_slice_unchecked::() }[0]; + std::hint::black_box(&o); + } + let per = s.elapsed().as_secs_f64() / runs as f64; + let gbps = (n * 4) as f64 / per / 1e9; + println!( + "reduce-max [{rows}x{k}] axis1 : {:>7.3} ms/call {:>6.1} GB/s (chk {chk:.3})", + per * 1e3, + gbps + ); + } + Ok(()) +} diff --git a/core/src/ops/nn/reduce.rs b/core/src/ops/nn/reduce.rs index 25f18fcd8e..5138379388 100644 --- a/core/src/ops/nn/reduce.rs +++ b/core/src/ops/nn/reduce.rs @@ -295,9 +295,12 @@ where { if T::datum_type() == f32::datum_type() && let Some(slice) = v.as_slice() + && !slice.is_empty() { let slice = unsafe { transmute::<&[T], &[f32]>(slice) }; - (tract_linalg::ops().max_f32)().run(slice).unwrap(); + let max = (tract_linalg::ops().max_f32)().run(slice).unwrap(); + // SAFETY: T is f32 in this branch (checked above). + return unsafe { std::mem::transmute_copy::(&max) }; } v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v }) } @@ -628,3 +631,39 @@ pub fn expand_mean_of_squares( patch.shunt_outside(model, node.id.into(), wire[0])?; Ok(Some(patch)) } + +#[cfg(test)] +mod tests { + use super::*; + + // Guards the f32 max reduction (max_t): the SIMD `max_f32` kernel result must + // be returned (contiguous path), with the scalar fold used only for strided + // slices. Checked against explicit per-row / per-col references. + #[test] + fn reduce_max_f32_contiguous_and_strided() { + let (r, c) = (5usize, 37usize); // c not a multiple of the SIMD width (tail) + let data: Vec = (0..r * c).map(|i| ((i * 31 % 97) as f32) - 48.0).collect(); + let t = Tensor::from_shape(&[r, c], &data).unwrap(); + + // axis 1: per-row max — contiguous slices -> SIMD path. + let got = Reducer::Max.reduce(&[1], &t).unwrap(); + assert_eq!(got.shape(), &[r, 1]); + for (i, &g) in unsafe { got.as_slice_unchecked::() }.iter().enumerate() { + let want = data[i * c..(i + 1) * c].iter().copied().fold(f32::MIN, f32::max); + assert_eq!(g, want, "row {i}"); + } + + // axis 0: per-col max — strided slices -> scalar fold. + let got = Reducer::Max.reduce(&[0], &t).unwrap(); + assert_eq!(got.shape(), &[1, c]); + for (j, &g) in unsafe { got.as_slice_unchecked::() }.iter().enumerate() { + let want = (0..r).map(|i| data[i * c + j]).fold(f32::MIN, f32::max); + assert_eq!(g, want, "col {j}"); + } + + // k == 1 (single-element reduction) exercises the SIMD-path length guard. + let t1 = Tensor::from_shape(&[3, 1], &[1.0f32, -2.0, 3.0]).unwrap(); + let got = Reducer::Max.reduce(&[1], &t1).unwrap(); + assert_eq!(unsafe { got.as_slice_unchecked::() }, &[1.0, -2.0, 3.0]); + } +}