diff --git a/gradio_interface.py b/gradio_interface.py index c112259c..4981fd63 100644 --- a/gradio_interface.py +++ b/gradio_interface.py @@ -1,3 +1,4 @@ +import gc import torch import torchaudio import gradio as gr @@ -14,12 +15,17 @@ SPEAKER_AUDIO_PATH = None +def unload_model(): + gc.collect() + torch.cuda.empty_cache() + return 'Unloaded' + def load_model_if_needed(model_choice: str): global CURRENT_MODEL_TYPE, CURRENT_MODEL if CURRENT_MODEL_TYPE != model_choice: if CURRENT_MODEL is not None: del CURRENT_MODEL - torch.cuda.empty_cache() + unload_model() print(f"Loading {model_choice} model...") CURRENT_MODEL = Zonos.from_pretrained(model_choice, device=device) CURRENT_MODEL.requires_grad_(False).eval()