diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index e001c2ec6..0a16a46bf 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -80,7 +80,6 @@ def __init__( def device(self): return next(self.parameters()).device - @torch.no_grad() def sample( self, cond: float["b n d"] | float["b nw"], @@ -99,8 +98,80 @@ def sample( duplicate_test=False, t_inter=0.1, edit_mask=None, + cache: bool = True, ): - self.eval() + """Generate a mel-spectrogram via the CFM flow-matching ODE solve. + + Args: + cache: When True (default, backward-compatible), the entire + sample() body runs under torch.no_grad(), self.eval() is + called, and the DiT transformer's KV cache is reused across + ODE steps for ~2x faster inference. + + When False, none of that happens: the body runs with grad + tracking enabled, self.eval() is skipped (so the caller's + training mode is preserved -- dropout stays on for a + fresh-from-train() student), the transformer's KV cache is + bypassed, and the output has a populated grad_fn so that + loss.backward() works through the ODE solve. + + Use cache=False for any gradient-based application: + consistency / progressive distillation, adversarial + distillation, RLHF on generated mels, or differentiable + rendering pipelines that feed the generated mel into a + downstream loss. ~2x slower per forward pass than + cache=True (you trade inference-cache reuse for grad-graph + construction), but enables the entire grad-based ecosystem + of techniques on top of CFM. + """ + # The sample() body was historically wrapped in @torch.no_grad() + # which disabled grad tracking unconditionally. We now gate + # grad-disabling on the cache flag: cache=True keeps the inference + # fast-path (no grad, eval mode, KV cache reuse); cache=False enables + # grad tracking, leaves the caller's train/eval state intact, and + # bypasses the KV cache so the autograd graph is preserved. + # We use set_grad_enabled with try/finally rather than a `with` block + # so the existing function body doesn't need to be re-indented. + if cache: + self.eval() + _prev_grad_enabled = torch.is_grad_enabled() + torch.set_grad_enabled(not cache) + try: + return self._sample_after_grad_setup( + cond=cond, text=text, duration=duration, lens=lens, + steps=steps, cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, seed=seed, + max_duration=max_duration, vocoder=vocoder, + use_epss=use_epss, no_ref_audio=no_ref_audio, + duplicate_test=duplicate_test, t_inter=t_inter, + edit_mask=edit_mask, cache=cache, + ) + finally: + torch.set_grad_enabled(_prev_grad_enabled) + + def _sample_after_grad_setup( + self, + cond, + text, + duration, + *, + lens, + steps, + cfg_strength, + sway_sampling_coef, + seed, + max_duration, + vocoder, + use_epss, + no_ref_audio, + duplicate_test, + t_inter, + edit_mask, + cache, + ): + """Body of sample(); called after grad mode + eval mode are set up. + Split out so the public sample() can manage grad mode via try/finally + without re-indenting this code.""" # raw wave if cond.ndim == 2: @@ -173,7 +244,7 @@ def fn(t, x): mask=mask, drop_audio_cond=False, drop_text=False, - cache=True, + cache=cache, ) return pred @@ -185,7 +256,7 @@ def fn(t, x): time=t, mask=mask, cfg_infer=True, - cache=True, + cache=cache, ) pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) return pred + (pred - null_pred) * cfg_strength @@ -216,7 +287,8 @@ def fn(t, x): t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) trajectory = odeint(fn, y0, t, **self.odeint_kwargs) - self.transformer.clear_cache() + if cache: + self.transformer.clear_cache() sampled = trajectory[-1] out = sampled