Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Utility functions to load from the checkpoints. | |
Each checkpoint is a torch.saved dict with the following keys: | |
- 'xp.cfg': the hydra config as dumped during training. This should be used | |
to rebuild the object using the audiocraft.models.builders functions, | |
- 'model_best_state': a readily loadable best state for the model, including | |
the conditioner. The model obtained from `xp.cfg` should be compatible | |
with this state dict. In the case of a LM, the encodec model would not be | |
bundled along but instead provided separately. | |
Those functions also support loading from a remote location with the Torch Hub API. | |
They also support overriding some parameters, in particular the device and dtype | |
of the returned model. | |
""" | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
import typing as tp | |
import os | |
from omegaconf import OmegaConf, DictConfig | |
import torch | |
import audiocraft | |
from . import builders | |
from .encodec import CompressionModel | |
def get_audiocraft_cache_dir() -> tp.Optional[str]: | |
return os.environ.get('AUDIOCRAFT_CACHE_DIR', None) | |
def _get_state_dict( | |
file_or_url_or_id: tp.Union[Path, str], | |
filename: tp.Optional[str] = None, | |
device='cpu', | |
cache_dir: tp.Optional[str] = None, | |
): | |
if cache_dir is None: | |
cache_dir = get_audiocraft_cache_dir() | |
# Return the state dict either from a file or url | |
file_or_url_or_id = str(file_or_url_or_id) | |
assert isinstance(file_or_url_or_id, str) | |
if os.path.isfile(file_or_url_or_id): | |
return torch.load(file_or_url_or_id, map_location=device) | |
if os.path.isdir(file_or_url_or_id): | |
file = f"{file_or_url_or_id}/{filename}" | |
return torch.load(file, map_location=device) | |
elif file_or_url_or_id.startswith('https://'): | |
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) | |
else: | |
assert filename is not None, "filename needs to be defined if using HF checkpoints" | |
file = hf_hub_download( | |
repo_id=file_or_url_or_id, | |
filename=filename, | |
cache_dir=cache_dir, | |
library_name="audiocraft", | |
library_version=audiocraft.__version__, | |
) | |
return torch.load(file, map_location=device) | |
def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): | |
return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) | |
def load_compression_model( | |
file_or_url_or_id: tp.Union[Path, str], | |
device="cpu", | |
cache_dir: tp.Optional[str] = None, | |
): | |
pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) | |
if 'pretrained' in pkg: | |
return CompressionModel.get_pretrained(pkg['pretrained'], device=device) | |
cfg = OmegaConf.create(pkg['xp.cfg']) | |
cfg.device = str(device) | |
model = builders.get_compression_model(cfg) | |
model.load_state_dict(pkg["best_state"]) | |
model.eval() | |
return model | |
def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): | |
return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) | |
def _delete_param(cfg: DictConfig, full_name: str): | |
parts = full_name.split('.') | |
for part in parts[:-1]: | |
if part in cfg: | |
cfg = cfg[part] | |
else: | |
return | |
OmegaConf.set_struct(cfg, False) | |
if parts[-1] in cfg: | |
del cfg[parts[-1]] | |
OmegaConf.set_struct(cfg, True) | |
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): | |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) | |
cfg = OmegaConf.create(pkg['xp.cfg']) | |
cfg.device = str(device) | |
if cfg.device == 'cpu': | |
cfg.dtype = 'float32' | |
else: | |
cfg.dtype = 'float16' | |
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path') | |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p') | |
_delete_param(cfg, 'conditioners.args.drop_desc_p') | |
model = builders.get_lm_model(cfg) | |
model.load_state_dict(pkg['best_state']) | |
model.eval() | |
model.cfg = cfg | |
return model | |
def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int, | |
device='cpu', cache_dir: tp.Optional[str] = None): | |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) | |
cfg = OmegaConf.create(pkg['xp.cfg']) | |
cfg.device = str(device) | |
if cfg.device == 'cpu': | |
cfg.dtype = 'float32' | |
else: | |
cfg.dtype = 'float16' | |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p') | |
_delete_param(cfg, 'conditioners.args.drop_desc_p') | |
cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate | |
cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration | |
cfg.transformer_lm.span_len = cfg.masking.span_len | |
# MAGNeT models v1 support only xformers backend. | |
from audiocraft.modules.transformer import set_efficient_attention_backend | |
if cfg.transformer_lm.memory_efficient: | |
set_efficient_attention_backend("xformers") | |
model = builders.get_lm_model(cfg) | |
model.load_state_dict(pkg['best_state']) | |
model.eval() | |
model.cfg = cfg | |
return model | |
def load_dit_model_melodyflow(file_or_url_or_id: tp.Union[Path, str], | |
device='cpu', cache_dir: tp.Optional[str] = None): | |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) | |
cfg = OmegaConf.create(pkg['xp.cfg']) | |
cfg.device = str(device) | |
if cfg.device == 'cpu' or cfg.device == 'mps': | |
cfg.dtype = 'float32' | |
else: | |
cfg.dtype = 'bfloat16' | |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p') | |
_delete_param(cfg, 'conditioners.args.drop_desc_p') | |
model = builders.get_dit_model(cfg) | |
model.load_state_dict(pkg['best_state']) | |
model.eval() | |
model.cfg = cfg | |
return model | |
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], | |
filename: tp.Optional[str] = None, | |
cache_dir: tp.Optional[str] = None): | |
return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir) | |
def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], | |
device='cpu', | |
filename: tp.Optional[str] = None, | |
cache_dir: tp.Optional[str] = None): | |
pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir) | |
models = [] | |
processors = [] | |
cfgs = [] | |
sample_rate = pkg['sample_rate'] | |
for i in range(pkg['n_bands']): | |
cfg = pkg[i]['cfg'] | |
model = builders.get_diffusion_model(cfg) | |
model_dict = pkg[i]['model_state'] | |
model.load_state_dict(model_dict) | |
model.to(device) | |
processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate) | |
processor_dict = pkg[i]['processor_state'] | |
processor.load_state_dict(processor_dict) | |
processor.to(device) | |
models.append(model) | |
processors.append(processor) | |
cfgs.append(cfg) | |
return models, processors, cfgs | |
def load_audioseal_models( | |
file_or_url_or_id: tp.Union[Path, str], | |
device="cpu", | |
filename: tp.Optional[str] = None, | |
cache_dir: tp.Optional[str] = None, | |
): | |
detector_ckpt = _get_state_dict( | |
file_or_url_or_id, | |
filename=f"detector_{filename}.pth", | |
device=device, | |
cache_dir=cache_dir, | |
) | |
assert ( | |
"model" in detector_ckpt | |
), f"No model state dict found in {file_or_url_or_id}/detector_{filename}.pth" | |
detector_state = detector_ckpt["model"] | |
generator_ckpt = _get_state_dict( | |
file_or_url_or_id, | |
filename=f"generator_{filename}.pth", | |
device=device, | |
cache_dir=cache_dir, | |
) | |
assert ( | |
"model" in generator_ckpt | |
), f"No model state dict found in {file_or_url_or_id}/generator_{filename}.pth" | |
generator_state = generator_ckpt["model"] | |
def load_model_config(): | |
if Path(file_or_url_or_id).joinpath(f"{filename}.yaml").is_file(): | |
return OmegaConf.load(Path(file_or_url_or_id).joinpath(f"{filename}.yaml")) | |
elif file_or_url_or_id.startswith("https://"): | |
import requests # type: ignore | |
resp = requests.get(f"{file_or_url_or_id}/{filename}.yaml") | |
return OmegaConf.create(resp.text) | |
else: | |
file = hf_hub_download( | |
repo_id=file_or_url_or_id, | |
filename=f"{filename}.yaml", | |
cache_dir=cache_dir, | |
library_name="audiocraft", | |
library_version=audiocraft.__version__, | |
) | |
return OmegaConf.load(file) | |
try: | |
cfg = load_model_config() | |
except Exception as exc: # noqa | |
cfg_fp = ( | |
Path(__file__) | |
.parents[2] | |
.joinpath("config", "model", "watermark", "default.yaml") | |
) | |
cfg = OmegaConf.load(cfg_fp) | |
OmegaConf.resolve(cfg) | |
model = builders.get_watermark_model(cfg) | |
model.generator.load_state_dict(generator_state) | |
model.detector.load_state_dict(detector_state) | |
return model.to(device) | |