Skip to content

metal: f16 activations/output for the GGML matmul (drop f32 round-trips)#2366

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:perf/metal-ggml-f16-roundtrip
Open

metal: f16 activations/output for the GGML matmul (drop f32 round-trips)#2366
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:perf/metal-ggml-f16-roundtrip

Conversation

@czoli1976

Copy link
Copy Markdown
Contributor

What

The GGML matmul kernels hardcoded f32 output, and the q4_0 / f16-weight GEMV+GEMM paths required f32 activations. So a q40ef16 model (Q4_0 weights + f16 activations — the common on-device LLM layout, e.g. the Qwen builds in examples/causal_llm) bounced through f32 at every matmul: the metal transform inserted a f16→f32 cast on the activation and a f32→f16 cast on the output.

This makes the GGML matmul keep f16 end-to-end:

  • ggml_mm_mv.metal
    • kernel_mul_mv output pointer is now the activation type T1 (f16 activations → f16 output; *_f32 paths unchanged).
    • The q4_0 GEMV (mul_vec_q_n_f32_impl) is templated on the activation/output type — new kernel_mul_mv_q4_0_f16, still accumulating in f32.
    • The GEMM (kernel_mul_mm) is templated on the activation/output type: f16 activations are converted to f32 in threadgroup memory on load, and f16 output is written through the f32 simdgroup scratch (simdgroup_store can only target a float buffer). New kernel_mul_mm_f16_f16 / kernel_mul_mm_q4_0_f16.
  • ggml_gemm/mod.rs: output_dt returns the activation dtype; the GEMV/GEMM dispatch and dtype guards accept f16 activations and select the f16 kernels.
  • transform.rs: drop the forced f16→f32 activation upcast; with output_dt following the activation, the post-matmul f32→f16 cast also becomes a no-op and is no longer inserted.

Correctness

  • All 53 tract-metal GPU tests pass, including a new mmm_ggml_prop_q4_f16 prop test (q4_0 weights × f16 activations vs an f32 CPU reference).
  • End-to-end on Qwen3-1.7B q40ef16 (Metal), greedy decode output is identical before/after.

Benchmark

Qwen3-1.7B q40ef16, Metal decode, via the added examples/causal_llm complete_bench bin (mean of 3 × 96 tokens):

tok/s ms/token
baseline (f32 round-trip) ~41.6 24.0
this change (f16 direct) ~45.6 21.9

≈10% faster decode, identical output. The win comes from dropping ~2 cast dispatches per matmul (~200 matmuls/token) plus reading f16 activations / writing f16 output instead of f32.

No clash with open PRs

The only open-PR edit to the matmul area is #2320 flipping mod mfapub mod mfa on line 1 of matmul/mod.rs. This PR touches the ggml_gemm kernels, output_dt, the matmul lowering, and adds a test at the end of matmul/mod.rs — no overlap. Confirmed mergeable.

Files

  • metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal — templated activation/output dtype
  • metal/src/kernels/matmul/ggml_gemm/mod.rsoutput_dt + dispatch
  • metal/src/transform.rs — drop activation upcast
  • metal/src/kernels/matmul/mod.rsmmm_ggml_prop_q4_f16 test
  • examples/causal_llm/src/bin/complete_bench.rs — decode benchmark

🤖 Generated with Claude Code

The GGML matmul kernels hardcoded f32 output, and the q4_0 / f16-weight
GEMV+GEMM paths required f32 activations. So a q40ef16 model (Q4_0 weights, f16
activations — the common on-device LLM layout) bounced every matmul through
f32: the transform inserted a f16->f32 cast on the activation and a f32->f16
cast on the output.

Make the output dtype follow the activation dtype and let the kernels consume
f16 activations directly:

- ggml_mm_mv.metal: the mul_mv output pointer is now the activation type T1
  (f16 activations -> f16 output); the q4_0 GEMV is templated on the
  activation/output type (new kernel_mul_mv_q4_0_f16, accumulating in f32); the
  GEMM (kernel_mul_mm) is templated on the activation/output type, converting
  f16 activations to f32 in threadgroup memory and writing f16 output through the
  f32 simdgroup scratch (simdgroup_store only targets float). New
  kernel_mul_mm_f16_f16 / kernel_mul_mm_q4_0_f16 instantiations.
- ggml_gemm/mod.rs: output_dt returns the activation dtype; the GEMV/GEMM
  dispatch and dtype guards accept f16 activations and pick the f16 kernels.
- transform.rs: drop the forced f16->f32 activation upcast; output_dt now makes
  the post-matmul f32->f16 cast a no-op too.

Correctness: all 53 tract-metal GPU tests pass, including a new
mmm_ggml_prop_q4_f16 prop test (q4_0 weights x f16 activations vs f32 CPU
reference). End-to-end on Qwen3-1.7B q40ef16 (Metal), greedy output is identical
before/after.

Benchmark (Qwen3-1.7B q40ef16, Metal decode, examples/causal_llm complete_bench,
mean of 3 x 96 tokens):

  baseline (f32 round-trip): ~41.6 tok/s  (24.0 ms/token)
  this change (f16 direct) : ~45.6 tok/s  (21.9 ms/token)   ~10% faster

No clash with sonos#2320 (it only flips `mod mfa` -> `pub mod mfa`; this touches the
ggml_gemm kernels, output_dt and the matmul lowering).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant