MusicGen / hf_loading.py
adefossez's picture
update demos
9d7284e
raw history blame
No virus
1.98 kB
"""Utility for loading the models from HF."""
from pathlib import Path
import typing as tp
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
import torch
from audiocraft.models import builders, MusicGen
MODEL_CHECKPOINTS_MAP = {
"small": "facebook/musicgen-small",
"medium": "facebook/musicgen-medium",
"large": "facebook/musicgen-large",
"melody": "facebook/musicgen-melody",
}
def _get_state_dict(file_or_url: tp.Union[Path, str],
filename="state_dict.bin", device='cpu'):
# Return the state dict either from a file or url
print("loading", file_or_url, filename)
file_or_url = str(file_or_url)
assert isinstance(file_or_url, str)
return torch.load(
hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device)
def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin")
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
model = builders.get_compression_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
pkg = _get_state_dict(file_or_url)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu':
cfg.transformer_lm.memory_efficient = False
cfg.transformer_lm.custom = True
cfg.dtype = 'float32'
else:
cfg.dtype = 'float16'
model = builders.get_lm_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def get_pretrained(name: str = 'small', device='cuda'):
model_id = MODEL_CHECKPOINTS_MAP[name]
compression_model = load_compression_model(model_id, device=device)
lm = load_lm_model(model_id, device=device)
return MusicGen(name, compression_model, lm)