From 3a2222204d2bfaccbac45bdb301dffed3408659e Mon Sep 17 00:00:00 2001 From: Adam Mohammed A Latif Date: Mon, 30 Mar 2026 11:09:53 +0000 Subject: [PATCH] air: make symbolic arena handles safe and generation-checked Replace raw `u32` arena offsets with opaque `SymbolicNodeRef` handles that carry arena_id, generation, and offset. Validate all three before any `read_unaligned`, turning four classes of UB into deterministic errors: 1. Forgeable handles: `SymbolicNodeRef` fields are private, so downstream code can no longer construct arbitrary indices. 2. Stale handles after clear: `clear_arena()` bumps a generation counter; old handles fail with `StaleGeneration`. 3. Cross-thread misuse: each thread-local arena gets a unique `arena_id` from a global `AtomicU32`; handles from another thread fail with `WrongArena`. 4. Index truncation: `alloc_node` uses `u32::try_from(offset)` instead of a bare `as u32` cast. `SymbolicExpression` remains `Copy`. The only downstream migration is `Operation(u32)` -> `Operation(SymbolicNodeRef)`, which is a mechanical change (one call site in `rec_aggregation::compilation`). Closes #170 --- Cargo.lock | 1 + crates/backend/air/Cargo.toml | 3 + crates/backend/air/src/symbolic.rs | 270 ++++++++++++++++++++-- crates/rec_aggregation/src/compilation.rs | 31 +-- 4 files changed, 277 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a1e508d94..10474b0b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -609,6 +609,7 @@ name = "mt-air" version = "0.1.0" dependencies = [ "mt-field", + "mt-koala-bear", "mt-poly", ] diff --git a/crates/backend/air/Cargo.toml b/crates/backend/air/Cargo.toml index 3868ecf2b..2494d5290 100644 --- a/crates/backend/air/Cargo.toml +++ b/crates/backend/air/Cargo.toml @@ -6,3 +6,6 @@ edition.workspace = true [dependencies] field = { path = "../field", package = "mt-field" } poly = { path = "../poly", package = "mt-poly" } + +[dev-dependencies] +koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } diff --git a/crates/backend/air/src/symbolic.rs b/crates/backend/air/src/symbolic.rs index f36286ce3..0bbd982f2 100644 --- a/crates/backend/air/src/symbolic.rs +++ b/crates/backend/air/src/symbolic.rs @@ -5,6 +5,7 @@ use core::iter::{Product, Sum}; use core::marker::PhantomData; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::cell::RefCell; +use std::sync::atomic::{AtomicU32, Ordering}; use field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing}; @@ -73,37 +74,149 @@ pub struct SymbolicNode { pub rhs: SymbolicExpression, // dummy (ZERO) for Neg } -// We use an arena as a trick to allow SymbolicExpression to be Copy -// (ugly trick but fine in practice since SymbolicExpression is only used once at the start of the program) +/// Opaque handle into the thread-local symbolic arena. +/// +/// Handles are scoped to a specific arena (thread) and generation (clear cycle). +/// Using a handle from a different thread or after the arena has been cleared will +/// produce a deterministic error instead of undefined behaviour. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct SymbolicNodeRef { + arena_id: u32, + generation: u32, + offset: u32, + _phantom: PhantomData F>, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum SymbolicNodeAccessError { + WrongArena, + StaleGeneration, + OutOfBounds, +} + +impl core::fmt::Display for SymbolicNodeAccessError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::WrongArena => { + write!(f, "symbolic node handle belongs to a different thread's arena") + } + Self::StaleGeneration => { + write!(f, "symbolic node handle is stale (arena was cleared)") + } + Self::OutOfBounds => write!(f, "symbolic node handle offset is out of bounds"), + } + } +} + +impl std::error::Error for SymbolicNodeAccessError {} + +#[derive(Debug)] +struct ArenaState { + arena_id: u32, + generation: u32, + bytes: Vec, +} + +impl ArenaState { + fn new() -> Self { + Self { + arena_id: next_arena_id(), + generation: 0, + bytes: Vec::new(), + } + } +} + +static NEXT_ARENA_ID: AtomicU32 = AtomicU32::new(1); + +fn next_arena_id_after(id: u32) -> Option { + id.checked_add(1) +} + +fn next_arena_id() -> u32 { + NEXT_ARENA_ID + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, next_arena_id_after) + .expect("symbolic arena id overflow") +} + +fn checked_arena_allocation_range(offset: usize, node_size: usize) -> (u32, usize) { + let end = offset + .checked_add(node_size) + .expect("symbolic arena allocation overflow"); + let offset_u32 = u32::try_from(offset).expect("symbolic arena exceeded u32::MAX bytes"); + u32::try_from(end).expect("symbolic arena exceeded u32::MAX bytes"); + (offset_u32, end) +} + +// We use an arena as a trick to allow SymbolicExpression to be Copy. +// Handles carry arena_id + generation so that stale or cross-thread use +// is caught deterministically instead of reading garbage bytes. thread_local! { - static ARENA: RefCell> = const { RefCell::new(Vec::new()) }; + static ARENA: RefCell = RefCell::new(ArenaState::new()); } -fn alloc_node(node: SymbolicNode) -> u32 { +fn clear_arena() { ARENA.with(|arena| { - let mut bytes = arena.borrow_mut(); + let mut state = arena.borrow_mut(); + state.generation = state + .generation + .checked_add(1) + .expect("symbolic arena generation overflow"); + state.bytes.clear(); + }); +} + +fn alloc_node(node: SymbolicNode) -> SymbolicNodeRef { + ARENA.with(|arena| { + let mut state = arena.borrow_mut(); let node_size = std::mem::size_of::>(); - let idx = bytes.len(); - bytes.resize(idx + node_size, 0); + let offset = state.bytes.len(); + let (offset_u32, end) = checked_arena_allocation_range(offset, node_size); + state.bytes.resize(end, 0); + // SAFETY: We just resized the buffer to `end` bytes, so `offset..end` is valid. unsafe { - std::ptr::write_unaligned(bytes.as_mut_ptr().add(idx) as *mut SymbolicNode, node); + std::ptr::write_unaligned(state.bytes.as_mut_ptr().add(offset).cast::>(), node); + } + SymbolicNodeRef { + arena_id: state.arena_id, + generation: state.generation, + offset: offset_u32, + _phantom: PhantomData, } - idx as u32 }) } -pub fn get_node(idx: u32) -> SymbolicNode { +pub fn try_get_node(handle: SymbolicNodeRef) -> Result, SymbolicNodeAccessError> { ARENA.with(|arena| { - let bytes = arena.borrow(); - unsafe { std::ptr::read_unaligned(bytes.as_ptr().add(idx as usize) as *const SymbolicNode) } + let state = arena.borrow(); + if state.arena_id != handle.arena_id { + return Err(SymbolicNodeAccessError::WrongArena); + } + if state.generation != handle.generation { + return Err(SymbolicNodeAccessError::StaleGeneration); + } + let offset = handle.offset as usize; + let node_size = std::mem::size_of::>(); + let end = offset + .checked_add(node_size) + .ok_or(SymbolicNodeAccessError::OutOfBounds)?; + if end > state.bytes.len() { + return Err(SymbolicNodeAccessError::OutOfBounds); + } + // SAFETY: We verified that `offset..end` is within the arena buffer. + Ok(unsafe { std::ptr::read_unaligned(state.bytes.as_ptr().add(offset).cast::>()) }) }) } +pub fn get_node(handle: SymbolicNodeRef) -> SymbolicNode { + try_get_node(handle).expect("invalid or stale symbolic node handle") +} + #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum SymbolicExpression { Variable(SymbolicVariable), Constant(F), - Operation(u32), // index into thread-local arena + Operation(SymbolicNodeRef), } impl Default for SymbolicExpression { @@ -321,8 +434,7 @@ pub fn get_symbolic_constraints_and_bus_data_values(air: &A) - where A::ExtraData: Default, { - // Clear the arena before building constraints - ARENA.with(|arena| arena.borrow_mut().clear()); + clear_arena(); let mut builder = SymbolicAirBuilder::::new(air.n_columns(), air.n_shift_columns()); air.eval(&mut builder, &Default::default()); @@ -332,3 +444,131 @@ where builder.bus_data_values.unwrap(), ) } + +#[cfg(test)] +mod tests { + use super::*; + use koala_bear::KoalaBear; + + type F = KoalaBear; + + const _: () = { + const fn assert_copy() {} + assert_copy::>(); + assert_copy::>(); + }; + + #[test] + fn roundtrip_alloc_get() { + clear_arena(); + let a = SymbolicExpression::::Constant(F::ONE); + let b = SymbolicExpression::::Constant(F::TWO); + let handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Add, + lhs: a, + rhs: b, + }); + let node = get_node::(handle); + assert_eq!(node.op, SymbolicOperation::Add); + assert_eq!(node.lhs, a); + assert_eq!(node.rhs, b); + } + + #[test] + fn stale_handle_rejected_after_clear() { + clear_arena(); + let handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Mul, + lhs: SymbolicExpression::::ONE, + rhs: SymbolicExpression::::TWO, + }); + assert!(try_get_node::(handle).is_ok()); + clear_arena(); + assert!(matches!( + try_get_node::(handle), + Err(SymbolicNodeAccessError::StaleGeneration) + )); + } + + #[test] + fn old_handle_cannot_read_new_generation_bytes() { + clear_arena(); + let old_handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Add, + lhs: SymbolicExpression::::ONE, + rhs: SymbolicExpression::::TWO, + }); + clear_arena(); + let _new_handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Sub, + lhs: SymbolicExpression::::ZERO, + rhs: SymbolicExpression::::ONE, + }); + assert!(matches!( + try_get_node::(old_handle), + Err(SymbolicNodeAccessError::StaleGeneration) + )); + } + + #[test] + fn wrong_thread_handle_rejected() { + clear_arena(); + let handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Neg, + lhs: SymbolicExpression::::ONE, + rhs: SymbolicExpression::::ZERO, + }); + let result = std::thread::spawn(move || try_get_node::(handle)).join().unwrap(); + assert!(matches!(result, Err(SymbolicNodeAccessError::WrongArena))); + } + + #[test] + fn out_of_bounds_handle_rejected() { + clear_arena(); + let bogus = SymbolicNodeRef:: { + arena_id: ARENA.with(|a| a.borrow().arena_id), + generation: ARENA.with(|a| a.borrow().generation), + offset: 999_999, + _phantom: PhantomData, + }; + assert!(matches!( + try_get_node::(bogus), + Err(SymbolicNodeAccessError::OutOfBounds) + )); + } + + #[test] + fn offset_truncation_detected() { + assert!(std::panic::catch_unwind(|| checked_arena_allocation_range(u32::MAX as usize, 1)).is_err()); + } + + #[test] + fn arena_id_overflow_detected() { + assert!(next_arena_id_after(u32::MAX).is_none()); + } + + #[test] + fn arithmetic_produces_valid_handles() { + clear_arena(); + let var = SymbolicExpression::::Variable(SymbolicVariable::new(0)); + let c = SymbolicExpression::::Constant(F::TWO); + let sum = var + c; + if let SymbolicExpression::Operation(handle) = sum { + let node = get_node::(handle); + assert_eq!(node.op, SymbolicOperation::Add); + assert_eq!(node.lhs, var); + assert_eq!(node.rhs, c); + } else { + panic!("expected Operation variant from variable + constant"); + } + + let neg = -var; + if let SymbolicExpression::Operation(handle) = neg { + let node = get_node::(handle); + assert_eq!(node.op, SymbolicOperation::Neg); + assert_eq!(node.lhs, var); + } else { + panic!("expected Operation variant from neg(variable)"); + } + } +} diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index f3752bf0b..9f5df0c29 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -483,7 +483,7 @@ fn all_air_evals_in_zk_dsl() -> String { const AIR_INNER_VALUES_VAR: &str = "inner_evals"; struct AirCodegenCtx { - expr_cache: HashMap, + expr_cache: HashMap, String>, consts_cache: HashMap, String>, ef_const_cache: HashMap, ctr: Counter, @@ -594,14 +594,14 @@ fn eval_air_constraint( let v = match expr { SymbolicExpression::Constant(c) => ctx.write_embedded_constant(c.as_canonical_u32(), res), SymbolicExpression::Variable(v) => format!("{} + DIM * {}", AIR_INNER_VALUES_VAR, v.index), - SymbolicExpression::Operation(idx) => { - if let Some(v) = ctx.expr_cache.get(&idx) { + SymbolicExpression::Operation(handle) => { + if let Some(v) = ctx.expr_cache.get(&handle) { v.clone() - } else if let Some(v) = try_emit_dot_product_be(idx, dest, ctx, res) { - ctx.expr_cache.insert(idx, v.clone()); + } else if let Some(v) = try_emit_dot_product_be(handle, dest, ctx, res) { + ctx.expr_cache.insert(handle, v.clone()); return v; } else { - let node = get_node::(idx); + let node = get_node::(handle); let v = match node.op { SymbolicOperation::Neg => { let a = eval_air_constraint(node.lhs, None, ctx, res); @@ -611,7 +611,7 @@ fn eval_air_constraint( } _ => eval_air_binary_op(node.op, node.lhs, node.rhs, dest, ctx, res), }; - ctx.expr_cache.insert(idx, v.clone()); + ctx.expr_cache.insert(handle, v.clone()); v } } @@ -626,16 +626,21 @@ fn eval_air_constraint( /// Detect `0 + c0*x0 + c1*x1 + ... + cn*xn` in the expression tree and emit /// a single `dot_product_be` precompile call. Returns None if the pattern doesn't match. -fn try_emit_dot_product_be(idx: u32, dest: Option<&str>, ctx: &mut AirCodegenCtx, res: &mut String) -> Option { +fn try_emit_dot_product_be( + handle: SymbolicNodeRef, + dest: Option<&str>, + ctx: &mut AirCodegenCtx, + res: &mut String, +) -> Option { // Walk the left-spine of Add(_, Mul(Const, _)) nodes down to Constant(ZERO). let mut constants = Vec::new(); let mut operands = Vec::new(); - let mut current = SymbolicExpression::::Operation(idx); + let mut current = SymbolicExpression::::Operation(handle); loop { match current { SymbolicExpression::Constant(c) if c == F::ZERO && constants.len() >= 2 => break, SymbolicExpression::Operation(op_idx) => { - if op_idx != idx && ctx.expr_cache.contains_key(&op_idx) { + if op_idx != handle && ctx.expr_cache.contains_key(&op_idx) { return None; } let node = get_node::(op_idx); @@ -699,12 +704,12 @@ fn try_emit_dot_product_be(idx: u32, dest: Option<&str>, ctx: &mut AirCodegenCtx fn try_find_contiguous_buffer(operands: &[SymbolicExpression], ctx: &AirCodegenCtx) -> Option { let mut base: Option<&str> = None; for (i, op) in operands.iter().enumerate() { - let idx = match op { - SymbolicExpression::Operation(idx) => *idx, + let handle = match op { + SymbolicExpression::Operation(handle) => *handle, _ => return None, }; let suffix = format!(" + DIM * {}", i); - let this_base = ctx.expr_cache.get(&idx)?.strip_suffix(&suffix)?; + let this_base = ctx.expr_cache.get(&handle)?.strip_suffix(&suffix)?; match base { None => base = Some(this_base), Some(b) if b == this_base => {}