From b2a23f1efb18aaa2fbec044dfe52b184026fe29e Mon Sep 17 00:00:00 2001 From: LINKLIU Date: Thu, 8 May 2025 11:37:16 +0800 Subject: [PATCH] load models from local disk --- zonos/model.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/zonos/model.py b/zonos/model.py index ccb713b0..86584e8a 100644 --- a/zonos/model.py +++ b/zonos/model.py @@ -1,4 +1,6 @@ import json +import os +from pathlib import Path from typing import Callable import safetensors @@ -58,8 +60,17 @@ def device(self) -> torch.device: def from_pretrained( cls, repo_id: str, revision: str | None = None, device: str = DEFAULT_DEVICE, **kwargs ) -> "Zonos": - config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) - model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision) + cwd = Path.cwd() + base_path = Path(cwd) / "models" + normalized = os.path.normpath(repo_id) + sub_path = os.path.basename(normalized) + config_path = base_path / sub_path / "config.json" + model_path = base_path / sub_path / "model.safetensors" + + # the below is the old ways + # config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) + # model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision) + return cls.from_local(config_path, model_path, device, **kwargs) @classmethod