diff --git a/pyproject.toml b/pyproject.toml index 80423d673..62d929b54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,9 @@ eval = [ "zhconv", "zhon", ] +peft = [ + "peft>=0.14.0", +] [project.urls] Homepage = "https://github.com/SWivid/F5-TTS" diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 782923222..a1a018b72 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -53,6 +53,7 @@ def __init__( is_local_vocoder: bool = False, # use local path vocoder local_vocoder_path: str = "", # local vocoder path model_cfg_dict: dict = dict(), # training config + peft_config: object | None = None, # peft.PeftConfig instance (LoraConfig / LoHaConfig); None disables PEFT ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) @@ -103,10 +104,24 @@ def __init__( self.model = model - if self.is_main: + # PEFT (LoRA / LoHa) wrap. Frozen base + small trainable adapter. EMA disabled under PEFT — + # tracking ema across base+adapter bloats state and degrades adapter learning. + self.peft_enabled = peft_config is not None + if self.peft_enabled: + from peft import get_peft_model + + self.model = get_peft_model(self.model, peft_config) + + if self.is_main and not self.peft_enabled: self.ema_model = EMA(model, include_online_model=False, **ema_kwargs) self.ema_model.to(self.accelerator.device) + elif self.is_main: + self.ema_model = None + if self.is_main: + if self.peft_enabled: + # peft helper prints "trainable params: X || all params: Y || trainable%: Z" + self.model.print_trainable_parameters() print(f"Using logger: {logger}") if grad_accumulation_steps > 1: print( @@ -135,12 +150,14 @@ def __init__( self.duration_predictor = duration_predictor + # Use self.model.parameters() so PEFT-wrapped runs see only trainable adapter params. + # For non-PEFT runs this matches the prior behavior (all params trainable). if bnb_optimizer: import bitsandbytes as bnb - self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) + self.optimizer = bnb.optim.AdamW8bit(self.model.parameters(), lr=learning_rate) else: - self.optimizer = AdamW(model.parameters(), lr=learning_rate, fused=True) + self.optimizer = AdamW(self.model.parameters(), lr=learning_rate, fused=True) self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) @property @@ -153,12 +170,17 @@ def save_checkpoint(self, update, last=False): checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), optimizer_state_dict=self.optimizer.state_dict(), - ema_model_state_dict=self.ema_model.state_dict(), scheduler_state_dict=self.scheduler.state_dict(), update=update, ) + if self.ema_model is not None: + checkpoint["ema_model_state_dict"] = self.ema_model.state_dict() if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) + # PEFT: also save adapter-only safetensors directory for portability (~MB vs GB) + if self.peft_enabled: + adapter_dir = f"{self.checkpoint_path}/adapter_{'last' if last else update}" + self.accelerator.unwrap_model(self.model).save_pretrained(adapter_dir) if last: self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") print(f"Saved last checkpoint at update {update}") @@ -224,11 +246,12 @@ def load_checkpoint(self): ) # patch for backward compatibility, 305e3ea - for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]: - if key in checkpoint["ema_model_state_dict"]: - del checkpoint["ema_model_state_dict"][key] + if "ema_model_state_dict" in checkpoint: + for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]: + if key in checkpoint["ema_model_state_dict"]: + del checkpoint["ema_model_state_dict"][key] - if self.is_main: + if self.is_main and self.ema_model is not None and "ema_model_state_dict" in checkpoint: self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"]) if "update" in checkpoint or "step" in checkpoint: @@ -249,13 +272,19 @@ def load_checkpoint(self): if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) update = checkpoint["update"] - else: + elif "ema_model_state_dict" in checkpoint: checkpoint["model_state_dict"] = { k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "update", "step"] } - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) + # PEFT-wrapped model has prefixed keys (base_model.model.…); pretrained ckpts don't. + # Load non-strict and rely on CLI-level pre-load when PEFT enabled (see finetune_cli.py). + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"], strict=not self.peft_enabled + ) + update = 0 + else: update = 0 del checkpoint @@ -384,7 +413,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int self.optimizer.zero_grad() if self.accelerator.sync_gradients: - if self.is_main: + if self.is_main and self.ema_model is not None: self.ema_model.update() global_update += 1 diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index cdf42a9ac..ec635f362 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -71,6 +71,34 @@ def parse_args(): action="store_true", help="Use 8-bit Adam optimizer from bitsandbytes", ) + parser.add_argument( + "--peft_method", + type=str, + default="none", + choices=["none", "lora", "loha"], + help="Parameter-efficient finetune adapter (frozen base, trainable adapter). 'none' = full finetune.", + ) + parser.add_argument( + "--peft_rank", + type=int, + default=8, + help="Adapter rank r. Sensible: LoRA r=16-32, LoHa r=4-8 (LyCORIS rule r<=sqrt(dim)).", + ) + parser.add_argument( + "--peft_alpha", + type=int, + default=8, + help="Adapter scaling alpha. Common choice: alpha = r (LoRA) or alpha = r (LoHa).", + ) + parser.add_argument( + "--peft_target_modules", + type=str, + default=None, + help=( + "Comma-separated module name suffixes to adapt. " + "Default targets DiT/UNetT attention+FFN linears: 'to_q,to_k,to_v,to_out.0,ff.ff.0.0,ff.ff.2'." + ), + ) return parser.parse_args() @@ -138,7 +166,10 @@ def main(): else: ckpt_path = args.pretrain - if args.finetune: + # PEFT runs need the base loaded INTO the bare model before adapter wrap (handled below). + # Skip the pretrained-copy mechanic: load_checkpoint would otherwise try to load the + # un-prefixed pretrained state into the PEFT-wrapped (prefixed) model. + if args.finetune and args.peft_method == "none": if not os.path.isdir(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) @@ -180,6 +211,73 @@ def main(): vocab_char_map=vocab_char_map, ) + # --- PEFT: build adapter config + pre-load base weights --- + peft_config = None + if args.peft_method != "none": + if args.peft_target_modules: + targets = [s.strip() for s in args.peft_target_modules.split(",")] + else: + # DiT/UNetT share attention (to_q/k/v/out.0) and FFN (ff.ff.0.0 / ff.ff.2) module naming. + targets = ["to_q", "to_k", "to_v", "to_out.0", "ff.ff.0.0", "ff.ff.2"] + # Always exclude AdaLN-Zero modulation + final zero-init linears — adapting them breaks + # F5-TTS init contract (NaN within first steps). + excludes = [ + "attn_norm", + "ff_norm", + "norm_out", + "proj_out", + "time_embed", + "text_embed", + "input_embed", + "long_skip_connection", + ] + if args.peft_method == "lora": + from peft import LoraConfig + + peft_config = LoraConfig( + r=args.peft_rank, + lora_alpha=args.peft_alpha, + target_modules=targets, + exclude_modules=excludes, + lora_dropout=0.0, + bias="none", + ) + elif args.peft_method == "loha": + from peft import LoHaConfig + + peft_config = LoHaConfig( + r=args.peft_rank, + alpha=args.peft_alpha, + target_modules=targets, + exclude_modules=excludes, + rank_dropout=0.0, + module_dropout=0.0, + use_effective_conv2d=False, + ) + + if args.finetune: + import torch + from safetensors.torch import load_file as _load_safetensors + + if ckpt_path.endswith(".safetensors"): + sd = _load_safetensors(ckpt_path) + else: + sd_pt = torch.load(ckpt_path, weights_only=True, map_location="cpu") + if "ema_model_state_dict" in sd_pt: + sd = { + k.replace("ema_model.", ""): v + for k, v in sd_pt["ema_model_state_dict"].items() + if k not in ("initted", "update", "step") + } + elif "model_state_dict" in sd_pt: + sd = sd_pt["model_state_dict"] + else: + sd = sd_pt + for k in ("mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"): + sd.pop(k, None) + missing, unexpected = model.load_state_dict(sd, strict=False) + print(f"PEFT pretrain loaded from {ckpt_path}: missing={len(missing)} unexpected={len(unexpected)}") + trainer = Trainer( model, args.epochs, @@ -200,6 +298,7 @@ def main(): log_samples=args.log_samples, last_per_updates=args.last_per_updates, bnb_optimizer=args.bnb_optimizer, + peft_config=peft_config, ) train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)