From 31fceb858bb4d37d80549d5cbf898048c8e2d218 Mon Sep 17 00:00:00 2001 From: Raivis Dejus Date: Sat, 17 Jan 2026 09:03:14 +0200 Subject: [PATCH 1/5] Adding F5TTS_base model and fix for use of custom checkpoints --- src/f5_tts/infer/infer_gradio.py | 164 +++++++++++++++++++++---------- 1 file changed, 110 insertions(+), 54 deletions(-) diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index fb24cad93..9519e17b0 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -35,6 +35,7 @@ def gpu_decorator(func): return func +from f5_tts.api import F5TTS from f5_tts.infer.utils_infer import ( infer_process, load_model, @@ -46,9 +47,17 @@ def gpu_decorator(func): ) from f5_tts.model import DiT, UNetT +# Mapping from UI model names to API model names +MODEL_NAME_MAP = { + "F5-TTS_v1": "F5TTS_v1_Base", + "F5-TTS": "F5TTS_Base", + "E2-TTS": "E2TTS_Base", +} + DEFAULT_TTS_MODEL = "F5-TTS_v1" tts_model_choice = DEFAULT_TTS_MODEL +custom_model_enabled = False DEFAULT_TTS_MODEL_CFG = [ "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors", @@ -56,6 +65,12 @@ def gpu_decorator(func): json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), ] +F5TTS_BASE_CFG = [ + "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt", + "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), +] + # load models @@ -68,6 +83,12 @@ def load_f5tts(): return load_model(DiT, F5TTS_model_cfg, ckpt_path) +def load_f5tts_base(): + ckpt_path = str(cached_path(F5TTS_BASE_CFG[0])) + F5TTS_model_cfg = json.loads(F5TTS_BASE_CFG[2]) + return load_model(DiT, F5TTS_model_cfg, ckpt_path) + + def load_e2tts(): ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1) @@ -88,6 +109,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None): F5TTS_ema_model = load_f5tts() +F5TTS_base_ema_model = None E2TTS_ema_model = load_e2tts() if USING_SPACES else None custom_ema_model, pre_custom_path = None, "" @@ -159,22 +181,40 @@ def infer( ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) - if model == DEFAULT_TTS_MODEL: + # Handle custom model - model can be a tuple (base_model, ckpt_path, vocab_path, model_cfg) + if isinstance(model, tuple): + assert not USING_SPACES, "Only official checkpoints allowed in Spaces." + base_model, ckpt_path, vocab_path, model_cfg = model + global custom_ema_model, pre_custom_path + # Cache key includes base_model to reload if model type changes + cache_key = (base_model, ckpt_path, vocab_path) + if pre_custom_path != cache_key: + # Map UI model name to API model name + api_model_name = MODEL_NAME_MAP.get(base_model, "F5TTS_v1_Base") + show_info(f"Loading Custom TTS model (base: {api_model_name})...") + # Use F5TTS API for proper model loading (same as finetune_gradio.py and CLI) + custom_tts_api = F5TTS( + model=api_model_name, + ckpt_file=ckpt_path, + vocab_file=vocab_path if vocab_path else "", + ) + custom_ema_model = custom_tts_api.ema_model + pre_custom_path = cache_key + ema_model = custom_ema_model + elif model == DEFAULT_TTS_MODEL: ema_model = F5TTS_ema_model + elif model == "F5-TTS": + global F5TTS_base_ema_model + if F5TTS_base_ema_model is None: + show_info("Loading F5-TTS model...") + F5TTS_base_ema_model = load_f5tts_base() + ema_model = F5TTS_base_ema_model elif model == "E2-TTS": global E2TTS_ema_model if E2TTS_ema_model is None: show_info("Loading E2-TTS model...") E2TTS_ema_model = load_e2tts() ema_model = E2TTS_ema_model - elif isinstance(model, tuple) and model[0] == "Custom": - assert not USING_SPACES, "Only official checkpoints allowed in Spaces." - global custom_ema_model, pre_custom_path - if pre_custom_path != model[1]: - show_info("Loading Custom TTS model...") - custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3]) - pre_custom_path = model[1] - ema_model = custom_ema_model final_wave, final_sample_rate, combined_spectrogram = infer_process( ref_audio, @@ -221,7 +261,7 @@ def infer( ) gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1) generate_btn = gr.Button("Synthesize", variant="primary") - with gr.Accordion("Advanced Settings", open=True) as adv_settn: + with gr.Accordion("Advanced Settings", open=False): with gr.Row(): ref_text_input = gr.Textbox( label="Reference Text", @@ -269,17 +309,6 @@ def infer( info="Set the duration of the cross-fade between audio clips.", ) - def collapse_accordion(): - return gr.Accordion(open=False) - - # Workaround for https://github.com/SWivid/F5-TTS/issues/1239#issuecomment-3677987413 - # i.e. to set gr.Accordion(open=True) by default, then collapse manually Blocks loaded - app_tts.load( - fn=collapse_accordion, - inputs=None, - outputs=adv_settn, - ) - audio_output = gr.Audio(label="Synthesized Audio") spectrogram_output = gr.Image(label="Spectrogram") @@ -588,7 +617,7 @@ def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, spee label="Cherry-pick Interface", lines=10, max_lines=40, - buttons=["copy"], # show_copy_button=True if gradio<6.0 + show_copy_button=True, interactive=False, visible=False, ) @@ -827,9 +856,7 @@ def load_chat_model(chat_model_name): lines=2, ) - chatbot_interface = gr.Chatbot( - label="Conversation" - ) # type="messages" hard-coded and no need to pass in since gradio 6.0 + chatbot_interface = gr.Chatbot(label="Conversation", type="messages") with gr.Row(): with gr.Column(): @@ -866,10 +893,6 @@ def process_audio_input(conv_state, audio_path, text): @gpu_decorator def generate_text_response(conv_state, system_prompt): """Generate text response from AI""" - for single_state in conv_state: - if isinstance(single_state["content"], list): - assert len(single_state["content"]) == 1 and single_state["content"][0]["type"] == "text" - single_state["content"] = single_state["content"][0]["text"] system_prompt_state = [{"role": "system", "content": system_prompt}] response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state) @@ -883,7 +906,7 @@ def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, ran if not conv_state or not ref_audio: return None, ref_text, seed_input - last_ai_response = conv_state[-1]["content"][0]["text"] + last_ai_response = conv_state[-1]["content"] if not last_ai_response or conv_state[-1]["role"] != "assistant": return None, ref_text, seed_input @@ -988,35 +1011,51 @@ def load_last_used_custom(): last_used_custom.parent.mkdir(parents=True, exist_ok=True) return DEFAULT_TTS_MODEL_CFG - def switch_tts_model(new_choice): - global tts_model_choice - if new_choice == "Custom": # override in case webpage is refreshed - custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom() - tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg) + def switch_tts_model(new_choice, use_custom, custom_ckpt_path, custom_vocab_path, custom_model_cfg): + global tts_model_choice, custom_model_enabled + custom_model_enabled = use_custom + if use_custom and custom_ckpt_path: + tts_model_choice = (new_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg) + else: + tts_model_choice = new_choice + return None # no UI updates needed + + def toggle_custom_model(use_custom, model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg): + global tts_model_choice, custom_model_enabled + custom_model_enabled = use_custom + if use_custom: + last_custom = load_last_used_custom() + if use_custom and custom_ckpt_path: + tts_model_choice = (model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg) + else: + tts_model_choice = (model_choice, last_custom[0], last_custom[1], last_custom[2]) return ( - gr.update(visible=True, value=custom_ckpt_path), - gr.update(visible=True, value=custom_vocab_path), - gr.update(visible=True, value=custom_model_cfg), + gr.update(visible=True, value=last_custom[0]), + gr.update(visible=True, value=last_custom[1]), + gr.update(visible=True, value=last_custom[2]), ) else: - tts_model_choice = new_choice + tts_model_choice = model_choice return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) - def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg): + def set_custom_model(model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg): global tts_model_choice - tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg) + tts_model_choice = (model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg) with open(last_used_custom, "w", encoding="utf-8") as f: f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n") with gr.Row(): + choose_tts_model = gr.Radio( + choices=[DEFAULT_TTS_MODEL, "F5-TTS", "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL + ) if not USING_SPACES: - choose_tts_model = gr.Radio( - choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL + use_custom_model = gr.Checkbox( + label="Use Custom Model", + value=False, + info="Load a custom checkpoint with the selected base model architecture", ) else: - choose_tts_model = gr.Radio( - choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL - ) + use_custom_model = gr.Checkbox(label="Use Custom Model", value=False, visible=False) custom_ckpt_path = gr.Dropdown( choices=[DEFAULT_TTS_MODEL_CFG[0]], value=load_last_used_custom()[0], @@ -1033,7 +1072,7 @@ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg): ) custom_model_cfg = gr.Dropdown( choices=[ - DEFAULT_TTS_MODEL_CFG[2], + DEFAULT_TTS_MODEL_CFG[2], # F5-TTS v1 Base json.dumps( dict( dim=1024, @@ -1045,7 +1084,7 @@ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg): conv_layers=4, pe_attn_head=1, ) - ), + ), # F5-TTS v1 with extra params json.dumps( dict( dim=768, @@ -1057,33 +1096,49 @@ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg): conv_layers=4, pe_attn_head=1, ) - ), + ), # F5-TTS Small + json.dumps( + dict( + dim=1024, + depth=24, + heads=16, + ff_mult=4, + text_mask_padding=False, + pe_attn_head=1, + ) + ), # E2-TTS Base ], value=load_last_used_custom()[2], allow_custom_value=True, - label="Config: in a dictionary form", + label="Config (optional, uses base model config if empty)", visible=False, ) choose_tts_model.change( switch_tts_model, - inputs=[choose_tts_model], + inputs=[choose_tts_model, use_custom_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + outputs=None, + show_progress="hidden", + ) + use_custom_model.change( + toggle_custom_model, + inputs=[use_custom_model, choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) custom_ckpt_path.change( set_custom_model, - inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) custom_vocab_path.change( set_custom_model, - inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) custom_model_cfg.change( set_custom_model, - inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) @@ -1125,6 +1180,7 @@ def main(port, host, share, api, root_path, inbrowser): server_name=host, server_port=port, share=share, + show_api=api, root_path=root_path, inbrowser=inbrowser, ) From 378228f632e3e36c9f5c9106f3c7197ea47f955c Mon Sep 17 00:00:00 2001 From: Raivis Dejus Date: Sat, 17 Jan 2026 10:22:21 +0200 Subject: [PATCH 2/5] Adding Shared community models to the custom model dropdown --- src/f5_tts/infer/infer_gradio.py | 170 +++++++++++++++++++++---------- 1 file changed, 115 insertions(+), 55 deletions(-) diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index 9519e17b0..a4017a096 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -71,6 +71,71 @@ def gpu_decorator(func): json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), ] +# Shared community models from SHARED.md +# Format: "Display Name": [model_path, vocab_path, config_json, base_model] +SHARED_MODELS = { + "F5-TTS v1 Base (zh & en)": [ + "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors", + "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), + "F5-TTS_v1", + ], + "F5-TTS Base (zh & en)": [ + "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", + "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + "F5-TTS", + ], + "F5-TTS Finnish": [ + "hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors", + "hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + "F5-TTS", + ], + "F5-TTS French": [ + "hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt", + "hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + "F5-TTS", + ], + "F5-TTS German": [ + "hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt", + "hf://hvoss-techfak/F5-TTS-German/vocab.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + "F5-TTS", + ], + "F5-TTS Hindi (Small)": [ + "hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors", + "hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt", + json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + "F5-TTS", + ], + "F5-TTS Italian": [ + "hf://alien79/F5-TTS-italian/model_159600.safetensors", + "hf://alien79/F5-TTS-italian/vocab.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + "F5-TTS", + ], + "F5-TTS Japanese": [ + "hf://Jmica/F5TTS/JA_21999120/model_21999120.pt", + "hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + "F5-TTS", + ], + "F5-TTS Russian": [ + "hf://hotstone228/F5-TTS-Russian/model_last.safetensors", + "hf://hotstone228/F5-TTS-Russian/vocab.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + "F5-TTS", + ], + "F5-TTS Latvian": [ + "hf://RaivisDejus/F5-TTS-Latvian/model.safetensors", + "hf://RaivisDejus/F5-TTS-Latvian/vocab.txt", + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + "F5-TTS", + ], +} + # load models @@ -192,11 +257,14 @@ def infer( # Map UI model name to API model name api_model_name = MODEL_NAME_MAP.get(base_model, "F5TTS_v1_Base") show_info(f"Loading Custom TTS model (base: {api_model_name})...") + # Resolve hf:// paths to local cache + resolved_ckpt = str(cached_path(ckpt_path)) if ckpt_path.startswith("hf://") else ckpt_path + resolved_vocab = str(cached_path(vocab_path)) if vocab_path and vocab_path.startswith("hf://") else vocab_path # Use F5TTS API for proper model loading (same as finetune_gradio.py and CLI) custom_tts_api = F5TTS( model=api_model_name, - ckpt_file=ckpt_path, - vocab_file=vocab_path if vocab_path else "", + ckpt_file=resolved_ckpt, + vocab_file=resolved_vocab if resolved_vocab else "", ) custom_ema_model = custom_tts_api.ema_model pre_custom_path = cache_key @@ -1025,18 +1093,14 @@ def toggle_custom_model(use_custom, model_choice, custom_ckpt_path, custom_vocab custom_model_enabled = use_custom if use_custom: last_custom = load_last_used_custom() - if use_custom and custom_ckpt_path: + if custom_ckpt_path: tts_model_choice = (model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg) else: tts_model_choice = (model_choice, last_custom[0], last_custom[1], last_custom[2]) - return ( - gr.update(visible=True, value=last_custom[0]), - gr.update(visible=True, value=last_custom[1]), - gr.update(visible=True, value=last_custom[2]), - ) + return gr.update(visible=True) else: tts_model_choice = model_choice - return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) + return gr.update(visible=False) def set_custom_model(model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg): global tts_model_choice @@ -1056,64 +1120,60 @@ def set_custom_model(model_choice, custom_ckpt_path, custom_vocab_path, custom_m ) else: use_custom_model = gr.Checkbox(label="Use Custom Model", value=False, visible=False) + # Build dropdown choices from shared models + shared_model_choices = [cfg[0] for cfg in SHARED_MODELS.values()] + shared_vocab_choices = list(set(cfg[1] for cfg in SHARED_MODELS.values())) + shared_config_choices = list(set(cfg[2] for cfg in SHARED_MODELS.values())) + + with gr.Row(visible=False) as custom_model_row: + custom_model_select = gr.Dropdown( + choices=list(SHARED_MODELS.keys()), + value=None, + allow_custom_value=False, + label="Shared Models (auto-fills paths below)", + scale=2, + ) custom_ckpt_path = gr.Dropdown( - choices=[DEFAULT_TTS_MODEL_CFG[0]], + choices=shared_model_choices, value=load_last_used_custom()[0], allow_custom_value=True, - label="Model: local_path | hf://user_id/repo_id/model_ckpt", - visible=False, + label="Model: local_path | hf://...", + scale=3, ) custom_vocab_path = gr.Dropdown( - choices=[DEFAULT_TTS_MODEL_CFG[1]], + choices=shared_vocab_choices, value=load_last_used_custom()[1], allow_custom_value=True, - label="Vocab: local_path | hf://user_id/repo_id/vocab_file", - visible=False, + label="Vocab: local_path | hf://...", + scale=2, ) custom_model_cfg = gr.Dropdown( - choices=[ - DEFAULT_TTS_MODEL_CFG[2], # F5-TTS v1 Base - json.dumps( - dict( - dim=1024, - depth=22, - heads=16, - ff_mult=2, - text_dim=512, - text_mask_padding=False, - conv_layers=4, - pe_attn_head=1, - ) - ), # F5-TTS v1 with extra params - json.dumps( - dict( - dim=768, - depth=18, - heads=12, - ff_mult=2, - text_dim=512, - text_mask_padding=False, - conv_layers=4, - pe_attn_head=1, - ) - ), # F5-TTS Small - json.dumps( - dict( - dim=1024, - depth=24, - heads=16, - ff_mult=4, - text_mask_padding=False, - pe_attn_head=1, - ) - ), # E2-TTS Base - ], + choices=shared_config_choices, value=load_last_used_custom()[2], allow_custom_value=True, - label="Config (optional, uses base model config if empty)", - visible=False, + label="Config (optional)", + scale=2, ) + def on_shared_model_select(model_name): + if model_name and model_name in SHARED_MODELS: + cfg = SHARED_MODELS[model_name] + # cfg = [model_path, vocab_path, config_json, base_model] + return ( + gr.update(value=cfg[3]), # Update radio button to correct base model + gr.update(value=cfg[0]), # Model path + gr.update(value=cfg[1]), # Vocab path + gr.update(value=cfg[2]), # Config + ) + return gr.update(), gr.update(), gr.update(), gr.update() + + custom_model_select.change( + on_shared_model_select, + inputs=[custom_model_select], + outputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + show_progress="hidden", + ) + choose_tts_model.change( switch_tts_model, inputs=[choose_tts_model, use_custom_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], @@ -1123,7 +1183,7 @@ def set_custom_model(model_choice, custom_ckpt_path, custom_vocab_path, custom_m use_custom_model.change( toggle_custom_model, inputs=[use_custom_model, choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], - outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg], + outputs=[custom_model_row], show_progress="hidden", ) custom_ckpt_path.change( From 4ea23a9f0a72a5062b8d4a3bee77a0ea2c487967 Mon Sep 17 00:00:00 2001 From: Raivis Dejus Date: Sat, 17 Jan 2026 11:29:11 +0200 Subject: [PATCH 3/5] Fix for inference with default base model --- src/f5_tts/infer/infer_gradio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index a4017a096..70d57194b 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -66,9 +66,9 @@ def gpu_decorator(func): ] F5TTS_BASE_CFG = [ - "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt", + "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), + json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), ] # Shared community models from SHARED.md From df98e215b04b8ad3ef6e0e2ddd9af48cf236c699 Mon Sep 17 00:00:00 2001 From: Raivis Dejus Date: Sat, 17 Jan 2026 13:27:01 +0200 Subject: [PATCH 4/5] Moving F5 Base to custom model options only --- src/f5_tts/infer/infer_gradio.py | 42 ++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index 70d57194b..2f5cc76ad 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -1079,38 +1079,38 @@ def load_last_used_custom(): last_used_custom.parent.mkdir(parents=True, exist_ok=True) return DEFAULT_TTS_MODEL_CFG - def switch_tts_model(new_choice, use_custom, custom_ckpt_path, custom_vocab_path, custom_model_cfg): + def switch_tts_model(new_choice, use_custom, custom_base, custom_ckpt_path, custom_vocab_path, custom_model_cfg): global tts_model_choice, custom_model_enabled custom_model_enabled = use_custom if use_custom and custom_ckpt_path: - tts_model_choice = (new_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg) + tts_model_choice = (custom_base, custom_ckpt_path, custom_vocab_path, custom_model_cfg) else: tts_model_choice = new_choice return None # no UI updates needed - def toggle_custom_model(use_custom, model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg): + def toggle_custom_model(use_custom, model_choice, custom_base, custom_ckpt_path, custom_vocab_path, custom_model_cfg): global tts_model_choice, custom_model_enabled custom_model_enabled = use_custom if use_custom: last_custom = load_last_used_custom() if custom_ckpt_path: - tts_model_choice = (model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg) + tts_model_choice = (custom_base, custom_ckpt_path, custom_vocab_path, custom_model_cfg) else: - tts_model_choice = (model_choice, last_custom[0], last_custom[1], last_custom[2]) + tts_model_choice = (custom_base, last_custom[0], last_custom[1], last_custom[2]) return gr.update(visible=True) else: tts_model_choice = model_choice return gr.update(visible=False) - def set_custom_model(model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg): + def set_custom_model(custom_base, custom_ckpt_path, custom_vocab_path, custom_model_cfg): global tts_model_choice - tts_model_choice = (model_choice, custom_ckpt_path, custom_vocab_path, custom_model_cfg) + tts_model_choice = (custom_base, custom_ckpt_path, custom_vocab_path, custom_model_cfg) with open(last_used_custom, "w", encoding="utf-8") as f: f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n") with gr.Row(): choose_tts_model = gr.Radio( - choices=[DEFAULT_TTS_MODEL, "F5-TTS", "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL + choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL ) if not USING_SPACES: use_custom_model = gr.Checkbox( @@ -1133,6 +1133,13 @@ def set_custom_model(model_choice, custom_ckpt_path, custom_vocab_path, custom_m label="Shared Models (auto-fills paths below)", scale=2, ) + custom_base_model = gr.Dropdown( + choices=["F5-TTS_v1", "F5-TTS"], + value="F5-TTS_v1", + allow_custom_value=False, + label="Base Model", + scale=1, + ) custom_ckpt_path = gr.Dropdown( choices=shared_model_choices, value=load_last_used_custom()[0], @@ -1160,7 +1167,7 @@ def on_shared_model_select(model_name): cfg = SHARED_MODELS[model_name] # cfg = [model_path, vocab_path, config_json, base_model] return ( - gr.update(value=cfg[3]), # Update radio button to correct base model + gr.update(value=cfg[3]), # Update base model dropdown gr.update(value=cfg[0]), # Model path gr.update(value=cfg[1]), # Vocab path gr.update(value=cfg[2]), # Config @@ -1170,35 +1177,40 @@ def on_shared_model_select(model_name): custom_model_select.change( on_shared_model_select, inputs=[custom_model_select], - outputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + outputs=[custom_base_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) choose_tts_model.change( switch_tts_model, - inputs=[choose_tts_model, use_custom_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[choose_tts_model, use_custom_model, custom_base_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], outputs=None, show_progress="hidden", ) use_custom_model.change( toggle_custom_model, - inputs=[use_custom_model, choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[use_custom_model, choose_tts_model, custom_base_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], outputs=[custom_model_row], show_progress="hidden", ) + custom_base_model.change( + set_custom_model, + inputs=[custom_base_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + show_progress="hidden", + ) custom_ckpt_path.change( set_custom_model, - inputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[custom_base_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) custom_vocab_path.change( set_custom_model, - inputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[custom_base_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) custom_model_cfg.change( set_custom_model, - inputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[custom_base_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) From 2a49364ce5fbd43a2ecebc70556f1edcddc08475 Mon Sep 17 00:00:00 2001 From: Raivis Dejus Date: Sat, 17 Jan 2026 13:39:40 +0200 Subject: [PATCH 5/5] Fixed linter issues --- src/f5_tts/infer/infer_gradio.py | 157 ++++++++++++++++++++++++++++--- 1 file changed, 143 insertions(+), 14 deletions(-) diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index 2f5cc76ad..7811babd5 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -47,6 +47,7 @@ def gpu_decorator(func): ) from f5_tts.model import DiT, UNetT + # Mapping from UI model names to API model names MODEL_NAME_MAP = { "F5-TTS_v1": "F5TTS_v1_Base", @@ -68,7 +69,18 @@ def gpu_decorator(func): F5TTS_BASE_CFG = [ "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), ] # Shared community models from SHARED.md @@ -83,55 +95,154 @@ def gpu_decorator(func): "F5-TTS Base (zh & en)": [ "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), "F5-TTS", ], "F5-TTS Finnish": [ "hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors", "hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), "F5-TTS", ], "F5-TTS French": [ "hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt", "hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), "F5-TTS", ], "F5-TTS German": [ "hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt", "hf://hvoss-techfak/F5-TTS-German/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), "F5-TTS", ], "F5-TTS Hindi (Small)": [ "hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors", "hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt", - json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=768, + depth=18, + heads=12, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), "F5-TTS", ], "F5-TTS Italian": [ "hf://alien79/F5-TTS-italian/model_159600.safetensors", "hf://alien79/F5-TTS-italian/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), "F5-TTS", ], "F5-TTS Japanese": [ "hf://Jmica/F5TTS/JA_21999120/model_21999120.pt", "hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), "F5-TTS", ], "F5-TTS Russian": [ "hf://hotstone228/F5-TTS-Russian/model_last.safetensors", "hf://hotstone228/F5-TTS-Russian/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), "F5-TTS", ], "F5-TTS Latvian": [ "hf://RaivisDejus/F5-TTS-Latvian/model.safetensors", "hf://RaivisDejus/F5-TTS-Latvian/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), "F5-TTS", ], } @@ -259,7 +370,9 @@ def infer( show_info(f"Loading Custom TTS model (base: {api_model_name})...") # Resolve hf:// paths to local cache resolved_ckpt = str(cached_path(ckpt_path)) if ckpt_path.startswith("hf://") else ckpt_path - resolved_vocab = str(cached_path(vocab_path)) if vocab_path and vocab_path.startswith("hf://") else vocab_path + resolved_vocab = ( + str(cached_path(vocab_path)) if vocab_path and vocab_path.startswith("hf://") else vocab_path + ) # Use F5TTS API for proper model loading (same as finetune_gradio.py and CLI) custom_tts_api = F5TTS( model=api_model_name, @@ -1088,7 +1201,9 @@ def switch_tts_model(new_choice, use_custom, custom_base, custom_ckpt_path, cust tts_model_choice = new_choice return None # no UI updates needed - def toggle_custom_model(use_custom, model_choice, custom_base, custom_ckpt_path, custom_vocab_path, custom_model_cfg): + def toggle_custom_model( + use_custom, model_choice, custom_base, custom_ckpt_path, custom_vocab_path, custom_model_cfg + ): global tts_model_choice, custom_model_enabled custom_model_enabled = use_custom if use_custom: @@ -1183,13 +1298,27 @@ def on_shared_model_select(model_name): choose_tts_model.change( switch_tts_model, - inputs=[choose_tts_model, use_custom_model, custom_base_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[ + choose_tts_model, + use_custom_model, + custom_base_model, + custom_ckpt_path, + custom_vocab_path, + custom_model_cfg, + ], outputs=None, show_progress="hidden", ) use_custom_model.change( toggle_custom_model, - inputs=[use_custom_model, choose_tts_model, custom_base_model, custom_ckpt_path, custom_vocab_path, custom_model_cfg], + inputs=[ + use_custom_model, + choose_tts_model, + custom_base_model, + custom_ckpt_path, + custom_vocab_path, + custom_model_cfg, + ], outputs=[custom_model_row], show_progress="hidden", )