metal: f16 activations/output for the GGML matmul (drop f32 round-trips)#2366
Open
czoli1976 wants to merge 1 commit into
Open
metal: f16 activations/output for the GGML matmul (drop f32 round-trips)#2366czoli1976 wants to merge 1 commit into
czoli1976 wants to merge 1 commit into
Conversation
The GGML matmul kernels hardcoded f32 output, and the q4_0 / f16-weight GEMV+GEMM paths required f32 activations. So a q40ef16 model (Q4_0 weights, f16 activations — the common on-device LLM layout) bounced every matmul through f32: the transform inserted a f16->f32 cast on the activation and a f32->f16 cast on the output. Make the output dtype follow the activation dtype and let the kernels consume f16 activations directly: - ggml_mm_mv.metal: the mul_mv output pointer is now the activation type T1 (f16 activations -> f16 output); the q4_0 GEMV is templated on the activation/output type (new kernel_mul_mv_q4_0_f16, accumulating in f32); the GEMM (kernel_mul_mm) is templated on the activation/output type, converting f16 activations to f32 in threadgroup memory and writing f16 output through the f32 simdgroup scratch (simdgroup_store only targets float). New kernel_mul_mm_f16_f16 / kernel_mul_mm_q4_0_f16 instantiations. - ggml_gemm/mod.rs: output_dt returns the activation dtype; the GEMV/GEMM dispatch and dtype guards accept f16 activations and pick the f16 kernels. - transform.rs: drop the forced f16->f32 activation upcast; output_dt now makes the post-matmul f32->f16 cast a no-op too. Correctness: all 53 tract-metal GPU tests pass, including a new mmm_ggml_prop_q4_f16 prop test (q4_0 weights x f16 activations vs f32 CPU reference). End-to-end on Qwen3-1.7B q40ef16 (Metal), greedy output is identical before/after. Benchmark (Qwen3-1.7B q40ef16, Metal decode, examples/causal_llm complete_bench, mean of 3 x 96 tokens): baseline (f32 round-trip): ~41.6 tok/s (24.0 ms/token) this change (f16 direct) : ~45.6 tok/s (21.9 ms/token) ~10% faster No clash with sonos#2320 (it only flips `mod mfa` -> `pub mod mfa`; this touches the ggml_gemm kernels, output_dt and the matmul lowering). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This was referenced Jun 14, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
The GGML matmul kernels hardcoded f32 output, and the q4_0 / f16-weight GEMV+GEMM paths required f32 activations. So a
q40ef16model (Q4_0 weights + f16 activations — the common on-device LLM layout, e.g. the Qwen builds inexamples/causal_llm) bounced through f32 at every matmul: the metal transform inserted af16→f32cast on the activation and af32→f16cast on the output.This makes the GGML matmul keep f16 end-to-end:
ggml_mm_mv.metalkernel_mul_mvoutput pointer is now the activation typeT1(f16 activations → f16 output;*_f32paths unchanged).mul_vec_q_n_f32_impl) is templated on the activation/output type — newkernel_mul_mv_q4_0_f16, still accumulating in f32.kernel_mul_mm) is templated on the activation/output type: f16 activations are converted to f32 in threadgroup memory on load, and f16 output is written through the f32 simdgroup scratch (simdgroup_storecan only target a float buffer). Newkernel_mul_mm_f16_f16/kernel_mul_mm_q4_0_f16.ggml_gemm/mod.rs:output_dtreturns the activation dtype; the GEMV/GEMM dispatch and dtype guards accept f16 activations and select the f16 kernels.transform.rs: drop the forcedf16→f32activation upcast; withoutput_dtfollowing the activation, the post-matmulf32→f16cast also becomes a no-op and is no longer inserted.Correctness
tract-metalGPU tests pass, including a newmmm_ggml_prop_q4_f16prop test (q4_0 weights × f16 activations vs an f32 CPU reference).Benchmark
Qwen3-1.7B q40ef16, Metal decode, via the added
examples/causal_llmcomplete_benchbin (mean of 3 × 96 tokens):≈10% faster decode, identical output. The win comes from dropping ~2 cast dispatches per matmul (~200 matmuls/token) plus reading f16 activations / writing f16 output instead of f32.
No clash with open PRs
The only open-PR edit to the matmul area is #2320 flipping
mod mfa→pub mod mfaon line 1 ofmatmul/mod.rs. This PR touches theggml_gemmkernels,output_dt, the matmul lowering, and adds a test at the end ofmatmul/mod.rs— no overlap. Confirmed mergeable.Files
metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal— templated activation/output dtypemetal/src/kernels/matmul/ggml_gemm/mod.rs—output_dt+ dispatchmetal/src/transform.rs— drop activation upcastmetal/src/kernels/matmul/mod.rs—mmm_ggml_prop_q4_f16testexamples/causal_llm/src/bin/complete_bench.rs— decode benchmark🤖 Generated with Claude Code