From bad0a07d7e7c953001fc5e095ce3a1ac816bf7a6 Mon Sep 17 00:00:00 2001 From: vidit-shrimali Date: Wed, 17 Jun 2026 19:38:12 +0530 Subject: [PATCH] docs: clarify EMA weights checkpoint loading (#1292) --- src/f5_tts/api.py | 5 +++++ src/f5_tts/train/README.md | 17 +++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 782466deb..1c92211de 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -32,6 +32,11 @@ def __init__( device=None, hf_cache_dir=None, ): + """ + Note: The official pre-trained checkpoints contain ONLY Exponential Moving Average (EMA) + weights under the `ema_model.transformer.` prefix. If writing a custom loader bypassing + this API, ensure you strip this prefix to avoid silent random initialization. + """ model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch diff --git a/src/f5_tts/train/README.md b/src/f5_tts/train/README.md index 5ec91e2f9..05259357f 100644 --- a/src/f5_tts/train/README.md +++ b/src/f5_tts/train/README.md @@ -64,6 +64,23 @@ If use tensorboard as logger, install it first with `pip install tensorboard`. The `use_ema = True` might be harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results. +### ⚠️ Note on Custom Checkpoint Loading + +The published base checkpoints (e.g., `model_1250000.safetensors`) store weights exclusively under the Exponential Moving Average (EMA) shadow key prefix (`ema_model.transformer.*`). + +If you bypass the built-in API wrapper and write a custom state-dict loader, a naive load will cause the model to fall back to random initialization silently. You must invert your filter to strip the prefix: + +```python +from safetensors.torch import load_file + +state = load_file("model_1250000.safetensors") +# Strip the EMA prefix to map directly to standard transformer keys +cleaned_state = {k[len("ema_model.transformer."):]: v + for k, v in state.items() + if k.startswith("ema_model.transformer.")} +model.load_state_dict(cleaned_state, strict=False) +``` + ### 3. W&B Logging The `wandb/` dir will be created under path you run training/finetuning scripts.