Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions core/examples/reduce_max_bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//! Benchmark for the f32 max-reduction. Before the fix, ReduceMax ran the SIMD
//! `max_f32` kernel, threw the result away, then recomputed with a scalar fold;
//! after, it returns the SIMD result. Times `Reducer::Max` over a contiguous
//! trailing axis (the SIMD path).
//!
//! Run: cargo run --release --example reduce_max_bench -p tract-core
use std::time::Instant;

use tract_core::internal::*;
use tract_core::ops::nn::Reducer;

fn main() -> TractResult<()> {
for (rows, k) in [(1024usize, 4096usize), (4096, 1024), (256, 65536)] {
let n = rows * k;
let data: Vec<f32> =
(0..n).map(|i| (((i * 2654435761) >> 13) as f32 / 1e6).sin()).collect();
let t = Tensor::from_shape(&[rows, k], &data)?;

for _ in 0..3 {
let _ = Reducer::Max.reduce(&[1], &t)?;
}
let runs = 50;
let mut chk = 0f32;
let s = Instant::now();
for _ in 0..runs {
let o = Reducer::Max.reduce(&[1], &t)?;
chk += unsafe { o.as_slice_unchecked::<f32>() }[0];
std::hint::black_box(&o);
}
let per = s.elapsed().as_secs_f64() / runs as f64;
let gbps = (n * 4) as f64 / per / 1e9;
println!(
"reduce-max [{rows}x{k}] axis1 : {:>7.3} ms/call {:>6.1} GB/s (chk {chk:.3})",
per * 1e3,
gbps
);
}
Ok(())
}
41 changes: 40 additions & 1 deletion core/src/ops/nn/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,12 @@ where
{
if T::datum_type() == f32::datum_type()
&& let Some(slice) = v.as_slice()
&& !slice.is_empty()
{
let slice = unsafe { transmute::<&[T], &[f32]>(slice) };
(tract_linalg::ops().max_f32)().run(slice).unwrap();
let max = (tract_linalg::ops().max_f32)().run(slice).unwrap();
// SAFETY: T is f32 in this branch (checked above).
return unsafe { std::mem::transmute_copy::<f32, T>(&max) };
}
v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v })
}
Expand Down Expand Up @@ -628,3 +631,39 @@ pub fn expand_mean_of_squares(
patch.shunt_outside(model, node.id.into(), wire[0])?;
Ok(Some(patch))
}

#[cfg(test)]
mod tests {
use super::*;

// Guards the f32 max reduction (max_t): the SIMD `max_f32` kernel result must
// be returned (contiguous path), with the scalar fold used only for strided
// slices. Checked against explicit per-row / per-col references.
#[test]
fn reduce_max_f32_contiguous_and_strided() {
let (r, c) = (5usize, 37usize); // c not a multiple of the SIMD width (tail)
let data: Vec<f32> = (0..r * c).map(|i| ((i * 31 % 97) as f32) - 48.0).collect();
let t = Tensor::from_shape(&[r, c], &data).unwrap();

// axis 1: per-row max — contiguous slices -> SIMD path.
let got = Reducer::Max.reduce(&[1], &t).unwrap();
assert_eq!(got.shape(), &[r, 1]);
for (i, &g) in unsafe { got.as_slice_unchecked::<f32>() }.iter().enumerate() {
let want = data[i * c..(i + 1) * c].iter().copied().fold(f32::MIN, f32::max);
assert_eq!(g, want, "row {i}");
}

// axis 0: per-col max — strided slices -> scalar fold.
let got = Reducer::Max.reduce(&[0], &t).unwrap();
assert_eq!(got.shape(), &[1, c]);
for (j, &g) in unsafe { got.as_slice_unchecked::<f32>() }.iter().enumerate() {
let want = (0..r).map(|i| data[i * c + j]).fold(f32::MIN, f32::max);
assert_eq!(g, want, "col {j}");
}

// k == 1 (single-element reduction) exercises the SIMD-path length guard.
let t1 = Tensor::from_shape(&[3, 1], &[1.0f32, -2.0, 3.0]).unwrap();
let got = Reducer::Max.reduce(&[1], &t1).unwrap();
assert_eq!(unsafe { got.as_slice_unchecked::<f32>() }, &[1.0, -2.0, 3.0]);
}
}
Loading