File size: 1,976 Bytes
5238467
 
 
 
 
9d7284e
5238467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Utility for loading the models from HF."""
from pathlib import Path
import typing as tp

from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
import torch

from audiocraft.models import builders, MusicGen

MODEL_CHECKPOINTS_MAP = {
    "small": "facebook/musicgen-small",
    "medium": "facebook/musicgen-medium",
    "large": "facebook/musicgen-large",
    "melody": "facebook/musicgen-melody",
}


def _get_state_dict(file_or_url: tp.Union[Path, str],
                    filename="state_dict.bin", device='cpu'):
    # Return the state dict either from a file or url
    print("loading", file_or_url, filename)
    file_or_url = str(file_or_url)
    assert isinstance(file_or_url, str)
    return torch.load(
        hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device)


def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
    pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin")
    cfg = OmegaConf.create(pkg['xp.cfg'])
    cfg.device = str(device)
    model = builders.get_compression_model(cfg)
    model.load_state_dict(pkg['best_state'])
    model.eval()
    model.cfg = cfg
    return model


def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
    pkg = _get_state_dict(file_or_url)
    cfg = OmegaConf.create(pkg['xp.cfg'])
    cfg.device = str(device)
    if cfg.device == 'cpu':
        cfg.transformer_lm.memory_efficient = False
        cfg.transformer_lm.custom = True
        cfg.dtype = 'float32'
    else:
        cfg.dtype = 'float16'
    model = builders.get_lm_model(cfg)
    model.load_state_dict(pkg['best_state'])
    model.eval()
    model.cfg = cfg
    return model


def get_pretrained(name: str = 'small', device='cuda'):
    model_id = MODEL_CHECKPOINTS_MAP[name]
    compression_model = load_compression_model(model_id, device=device)
    lm = load_lm_model(model_id, device=device)
    return MusicGen(name, compression_model, lm)