Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 77 additions & 5 deletions src/f5_tts/model/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(
def device(self):
return next(self.parameters()).device

@torch.no_grad()
def sample(
self,
cond: float["b n d"] | float["b nw"],
Expand All @@ -99,8 +98,80 @@ def sample(
duplicate_test=False,
t_inter=0.1,
edit_mask=None,
cache: bool = True,
):
self.eval()
"""Generate a mel-spectrogram via the CFM flow-matching ODE solve.

Args:
cache: When True (default, backward-compatible), the entire
sample() body runs under torch.no_grad(), self.eval() is
called, and the DiT transformer's KV cache is reused across
ODE steps for ~2x faster inference.

When False, none of that happens: the body runs with grad
tracking enabled, self.eval() is skipped (so the caller's
training mode is preserved -- dropout stays on for a
fresh-from-train() student), the transformer's KV cache is
bypassed, and the output has a populated grad_fn so that
loss.backward() works through the ODE solve.

Use cache=False for any gradient-based application:
consistency / progressive distillation, adversarial
distillation, RLHF on generated mels, or differentiable
rendering pipelines that feed the generated mel into a
downstream loss. ~2x slower per forward pass than
cache=True (you trade inference-cache reuse for grad-graph
construction), but enables the entire grad-based ecosystem
of techniques on top of CFM.
"""
# The sample() body was historically wrapped in @torch.no_grad()
# which disabled grad tracking unconditionally. We now gate
# grad-disabling on the cache flag: cache=True keeps the inference
# fast-path (no grad, eval mode, KV cache reuse); cache=False enables
# grad tracking, leaves the caller's train/eval state intact, and
# bypasses the KV cache so the autograd graph is preserved.
# We use set_grad_enabled with try/finally rather than a `with` block
# so the existing function body doesn't need to be re-indented.
if cache:
self.eval()
_prev_grad_enabled = torch.is_grad_enabled()
torch.set_grad_enabled(not cache)
try:
return self._sample_after_grad_setup(
cond=cond, text=text, duration=duration, lens=lens,
steps=steps, cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef, seed=seed,
max_duration=max_duration, vocoder=vocoder,
use_epss=use_epss, no_ref_audio=no_ref_audio,
duplicate_test=duplicate_test, t_inter=t_inter,
edit_mask=edit_mask, cache=cache,
)
finally:
torch.set_grad_enabled(_prev_grad_enabled)

def _sample_after_grad_setup(
self,
cond,
text,
duration,
*,
lens,
steps,
cfg_strength,
sway_sampling_coef,
seed,
max_duration,
vocoder,
use_epss,
no_ref_audio,
duplicate_test,
t_inter,
edit_mask,
cache,
):
"""Body of sample(); called after grad mode + eval mode are set up.
Split out so the public sample() can manage grad mode via try/finally
without re-indenting this code."""
# raw wave

if cond.ndim == 2:
Expand Down Expand Up @@ -173,7 +244,7 @@ def fn(t, x):
mask=mask,
drop_audio_cond=False,
drop_text=False,
cache=True,
cache=cache,
)
return pred

Expand All @@ -185,7 +256,7 @@ def fn(t, x):
time=t,
mask=mask,
cfg_infer=True,
cache=True,
cache=cache,
)
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
return pred + (pred - null_pred) * cfg_strength
Expand Down Expand Up @@ -216,7 +287,8 @@ def fn(t, x):
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)

trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
self.transformer.clear_cache()
if cache:
self.transformer.clear_cache()

sampled = trajectory[-1]
out = sampled
Expand Down