diff --git a/core/examples/reduce_min_bench.rs b/core/examples/reduce_min_bench.rs new file mode 100644 index 0000000000..2423498719 --- /dev/null +++ b/core/examples/reduce_min_bench.rs @@ -0,0 +1,63 @@ +//! Benchmark for the f32 min-reduction. `min_t` now routes the contiguous f32 +//! case through the new SIMD `min_f32` (generic `SMin4`) reducer instead of the +//! scalar branchy fold. Times `Reducer::Min` (SIMD) against an inline replica of +//! the previous scalar fold over the same data. +//! +//! Run: cargo run --release --example reduce_min_bench -p tract-core +use std::time::Instant; + +use tract_core::internal::*; +use tract_core::ops::nn::Reducer; + +#[inline(never)] +fn scalar_min_per_row(data: &[f32], rows: usize, k: usize) -> f32 { + // replica of the old min_t scalar fold (branchy partial-ord) + let mut acc = 0f32; + for r in 0..rows { + let m = data[r * k..(r + 1) * k] + .iter() + .copied() + .fold(f32::MAX, |a, v| if a < v { a } else { v }); + acc += m; + } + acc +} + +fn main() -> TractResult<()> { + for (rows, k) in [(1024usize, 4096usize), (4096, 1024), (256, 65536)] { + let n = rows * k; + let data: Vec = + (0..n).map(|i| (((i * 2654435761) >> 13) as f32 / 1e6).sin()).collect(); + let t = Tensor::from_shape(&[rows, k], &data)?; + + // SIMD path (Reducer::Min) + for _ in 0..3 { + let _ = Reducer::Min.reduce(&[1], &t)?; + } + let runs = 50; + let s = Instant::now(); + for _ in 0..runs { + std::hint::black_box(Reducer::Min.reduce(&[1], &t)?); + } + let simd = s.elapsed().as_secs_f64() / runs as f64; + + // scalar replica of the old fold + for _ in 0..3 { + std::hint::black_box(scalar_min_per_row(&data, rows, k)); + } + let s = Instant::now(); + for _ in 0..runs { + std::hint::black_box(scalar_min_per_row(&data, rows, k)); + } + let scalar = s.elapsed().as_secs_f64() / runs as f64; + + println!( + "min [{rows}x{k}] axis1 : scalar {:>7.3} ms -> SIMD {:>7.3} ms ({:.1}x, {:>5.1} GB/s)", + scalar * 1e3, + simd * 1e3, + scalar / simd, + (n * 4) as f64 / simd / 1e9 + ); + } + Ok(()) +} diff --git a/core/src/ops/nn/reduce.rs b/core/src/ops/nn/reduce.rs index 5138379388..5cfbabafbc 100644 --- a/core/src/ops/nn/reduce.rs +++ b/core/src/ops/nn/reduce.rs @@ -309,6 +309,15 @@ fn min_t(v: ArrayViewD, _: ()) -> T where T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd, { + if T::datum_type() == f32::datum_type() + && let Some(slice) = v.as_slice() + && !slice.is_empty() + { + let slice = unsafe { transmute::<&[T], &[f32]>(slice) }; + let min = (tract_linalg::ops().min_f32)().run(slice).unwrap(); + // SAFETY: T is f32 in this branch (checked above). + return unsafe { std::mem::transmute_copy::(&min) }; + } v.fold(T::max_value(), |acc, &v| if acc < v { acc } else { v }) } @@ -666,4 +675,28 @@ mod tests { let got = Reducer::Max.reduce(&[1], &t1).unwrap(); assert_eq!(unsafe { got.as_slice_unchecked::() }, &[1.0, -2.0, 3.0]); } + + // Same coverage for the f32 min reduction (min_t -> SIMD min_f32 / scalar fold). + #[test] + fn reduce_min_f32_contiguous_and_strided() { + let (r, c) = (5usize, 37usize); // c not a multiple of the SIMD width (tail) + let data: Vec = (0..r * c).map(|i| ((i * 31 % 97) as f32) - 48.0).collect(); + let t = Tensor::from_shape(&[r, c], &data).unwrap(); + + // axis 1: per-row min — contiguous slices -> SIMD path. + let got = Reducer::Min.reduce(&[1], &t).unwrap(); + assert_eq!(got.shape(), &[r, 1]); + for (i, &g) in unsafe { got.as_slice_unchecked::() }.iter().enumerate() { + let want = data[i * c..(i + 1) * c].iter().copied().fold(f32::MAX, f32::min); + assert_eq!(g, want, "row {i}"); + } + + // axis 0: per-col min — strided slices -> scalar fold. + let got = Reducer::Min.reduce(&[0], &t).unwrap(); + assert_eq!(got.shape(), &[1, c]); + for (j, &g) in unsafe { got.as_slice_unchecked::() }.iter().enumerate() { + let want = (0..r).map(|i| data[i * c + j]).fold(f32::MAX, f32::min); + assert_eq!(g, want, "col {j}"); + } + } } diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index 7a61b13601..6d52af6991 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -460,6 +460,7 @@ pub fn plug(ops: &mut Ops) { ops.sigmoid_f32 = Box::new(|| arm64simd_sigmoid_f32_4n::ew()); ops.tanh_f32 = Box::new(|| arm64simd_tanh_f32_4n::ew()); ops.max_f32 = Box::new(|| arm64simd_max_f32_16n::red()); + ops.min_f32 = Box::new(|| arm64simd_min_f32_16n::red()); ops.sum_f32 = Box::new(|| arm64simd_sum_f32_16n::red()); ops.mul_by_scalar_f32 = Box::new(|| arm64simd_mul_by_scalar_f32_16n::ew()); ops.softmax2_fastcompact_f32 = Box::new(|| arm64simd_softmax2_fastcompact_f32_16n::red()); diff --git a/linalg/src/arm64/arm64simd.rs b/linalg/src/arm64/arm64simd.rs index 41e20047ff..af97b3c21d 100644 --- a/linalg/src/arm64/arm64simd.rs +++ b/linalg/src/arm64/arm64simd.rs @@ -4,6 +4,7 @@ mod gelu_fused; mod hardswish; mod leaky_relu; mod max; +mod min; mod panel_extract; mod rms_norm; mod silu; @@ -18,6 +19,7 @@ pub use gelu_fused::arm64simd_gelu_f32_4n_fused; pub use hardswish::arm64simd_hardswish_f32_8n; pub use leaky_relu::arm64simd_leaky_relu_f32_8n; pub use max::arm64simd_max_f32_16n; +pub use min::arm64simd_min_f32_16n; pub use rms_norm::rms_norm_f32 as arm64simd_rms_norm_f32; pub use silu::arm64simd_silu_f32_4n; pub use silu_fused::arm64simd_silu_f32_4n_fused; diff --git a/linalg/src/arm64/arm64simd/min.rs b/linalg/src/arm64/arm64simd/min.rs new file mode 100644 index 0000000000..720b9676d7 --- /dev/null +++ b/linalg/src/arm64/arm64simd/min.rs @@ -0,0 +1,52 @@ +use std::arch::aarch64::{float32x4_t, vdupq_n_f32, vgetq_lane_f32}; + +reduce_impl_wrap!( + f32, + arm64simd_min_f32_16n, + 16, + 4, + (), + f32::MAX, + #[inline(never)] + fn run(buf: &[f32], _: ()) -> f32 { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut out: float32x4_t = vdupq_n_f32(f32::MAX); + std::arch::asm!(" + mov v1.16b, v0.16b + mov v2.16b, v0.16b + mov v3.16b, v0.16b + 2: + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}], 64 + fmin v0.4s, v0.4s, v4.4s + fmin v1.4s, v1.4s, v5.4s + fmin v2.4s, v2.4s, v6.4s + fmin v3.4s, v3.4s, v7.4s + subs {len}, {len}, 16 + bne 2b + fmin v0.4s, v0.4s, v1.4s + fmin v2.4s, v2.4s, v3.4s + fmin v0.4s, v0.4s, v2.4s + fminv s0, v0.4s + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("v0") out, out("v1") _, out("v2") _, out("v3") _, + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + vgetq_lane_f32(out, 0) + } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a.min(b) + } +); + +#[cfg(test)] +mod test_arm64simd_min_f32_16n { + use super::*; + crate::min_frame_tests!(true, f32, arm64simd_min_f32_16n); +} diff --git a/linalg/src/frame/reduce/min.rs b/linalg/src/frame/reduce/min.rs new file mode 100644 index 0000000000..a772fe942b --- /dev/null +++ b/linalg/src/frame/reduce/min.rs @@ -0,0 +1,42 @@ +#[cfg(test)] +#[macro_use] +pub mod test { + use crate::LADatum; + use crate::frame::reduce::ReduceKer; + use num_traits::{AsPrimitive, Float}; + use proptest::test_runner::TestCaseResult; + + #[macro_export] + macro_rules! min_frame_tests { + ($cond:expr, $t: ty, $ker:ty) => { + proptest::proptest! { + #[test] + fn prop(xs in proptest::collection::vec(-25f32..25.0, 0..100)) { + if $cond { + $crate::frame::reduce::min::test::test_min::<$ker, $t>(&*xs).unwrap() + } + } + } + + #[test] + fn empty() { + if $cond { + $crate::frame::reduce::min::test::test_min::<$ker, $t>(&[]).unwrap() + } + } + }; + } + + pub fn test_min, T: LADatum + Float>(values: &[f32]) -> TestCaseResult + where + f32: AsPrimitive, + { + crate::setup_test_logger(); + let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); + crate::frame::reduce::test::test_reduce::( + &values, + ::max_value(), + |a, b| a.min(b), + ) + } +} diff --git a/linalg/src/frame/reduce/mod.rs b/linalg/src/frame/reduce/mod.rs index ecc13535f8..c24d6055bd 100644 --- a/linalg/src/frame/reduce/mod.rs +++ b/linalg/src/frame/reduce/mod.rs @@ -1,4 +1,5 @@ pub mod max; +pub mod min; pub mod softmax; pub mod sum; diff --git a/linalg/src/generic/reduce.rs b/linalg/src/generic/reduce.rs index 495c5e9476..646ba28d5d 100644 --- a/linalg/src/generic/reduce.rs +++ b/linalg/src/generic/reduce.rs @@ -50,6 +50,34 @@ pub mod max { } } +// Reduce generic implementation +pub mod min { + pub use tract_data::internal::f16; + + reduce_impl_wrap!( + f32, + SMin4, + 4, + 4, + (), + f32::MAX, + fn run(x: &[f32], _: ()) -> f32 { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + *x.iter().min_by(|a, b| a.total_cmp(b)).unwrap() + }, + fn reduce_two(a: f32, b: f32) -> f32 { + a.min(b) + } + ); + + #[cfg(test)] + #[macro_use] + pub mod s { + crate::min_frame_tests!(true, f32, crate::generic::reduce::min::SMin4); + } +} + // Reduce generic implementation pub mod sum { use crate::num_traits::Zero; diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 0f95eea65b..069fef7215 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -99,6 +99,7 @@ pub struct Ops { pub max_f16: Box Box> + Send + Sync>, pub max_f32: Box Box> + Send + Sync>, + pub min_f32: Box Box> + Send + Sync>, pub sum_f16: Box Box> + Send + Sync>, pub sum_f32: Box Box> + Send + Sync>, @@ -244,6 +245,7 @@ pub fn generic() -> Ops { lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::::new(table))), max_f16: Box::new(|| generic::reduce::max::HMax8::red()), max_f32: Box::new(|| generic::reduce::max::SMax4::red()), + min_f32: Box::new(|| generic::reduce::min::SMin4::red()), sum_f16: Box::new(|| generic::reduce::sum::HSum8::red()), sum_f32: Box::new(|| generic::reduce::sum::SSum4::red()), /* diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index e61baa2efe..fc5fbce52a 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -13,6 +13,7 @@ pub mod by_scalar; pub mod erf; mod intel; pub mod max; +pub mod min; pub mod panel_extract; pub mod rms_norm; pub mod softmax; @@ -42,6 +43,7 @@ fn plug_fma(ops: &mut Ops) { ops.mul_by_scalar_f32 = Box::new(|| by_scalar::x86_64_avx_f32_mul_by_scalar_32n::ew()); ops.max_f32 = Box::new(|| max::x86_64_fma_max_f32_32n::red()); + ops.min_f32 = Box::new(|| min::x86_64_fma_min_f32_32n::red()); ops.softmax2_fastcompact_f32 = Box::new(|| x86_64_fma_softmax2_fastcompact_f32_32n::red()); log::info!("sigmoid_f32, tanh_f32: x86_64/fma activated"); diff --git a/linalg/src/x86_64_fma/min.rs b/linalg/src/x86_64_fma/min.rs new file mode 100644 index 0000000000..09f967a788 --- /dev/null +++ b/linalg/src/x86_64_fma/min.rs @@ -0,0 +1,67 @@ +reduce_impl_wrap!( + f32, + x86_64_fma_min_f32_32n, + 32, + 8, + (), + f32::MAX, + #[inline(never)] + fn run(buf: &[f32], _: ()) -> f32 { + assert!(buf.len() % 32 == 0); + assert!(buf.len() > 0); + unsafe { x86_64_fma_min_f32_32n_run(buf) } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a.min(b) + } +); + +#[target_feature(enable = "avx")] +unsafe fn x86_64_fma_min_f32_32n_run(buf: &[f32]) -> f32 { + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut acc = f32::MAX; + std::arch::asm!(" + vbroadcastss ymm0, xmm0 + vmovaps ymm1, ymm0 + vmovaps ymm2, ymm0 + vmovaps ymm3, ymm0 + 2: + vmovaps ymm4, [{ptr}] + vmovaps ymm5, [{ptr} + 32] + vmovaps ymm6, [{ptr} + 64] + vmovaps ymm7, [{ptr} + 96] + vminps ymm0, ymm0, ymm4 + vminps ymm1, ymm1, ymm5 + vminps ymm2, ymm2, ymm6 + vminps ymm3, ymm3, ymm7 + add {ptr}, 128 + sub {len}, 32 + jnz 2b + vminps ymm0, ymm0, ymm1 + vminps ymm2, ymm2, ymm3 + vminps ymm0, ymm0, ymm2 + vperm2f128 ymm1, ymm0, ymm0, 1 // copy second half (4xf32) of ymm0 to ymm1 + vminps xmm0, xmm0, xmm1 // xmm0 contains 4 values to min + vpermilps xmm1, xmm0, 2 + (3 << 2) // second 2x32 bit half moved to top + vminps xmm0, xmm0, xmm1 // xmm0 containes 2 values + vpermilps xmm1, xmm0, 1 // second f32 to top + vminps xmm0, xmm0, xmm1 + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("ymm0") acc, + out("ymm1") _, out("ymm2") _, out("ymm3") _, + out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _ + ); + acc + } +} + +#[cfg(test)] +mod test_x86_64_fma_min_f32_32n { + use super::*; + crate::min_frame_tests!(is_x86_feature_detected!("avx2"), f32, x86_64_fma_min_f32_32n); +}