Spaces:
Sleeping
Sleeping
import logging | |
from pathlib import Path | |
from typing import Union | |
import torch | |
RUN_NAME = "enhancer_stage2" | |
logger = logging.getLogger(__name__) | |
def get_source_url(relpath): | |
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true" | |
def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None): | |
if run_dir is None: | |
run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME | |
return Path(run_dir) / relpath | |
def download(run_dir: Union[str, Path, None] = None): | |
relpaths = [ | |
"hparams.yaml", | |
"ds/G/latest", | |
"ds/G/default/mp_rank_00_model_states.pt", | |
] | |
for relpath in relpaths: | |
path = get_target_path(relpath, run_dir=run_dir) | |
if path.exists(): | |
continue | |
url = get_source_url(relpath) | |
path.parent.mkdir(parents=True, exist_ok=True) | |
torch.hub.download_url_to_file(url, str(path)) | |
return get_target_path("", run_dir=run_dir) | |