Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions core/examples/reduce_max_bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//! Benchmark for the f32 max-reduction. Before the fix, ReduceMax ran the SIMD
//! `max_f32` kernel, threw the result away, then recomputed with a scalar fold;
//! after, it returns the SIMD result. Times `Reducer::Max` over a contiguous
//! trailing axis (the SIMD path).
//!
//! Run: cargo run --release --example reduce_max_bench -p tract-core
use std::time::Instant;

use tract_core::internal::*;
use tract_core::ops::nn::Reducer;

fn main() -> TractResult<()> {
for (rows, k) in [(1024usize, 4096usize), (4096, 1024), (256, 65536)] {
let n = rows * k;
let data: Vec<f32> =
(0..n).map(|i| (((i * 2654435761) >> 13) as f32 / 1e6).sin()).collect();
let t = Tensor::from_shape(&[rows, k], &data)?;

for _ in 0..3 {
let _ = Reducer::Max.reduce(&[1], &t)?;
}
let runs = 50;
let mut chk = 0f32;
let s = Instant::now();
for _ in 0..runs {
let o = Reducer::Max.reduce(&[1], &t)?;
chk += unsafe { o.as_slice_unchecked::<f32>() }[0];
std::hint::black_box(&o);
}
let per = s.elapsed().as_secs_f64() / runs as f64;
let gbps = (n * 4) as f64 / per / 1e9;
println!(
"reduce-max [{rows}x{k}] axis1 : {:>7.3} ms/call {:>6.1} GB/s (chk {chk:.3})",
per * 1e3,
gbps
);
}
Ok(())
}
63 changes: 63 additions & 0 deletions core/examples/reduce_min_bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//! Benchmark for the f32 min-reduction. `min_t` now routes the contiguous f32
//! case through the new SIMD `min_f32` (generic `SMin4`) reducer instead of the
//! scalar branchy fold. Times `Reducer::Min` (SIMD) against an inline replica of
//! the previous scalar fold over the same data.
//!
//! Run: cargo run --release --example reduce_min_bench -p tract-core
use std::time::Instant;

use tract_core::internal::*;
use tract_core::ops::nn::Reducer;

#[inline(never)]
fn scalar_min_per_row(data: &[f32], rows: usize, k: usize) -> f32 {
// replica of the old min_t scalar fold (branchy partial-ord)
let mut acc = 0f32;
for r in 0..rows {
let m = data[r * k..(r + 1) * k]
.iter()
.copied()
.fold(f32::MAX, |a, v| if a < v { a } else { v });
acc += m;
}
acc
}

fn main() -> TractResult<()> {
for (rows, k) in [(1024usize, 4096usize), (4096, 1024), (256, 65536)] {
let n = rows * k;
let data: Vec<f32> =
(0..n).map(|i| (((i * 2654435761) >> 13) as f32 / 1e6).sin()).collect();
let t = Tensor::from_shape(&[rows, k], &data)?;

// SIMD path (Reducer::Min)
for _ in 0..3 {
let _ = Reducer::Min.reduce(&[1], &t)?;
}
let runs = 50;
let s = Instant::now();
for _ in 0..runs {
std::hint::black_box(Reducer::Min.reduce(&[1], &t)?);
}
let simd = s.elapsed().as_secs_f64() / runs as f64;

// scalar replica of the old fold
for _ in 0..3 {
std::hint::black_box(scalar_min_per_row(&data, rows, k));
}
let s = Instant::now();
for _ in 0..runs {
std::hint::black_box(scalar_min_per_row(&data, rows, k));
}
let scalar = s.elapsed().as_secs_f64() / runs as f64;

println!(
"min [{rows}x{k}] axis1 : scalar {:>7.3} ms -> SIMD {:>7.3} ms ({:.1}x, {:>5.1} GB/s)",
scalar * 1e3,
simd * 1e3,
scalar / simd,
(n * 4) as f64 / simd / 1e9
);
}
Ok(())
}
74 changes: 73 additions & 1 deletion core/src/ops/nn/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,12 @@ where
{
if T::datum_type() == f32::datum_type()
&& let Some(slice) = v.as_slice()
&& !slice.is_empty()
{
let slice = unsafe { transmute::<&[T], &[f32]>(slice) };
(tract_linalg::ops().max_f32)().run(slice).unwrap();
let max = (tract_linalg::ops().max_f32)().run(slice).unwrap();
// SAFETY: T is f32 in this branch (checked above).
return unsafe { std::mem::transmute_copy::<f32, T>(&max) };
}
v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v })
}
Expand All @@ -306,6 +309,15 @@ fn min_t<T>(v: ArrayViewD<T>, _: ()) -> T
where
T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
{
if T::datum_type() == f32::datum_type()
&& let Some(slice) = v.as_slice()
&& !slice.is_empty()
{
let slice = unsafe { transmute::<&[T], &[f32]>(slice) };
let min = (tract_linalg::ops().min_f32)().run(slice).unwrap();
// SAFETY: T is f32 in this branch (checked above).
return unsafe { std::mem::transmute_copy::<f32, T>(&min) };
}
v.fold(T::max_value(), |acc, &v| if acc < v { acc } else { v })
}

Expand Down Expand Up @@ -628,3 +640,63 @@ pub fn expand_mean_of_squares(
patch.shunt_outside(model, node.id.into(), wire[0])?;
Ok(Some(patch))
}

#[cfg(test)]
mod tests {
use super::*;

// Guards the f32 max reduction (max_t): the SIMD `max_f32` kernel result must
// be returned (contiguous path), with the scalar fold used only for strided
// slices. Checked against explicit per-row / per-col references.
#[test]
fn reduce_max_f32_contiguous_and_strided() {
let (r, c) = (5usize, 37usize); // c not a multiple of the SIMD width (tail)
let data: Vec<f32> = (0..r * c).map(|i| ((i * 31 % 97) as f32) - 48.0).collect();
let t = Tensor::from_shape(&[r, c], &data).unwrap();

// axis 1: per-row max — contiguous slices -> SIMD path.
let got = Reducer::Max.reduce(&[1], &t).unwrap();
assert_eq!(got.shape(), &[r, 1]);
for (i, &g) in unsafe { got.as_slice_unchecked::<f32>() }.iter().enumerate() {
let want = data[i * c..(i + 1) * c].iter().copied().fold(f32::MIN, f32::max);
assert_eq!(g, want, "row {i}");
}

// axis 0: per-col max — strided slices -> scalar fold.
let got = Reducer::Max.reduce(&[0], &t).unwrap();
assert_eq!(got.shape(), &[1, c]);
for (j, &g) in unsafe { got.as_slice_unchecked::<f32>() }.iter().enumerate() {
let want = (0..r).map(|i| data[i * c + j]).fold(f32::MIN, f32::max);
assert_eq!(g, want, "col {j}");
}

// k == 1 (single-element reduction) exercises the SIMD-path length guard.
let t1 = Tensor::from_shape(&[3, 1], &[1.0f32, -2.0, 3.0]).unwrap();
let got = Reducer::Max.reduce(&[1], &t1).unwrap();
assert_eq!(unsafe { got.as_slice_unchecked::<f32>() }, &[1.0, -2.0, 3.0]);
}

// Same coverage for the f32 min reduction (min_t -> SIMD min_f32 / scalar fold).
#[test]
fn reduce_min_f32_contiguous_and_strided() {
let (r, c) = (5usize, 37usize); // c not a multiple of the SIMD width (tail)
let data: Vec<f32> = (0..r * c).map(|i| ((i * 31 % 97) as f32) - 48.0).collect();
let t = Tensor::from_shape(&[r, c], &data).unwrap();

// axis 1: per-row min — contiguous slices -> SIMD path.
let got = Reducer::Min.reduce(&[1], &t).unwrap();
assert_eq!(got.shape(), &[r, 1]);
for (i, &g) in unsafe { got.as_slice_unchecked::<f32>() }.iter().enumerate() {
let want = data[i * c..(i + 1) * c].iter().copied().fold(f32::MAX, f32::min);
assert_eq!(g, want, "row {i}");
}

// axis 0: per-col min — strided slices -> scalar fold.
let got = Reducer::Min.reduce(&[0], &t).unwrap();
assert_eq!(got.shape(), &[1, c]);
for (j, &g) in unsafe { got.as_slice_unchecked::<f32>() }.iter().enumerate() {
let want = (0..r).map(|i| data[i * c + j]).fold(f32::MAX, f32::min);
assert_eq!(g, want, "col {j}");
}
}
}
1 change: 1 addition & 0 deletions linalg/src/arm64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ pub fn plug(ops: &mut Ops) {
ops.sigmoid_f32 = Box::new(|| arm64simd_sigmoid_f32_4n::ew());
ops.tanh_f32 = Box::new(|| arm64simd_tanh_f32_4n::ew());
ops.max_f32 = Box::new(|| arm64simd_max_f32_16n::red());
ops.min_f32 = Box::new(|| arm64simd_min_f32_16n::red());
ops.sum_f32 = Box::new(|| arm64simd_sum_f32_16n::red());
ops.mul_by_scalar_f32 = Box::new(|| arm64simd_mul_by_scalar_f32_16n::ew());
ops.softmax2_fastcompact_f32 = Box::new(|| arm64simd_softmax2_fastcompact_f32_16n::red());
Expand Down
2 changes: 2 additions & 0 deletions linalg/src/arm64/arm64simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod gelu_fused;
mod hardswish;
mod leaky_relu;
mod max;
mod min;
mod panel_extract;
mod rms_norm;
mod silu;
Expand All @@ -18,6 +19,7 @@ pub use gelu_fused::arm64simd_gelu_f32_4n_fused;
pub use hardswish::arm64simd_hardswish_f32_8n;
pub use leaky_relu::arm64simd_leaky_relu_f32_8n;
pub use max::arm64simd_max_f32_16n;
pub use min::arm64simd_min_f32_16n;
pub use rms_norm::rms_norm_f32 as arm64simd_rms_norm_f32;
pub use silu::arm64simd_silu_f32_4n;
pub use silu_fused::arm64simd_silu_f32_4n_fused;
Expand Down
52 changes: 52 additions & 0 deletions linalg/src/arm64/arm64simd/min.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::arch::aarch64::{float32x4_t, vdupq_n_f32, vgetq_lane_f32};

reduce_impl_wrap!(
f32,
arm64simd_min_f32_16n,
16,
4,
(),
f32::MAX,
#[inline(never)]
fn run(buf: &[f32], _: ()) -> f32 {
assert!(buf.len() % 16 == 0);
assert!(buf.len() > 0);
unsafe {
let len = buf.len();
let ptr = buf.as_ptr();
let mut out: float32x4_t = vdupq_n_f32(f32::MAX);
std::arch::asm!("
and v1.16b, v0.16b, v0.16b
and v2.16b, v0.16b, v0.16b
and v3.16b, v0.16b, v0.16b
2:
ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}], 64
fmin v0.4s, v0.4s, v4.4s
fmin v1.4s, v1.4s, v5.4s
fmin v2.4s, v2.4s, v6.4s
fmin v3.4s, v3.4s, v7.4s
subs {len}, {len}, 16
bne 2b
fmin v0.4s, v0.4s, v1.4s
fmin v2.4s, v2.4s, v3.4s
fmin v0.4s, v0.4s, v2.4s
fminv s0, v0.4s
",
len = inout(reg) len => _,
ptr = inout(reg) ptr => _,
inout("v0") out, out("v1") _, out("v2") _, out("v3") _,
out("v4") _, out("v5") _, out("v6") _, out("v7") _,);
vgetq_lane_f32(out, 0)
}
},
#[inline(never)]
fn reduce_two(a: f32, b: f32) -> f32 {
a.min(b)
}
);

#[cfg(test)]
mod test_arm64simd_min_f32_16n {
use super::*;
crate::min_frame_tests!(true, f32, arm64simd_min_f32_16n);
}
42 changes: 42 additions & 0 deletions linalg/src/frame/reduce/min.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#[cfg(test)]
#[macro_use]
pub mod test {
use crate::LADatum;
use crate::frame::reduce::ReduceKer;
use num_traits::{AsPrimitive, Float};
use proptest::test_runner::TestCaseResult;

#[macro_export]
macro_rules! min_frame_tests {
($cond:expr, $t: ty, $ker:ty) => {
proptest::proptest! {
#[test]
fn prop(xs in proptest::collection::vec(-25f32..25.0, 0..100)) {
if $cond {
$crate::frame::reduce::min::test::test_min::<$ker, $t>(&*xs).unwrap()
}
}
}

#[test]
fn empty() {
if $cond {
$crate::frame::reduce::min::test::test_min::<$ker, $t>(&[]).unwrap()
}
}
};
}

pub fn test_min<K: ReduceKer<T>, T: LADatum + Float>(values: &[f32]) -> TestCaseResult
where
f32: AsPrimitive<T>,
{
crate::setup_test_logger();
let values: Vec<T> = values.iter().copied().map(|x| x.as_()).collect();
crate::frame::reduce::test::test_reduce::<K, _>(
&values,
<T as Float>::max_value(),
|a, b| a.min(b),
)
}
}
1 change: 1 addition & 0 deletions linalg/src/frame/reduce/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod max;
pub mod min;
pub mod softmax;
pub mod sum;

Expand Down
28 changes: 28 additions & 0 deletions linalg/src/generic/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,34 @@ pub mod max {
}
}

// Reduce<min> generic implementation
pub mod min {
pub use tract_data::internal::f16;

reduce_impl_wrap!(
f32,
SMin4,
4,
4,
(),
f32::MAX,
fn run(x: &[f32], _: ()) -> f32 {
debug_assert!(x.len() % Self::nr() == 0);
debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
*x.iter().min_by(|a, b| a.total_cmp(b)).unwrap()
},
fn reduce_two(a: f32, b: f32) -> f32 {
a.min(b)
}
);

#[cfg(test)]
#[macro_use]
pub mod s {
crate::min_frame_tests!(true, f32, crate::generic::reduce::min::SMin4);
}
}

// Reduce<sum> generic implementation
pub mod sum {
use crate::num_traits::Zero;
Expand Down
2 changes: 2 additions & 0 deletions linalg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub struct Ops {

pub max_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
pub max_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
pub min_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,

pub sum_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
pub sum_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
Expand Down Expand Up @@ -242,6 +243,7 @@ pub fn generic() -> Ops {
lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::<generic::GenericLut8>::new(table))),
max_f16: Box::new(|| generic::reduce::max::HMax8::red()),
max_f32: Box::new(|| generic::reduce::max::SMax4::red()),
min_f32: Box::new(|| generic::reduce::min::SMin4::red()),
sum_f16: Box::new(|| generic::reduce::sum::HSum8::red()),
sum_f32: Box::new(|| generic::reduce::sum::SSum4::red()),
/*
Expand Down
2 changes: 2 additions & 0 deletions linalg/src/x86_64_fma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub mod by_scalar;
pub mod erf;
mod intel;
pub mod max;
pub mod min;
pub mod panel_extract;
pub mod rms_norm;
pub mod softmax;
Expand Down Expand Up @@ -42,6 +43,7 @@ fn plug_fma(ops: &mut Ops) {

ops.mul_by_scalar_f32 = Box::new(|| by_scalar::x86_64_avx_f32_mul_by_scalar_32n::ew());
ops.max_f32 = Box::new(|| max::x86_64_fma_max_f32_32n::red());
ops.min_f32 = Box::new(|| min::x86_64_fma_min_f32_32n::red());
ops.softmax2_fastcompact_f32 = Box::new(|| x86_64_fma_softmax2_fastcompact_f32_32n::red());

log::info!("sigmoid_f32, tanh_f32: x86_64/fma activated");
Expand Down
Loading