Spaces:
Running
on
A10G
Running
on
A10G
"""Utility for loading the models from HF.""" | |
import os | |
from pathlib import Path | |
import typing as tp | |
from omegaconf import OmegaConf | |
from huggingface_hub import hf_hub_download, login | |
import torch | |
from audiocraft.models import builders, MusicGen | |
MODEL_CHECKPOINTS_MAP = { | |
"small": "facebook/musicgen-small", | |
"medium": "facebook/musicgen-medium", | |
"large": "facebook/musicgen-large", | |
"melody": "facebook/musicgen-melody", | |
} | |
login(os.environ['ACCESS_TOKEN']) | |
def _get_state_dict(file_or_url: tp.Union[Path, str], | |
filename="state_dict.bin", device='cpu'): | |
# Return the state dict either from a file or url | |
print("loading", file_or_url, filename) | |
file_or_url = str(file_or_url) | |
assert isinstance(file_or_url, str) | |
return torch.load( | |
hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device) | |
def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'): | |
pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin") | |
cfg = OmegaConf.create(pkg['xp.cfg']) | |
cfg.device = str(device) | |
model = builders.get_compression_model(cfg) | |
model.load_state_dict(pkg['best_state']) | |
model.eval() | |
model.cfg = cfg | |
return model | |
def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'): | |
pkg = _get_state_dict(file_or_url) | |
cfg = OmegaConf.create(pkg['xp.cfg']) | |
cfg.device = str(device) | |
if cfg.device == 'cpu': | |
cfg.transformer_lm.memory_efficient = False | |
cfg.transformer_lm.custom = True | |
cfg.dtype = 'float32' | |
else: | |
cfg.dtype = 'float16' | |
model = builders.get_lm_model(cfg) | |
model.load_state_dict(pkg['best_state']) | |
model.eval() | |
model.cfg = cfg | |
return model | |
def get_pretrained(name: str = 'small', device='cuda'): | |
model_id = MODEL_CHECKPOINTS_MAP[name] | |
compression_model = load_compression_model(model_id, device=device) | |
lm = load_lm_model(model_id, device=device) | |
return MusicGen(name, compression_model, lm) | |