| 
							 | 
						"""Utility for loading the models from HF.""" | 
					
					
						
						| 
							 | 
						from pathlib import Path | 
					
					
						
						| 
							 | 
						import typing as tp | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from omegaconf import OmegaConf | 
					
					
						
						| 
							 | 
						from huggingface_hub import hf_hub_download | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from audiocraft.models import builders, MusicGen | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						MODEL_CHECKPOINTS_MAP = { | 
					
					
						
						| 
							 | 
						    "small": "facebook/musicgen-small", | 
					
					
						
						| 
							 | 
						    "medium": "facebook/musicgen-medium", | 
					
					
						
						| 
							 | 
						    "large": "facebook/musicgen-large", | 
					
					
						
						| 
							 | 
						    "melody": "facebook/musicgen-melody", | 
					
					
						
						| 
							 | 
						} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _get_state_dict(file_or_url: tp.Union[Path, str], | 
					
					
						
						| 
							 | 
						                    filename="state_dict.bin", device='cpu'): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    print("loading", file_or_url, filename) | 
					
					
						
						| 
							 | 
						    file_or_url = str(file_or_url) | 
					
					
						
						| 
							 | 
						    assert isinstance(file_or_url, str) | 
					
					
						
						| 
							 | 
						    return torch.load( | 
					
					
						
						| 
							 | 
						        hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'): | 
					
					
						
						| 
							 | 
						    pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin") | 
					
					
						
						| 
							 | 
						    cfg = OmegaConf.create(pkg['xp.cfg']) | 
					
					
						
						| 
							 | 
						    cfg.device = str(device) | 
					
					
						
						| 
							 | 
						    model = builders.get_compression_model(cfg) | 
					
					
						
						| 
							 | 
						    model.load_state_dict(pkg['best_state']) | 
					
					
						
						| 
							 | 
						    model.eval() | 
					
					
						
						| 
							 | 
						    model.cfg = cfg | 
					
					
						
						| 
							 | 
						    return model | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'): | 
					
					
						
						| 
							 | 
						    pkg = _get_state_dict(file_or_url) | 
					
					
						
						| 
							 | 
						    cfg = OmegaConf.create(pkg['xp.cfg']) | 
					
					
						
						| 
							 | 
						    cfg.device = str(device) | 
					
					
						
						| 
							 | 
						    if cfg.device == 'cpu': | 
					
					
						
						| 
							 | 
						        cfg.transformer_lm.memory_efficient = False | 
					
					
						
						| 
							 | 
						        cfg.transformer_lm.custom = True | 
					
					
						
						| 
							 | 
						        cfg.dtype = 'float32' | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        cfg.dtype = 'float16' | 
					
					
						
						| 
							 | 
						    model = builders.get_lm_model(cfg) | 
					
					
						
						| 
							 | 
						    model.load_state_dict(pkg['best_state']) | 
					
					
						
						| 
							 | 
						    model.eval() | 
					
					
						
						| 
							 | 
						    model.cfg = cfg | 
					
					
						
						| 
							 | 
						    return model | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_pretrained(name: str = 'small', device='cuda'): | 
					
					
						
						| 
							 | 
						    model_id = MODEL_CHECKPOINTS_MAP[name] | 
					
					
						
						| 
							 | 
						    compression_model = load_compression_model(model_id, device=device) | 
					
					
						
						| 
							 | 
						    lm = load_lm_model(model_id, device=device) | 
					
					
						
						| 
							 | 
						    return MusicGen(name, compression_model, lm) | 
					
					
						
						| 
							 | 
						
 |