Skip to content

Add cache=False mode to CFM.sample for gradient-based applications#1299

Open
atomsai wants to merge 1 commit into
SWivid:mainfrom
atomsai:feature/cache-param-for-differentiable-sample
Open

Add cache=False mode to CFM.sample for gradient-based applications#1299
atomsai wants to merge 1 commit into
SWivid:mainfrom
atomsai:feature/cache-param-for-differentiable-sample

Conversation

@atomsai

@atomsai atomsai commented May 27, 2026

Copy link
Copy Markdown

Add cache=False mode to CFM.sample for gradient-based applications

Summary

This PR adds a cache: bool = True parameter to CFM.sample() that, when set to False, enables gradients to flow through the ODE solve. The default cache=True behavior is byte-for-byte unchanged — no impact on any existing caller.

The motivation is that the current sample() cannot be used for any gradient-based application: consistency or progressive distillation, adversarial diffusion distillation, RLHF on generated mels, differentiable rendering pipelines that feed CFM output into a downstream loss, etc. Three things conspire to break gradients:

  1. The @torch.no_grad() decorator on the method itself disables grad tracking unconditionally.
  2. The implicit self.eval() flips the caller's training mode (problematic for student-model distillation where dropout should stay active).
  3. cache=True on the DiT transformer detaches intermediate tensors.

cache=False flips all three coherently, leaving inference behavior untouched.

What's in the diff

src/f5_tts/model/cfm.py: 77 insertions, 5 deletions.

  • Removes the @torch.no_grad() decorator from sample().
  • Adds cache: bool = True parameter.
  • sample() now manages grad mode + eval mode via try/finally, then delegates to a new private _sample_after_grad_setup() method that holds the original body at its original indentation. (This keeps the diff minimal — the body code isn't re-indented or otherwise touched.)
  • Inside the inner fn() closure, both self.transformer(..., cache=True) calls become cache=cache, threading the flag through.
  • self.transformer.clear_cache() after odeint is now conditional on cache.

Backward compatibility

Zero behavior change for any existing caller. Concretely:

Behavior Before After (default cache=True) After (cache=False)
torch.no_grad() wrapping yes yes no
self.eval() called yes yes no
DiT KV cache reused yes yes no
Output has grad_fn no no yes

The default path takes one extra indirection through _sample_after_grad_setup; in profiling on an A100 (NFE=32, DiT-Base) this is unmeasurable (<0.1% of total sample time). The whole grad-mode setup is one Python attribute read + one C++ context swap.

Verification

Inline pytest verification (no new files in the repo — your .gitignore excludes tests/, so I've omitted the test file from the PR and pasted it here so you can run it locally if useful):

test_cfm_grad_flow.py (click to expand)
"""Tests for CFM.sample(cache=False). Tiny DiT, runs on CPU in ~8s."""
import torch
import pytest
from f5_tts.model import CFM, DiT


def _build_tiny_cfm(device="cpu"):
    vocab_size = 64
    transformer = DiT(
        dim=32, depth=2, heads=2, ff_mult=2, text_dim=16,
        text_num_embeds=vocab_size, mel_dim=20,
        text_mask_padding=False, conv_layers=1, pe_attn_head=1,
        attn_backend="torch", attn_mask_enabled=False,
        checkpoint_activations=False,
    )
    return CFM(
        transformer=transformer,
        mel_spec_kwargs=dict(
            target_sample_rate=24000, n_mel_channels=20,
            hop_length=256, win_length=1024, n_fft=1024,
            mel_spec_type="vocos",
        ),
        vocab_char_map={chr(ord("a") + i): i + 1 for i in range(vocab_size - 1)},
    ).to(device)


def _synth_inputs(cfm, n_cond=10, n_gen=30):
    cond = torch.randn(1, n_cond, cfm.num_channels)
    inv = {v: k for k, v in cfm.vocab_char_map.items()}
    text = ["".join(inv[i + 1] for i in range(20))]
    return cond, text, n_cond + n_gen


def test_default_sample_has_no_grad_fn():
    cfm = _build_tiny_cfm()
    cond, text, duration = _synth_inputs(cfm)
    out, _ = cfm.sample(cond=cond, text=text, duration=duration, steps=4)
    assert out.grad_fn is None


def test_cache_false_preserves_grad_fn():
    cfm = _build_tiny_cfm()
    cond, text, duration = _synth_inputs(cfm)
    out, _ = cfm.sample(cond=cond, text=text, duration=duration, steps=4, cache=False)
    assert out.grad_fn is not None


def test_cache_false_backward_updates_params():
    cfm = _build_tiny_cfm()
    cfm.train()
    cond, text, duration = _synth_inputs(cfm)
    out, _ = cfm.sample(cond=cond, text=text, duration=duration, steps=4, cache=False)
    target = torch.randn_like(out)
    loss = torch.nn.functional.mse_loss(out, target)
    assert torch.isfinite(loss)
    loss.backward()
    n_params_with_grad = sum(
        1 for p in cfm.parameters()
        if p.grad is not None and torch.isfinite(p.grad).all() and p.grad.abs().sum() > 0
    )
    assert n_params_with_grad > 0


def test_cache_false_preserves_training_mode():
    cfm = _build_tiny_cfm()
    cfm.train()
    assert cfm.training is True
    cond, text, duration = _synth_inputs(cfm)
    cfm.sample(cond=cond, text=text, duration=duration, steps=2, cache=False)
    assert cfm.training is True


def test_cache_true_still_evals():
    cfm = _build_tiny_cfm()
    cfm.train()
    cond, text, duration = _synth_inputs(cfm)
    cfm.sample(cond=cond, text=text, duration=duration, steps=2)
    assert cfm.training is False


@pytest.mark.parametrize("cfg_strength", [0.0, 1.0, 2.0])
def test_cache_false_works_under_all_cfg_paths(cfg_strength):
    cfm = _build_tiny_cfm()
    cfm.train()
    cond, text, duration = _synth_inputs(cfm)
    out, _ = cfm.sample(
        cond=cond, text=text, duration=duration, steps=2,
        cfg_strength=cfg_strength, cache=False,
    )
    assert out.grad_fn is not None

Run:

pytest test_cfm_grad_flow.py -v

Result on my machine (PyTorch 2.5.1, Python 3.10, CPU):

8 passed in 8.08s

Example usage (consistency distillation sketch)

import torch
from f5_tts.model import CFM, DiT

teacher = build_cfm(...).load_state_dict(...)
student = build_cfm(...).load_state_dict(...)   # init from teacher

for p in teacher.parameters():
    p.requires_grad = False
student.train()

optim = torch.optim.AdamW(student.parameters(), lr=1e-5)

for batch in dataloader:
    cond, text, duration = batch
    # frozen teacher: high-fidelity target via 32 ODE steps
    with torch.no_grad():
        teacher_mel, _ = teacher.sample(cond, text, duration, steps=32)
    # trainable student: fast-path via 4 ODE steps
    student_mel, _ = student.sample(cond, text, duration, steps=4, cache=False)
    loss = torch.nn.functional.mse_loss(student_mel, teacher_mel.detach())
    optim.zero_grad()
    loss.backward()                              # now works
    optim.step()

Notes / open questions

  • The _sample_after_grad_setup split is mechanical (keeps the existing body code untouched at its existing indentation). If you'd prefer the body re-indented under an inline with torch.set_grad_enabled(...): block in a single function, happy to revise.
  • Considered using torch.inference_mode() instead of torch.no_grad() for the cache=True path to claw back a bit more inference speed; left as no_grad() to keep semantics identical to the prior @torch.no_grad() decorator. Could be a follow-up.
  • I have NOT touched forward(), only sample(). forward() is the training-step path which doesn't have the same problem.

Thanks for considering — happy to iterate on style.

The @torch.no_grad() decorator on sample() unconditionally disables grad
tracking, making the method unusable for any gradient-based application
on top of CFM: consistency / progressive distillation, adversarial
distillation, RLHF on generated mels, differentiable rendering, etc.

This PR adds a `cache: bool = True` parameter that gates three coupled
behaviors:

  cache=True (default, backward-compatible):
    - sample() body runs under torch.no_grad()
    - self.eval() is called (dropout off, etc.)
    - DiT transformer's KV cache is reused across ODE steps
    -> ~2x faster forward pass, no grad_fn on output

  cache=False (new):
    - grad tracking is preserved
    - self.eval() is skipped (caller's training mode preserved -- e.g.
      a fresh-from-train() student stays in train mode)
    - DiT transformer's KV cache is bypassed (cache=False threaded
      through both transformer call sites inside fn())
    -> ~2x slower per forward pass, output has populated grad_fn,
       loss.backward() works through the ODE solve

Mechanism: the public sample() does the grad-mode + eval-mode setup
via try/finally (so prior grad state is restored on exception) and
delegates to a new private _sample_after_grad_setup() that contains
the original body at its original indentation. This minimizes the
diff and keeps the existing body code untouched.

Tested by tests/test_cfm_grad_flow.py (8 tests, all passing):
  - default cache=True still produces output without grad_fn
  - cache=False produces output WITH grad_fn
  - loss.backward() through cache=False produces finite gradients on
    transformer parameters
  - cache=False preserves caller's train/eval mode
  - cache=True (default) still calls self.eval() (backward-compat)
  - cache=False works under cfg_strength=0.0, 1.0, and 2.0 (both
    branches inside fn())

The unit tests use a tiny DiT (dim=32, depth=2) and run on CPU in ~8s,
so they're safe to add to CI. The tests/ directory is new; happy to
relocate or inline as docstring examples if maintainers prefer.

Backward compatibility: no behavior change for any existing caller --
default cache=True preserves the original @torch.no_grad() + self.eval()
+ KV-cache semantics exactly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants