Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/f5_tts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def __init__(
device=None,
hf_cache_dir=None,
):
"""
Note: The official pre-trained checkpoints contain ONLY Exponential Moving Average (EMA)
weights under the `ema_model.transformer.` prefix. If writing a custom loader bypassing
this API, ensure you strip this prefix to avoid silent random initialization.
"""
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
Expand Down
17 changes: 17 additions & 0 deletions src/f5_tts/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ If use tensorboard as logger, install it first with `pip install tensorboard`.

<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.

### ⚠️ Note on Custom Checkpoint Loading

The published base checkpoints (e.g., `model_1250000.safetensors`) store weights exclusively under the Exponential Moving Average (EMA) shadow key prefix (`ema_model.transformer.*`).

If you bypass the built-in API wrapper and write a custom state-dict loader, a naive load will cause the model to fall back to random initialization silently. You must invert your filter to strip the prefix:

```python
from safetensors.torch import load_file

state = load_file("model_1250000.safetensors")
# Strip the EMA prefix to map directly to standard transformer keys
cleaned_state = {k[len("ema_model.transformer."):]: v
for k, v in state.items()
if k.startswith("ema_model.transformer.")}
model.load_state_dict(cleaned_state, strict=False)
```

### 3. W&B Logging

The `wandb/` dir will be created under path you run training/finetuning scripts.
Expand Down