# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Utility functions to load from the checkpoints. Each checkpoint is a torch.saved dict with the following keys: - 'xp.cfg': the hydra config as dumped during training. This should be used to rebuild the object using the audiocraft.models.builders functions, - 'model_best_state': a readily loadable best state for the model, including the conditioner. The model obtained from `xp.cfg` should be compatible with this state dict. In the case of a LM, the encodec model would not be bundled along but instead provided separately. Those functions also support loading from a remote location with the Torch Hub API. They also support overriding some parameters, in particular the device and dtype of the returned model. """ from pathlib import Path from huggingface_hub import hf_hub_download import typing as tp import os from omegaconf import OmegaConf import torch from . import builders HF_MODEL_CHECKPOINTS_MAP = { "small": "facebook/musicgen-small", "medium": "facebook/musicgen-medium", "large": "facebook/musicgen-large", "melody": "facebook/musicgen-melody", } def _get_state_dict( file_or_url_or_id: tp.Union[Path, str], filename: tp.Optional[str] = None, device='cpu', cache_dir: tp.Optional[str] = None, ): # Return the state dict either from a file or url file_or_url_or_id = str(file_or_url_or_id) assert isinstance(file_or_url_or_id, str) if os.path.isfile(file_or_url_or_id): return torch.load(file_or_url_or_id, map_location=device) elif file_or_url_or_id.startswith('https://'): return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP: assert filename is not None, "filename needs to be defined if using HF checkpoints" repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id] file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir) return torch.load(file, map_location=device) else: raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.") def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) cfg = OmegaConf.create(pkg['xp.cfg']) cfg.device = str(device) model = builders.get_compression_model(cfg) model.load_state_dict(pkg['best_state']) model.eval() return model def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) cfg = OmegaConf.create(pkg['xp.cfg']) cfg.device = str(device) if cfg.device == 'cpu': cfg.transformer_lm.memory_efficient = False cfg.transformer_lm.custom = True cfg.dtype = 'float32' else: cfg.dtype = 'float16' model = builders.get_lm_model(cfg) model.load_state_dict(pkg['best_state']) model.eval() model.cfg = cfg return model