core: drop per-run String/TypedFact alloc in resolve_symbols_with_states#2364
Open
czoli1976 wants to merge 1 commit into
Open
core: drop per-run String/TypedFact alloc in resolve_symbols_with_states#2364czoli1976 wants to merge 1 commit into
czoli1976 wants to merge 1 commit into
Conversation
fd9b857 to
9ca9631
Compare
`SimpleState::resolve_symbols_with_states` runs once per `run()` — i.e. once
per decode token for an LLM. It selected the stateful ops that need symbol
resolution with `s.init_tensor_fact().is_some()`, but `init_tensor_fact()`
clones a `String` (the cache name) and a `TypedFact` and returns them in an
`Option` purely so the result can be tested with `is_some()` and dropped.
Add an allocation-free `OpState::has_init_tensor_fact() -> bool` predicate that
mirrors `init_tensor_fact()`, and use it for the filter. The set of ops that
override `init_tensor_fact` to return `Some` (the transformers and GPU KV-cache
states, with the metal/cuda fused ops delegating) is exactly the set that
overrides `resolve_symbols`, so the filter selects the same states as before —
only the per-state allocation is removed. A drift-guard test keeps the two
methods in sync.
This is a micro-optimization: it removes exactly one heap allocation per
KV-cache op per decode step (2 * n_layers per token), with no change in
behaviour. On a real model it is a small fraction of per-token work and does
not move wall-clock; the value is reduced allocator pressure on the per-token
hot path, most visible when layer count is high relative to compute and under
allocator contention.
Benchmarks (added as transformers examples):
kv_resolve_probe (compute-light, N KV caches), allocs/decode-step:
N=16: 52 -> 36 N=32: 101 -> 69
N=64: 198 -> 134 N=128: 391 -> 263 (exactly N fewer; ~7-20% faster)
llm_decode_bench (Qwen3-1.7B q40, folded decode, 56 KV caches):
allocs/token: 17444 -> 17388 (exactly -56 = 2 * 28 layers, deterministic)
tokens/sec: ~19.6 -> ~19.6 (unchanged, within noise)
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
9ca9631 to
3872934
Compare
Collaborator
|
Rebased! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
SimpleState::resolve_symbols_with_statesruns once perrun()— i.e. once per decode token for an autoregressive model. It selects the stateful ops that need symbol resolution with:but
OpState::init_tensor_fact()clones aString(the cache name) and aTypedFact, wraps them in anOption, and returns them — only for the result to be tested withis_some()and immediately dropped.This PR adds an allocation-free predicate
OpState::has_init_tensor_fact() -> boolthat mirrorsinit_tensor_fact(), and uses it for the filter.The set of ops that override
init_tensor_factto returnSome(the transformers + GPUDynKeyValueCachestates, with the metal/cuda fused ops delegating) is exactly the set that overridesresolve_symbols, so the filter selects the same states as before — only the per-state allocation is removed. A drift-guard test (has_init_tensor_fact_matches_init_tensor_fact) keeps the two in sync, since if they ever disagreed an op'sresolve_symbolswould silently stop running.Honesty up front: this is a micro-optimization
It removes exactly one heap allocation per KV-cache op per decode step (
2 * n_layersper token), with no behavioural change. On a real, compute-bound model it is a tiny fraction of per-token work and does not move wall-clock time. The value is reduced allocator pressure on the per-token hot path — most visible when layer count is high relative to compute, and under allocator contention. I'm filing it because it's a strictly-better, zero-risk change with a clear measurement, not because it's a speedup you'll feel on a 1.7B model.Benchmarks
Two reproducible benchmark examples are included under
transformers/examples/, both using a counting global allocator.kv_resolve_probe— compute-light model with N realDynKeyValueCacheops, allocations per decode step:Allocations drop by exactly N (one
Stringclone per cache per step; theTypedFactclone was inline-smallvec and didn't hit the heap). Wall-time is consistently lower and scales with N when compute is light.llm_decode_bench— end-to-end, Qwen3-1.7B q40, foldedDynKeyValueCachedecode (56 caches = 28 layers × 2), persistentSimpleState, 128 decode tokens × 3 runs:Exactly −56 allocations/token (=
2 * n_layers), deterministic across runs, no wall-clock regression.(Note: the
causal_llmexample unfolds KV caches into explicit model I/O, which removes these stateful ops entirely, so it does not exercise this path — hence the dedicated folded-modellm_decode_bench.)Testing
cargo test -p tract-core(247) and-p tract-transformers(incl. thedyn_kv_cacheNNEF round-trip + the new drift-guard test) pass.Files
core/src/ops/mod.rs— newhas_init_tensor_fact()trait method (defaultfalse)core/src/plan.rs— use it in the per-run filtertransformers/src/ops/dyn_kv_cache.rs— overridetrue+ drift-guard testgpu/src/ops/dyn_kv_cache.rs— overridetruemetal/src/ops/fused_axis_op.rs,cuda/src/ops/fused_axis_op.rs— delegatetransformers/examples/{kv_resolve_probe,llm_decode_bench}.rs— benchmarks🤖 Generated with Claude Code