Gael Le Lan
Initial commit
9d0d223
raw
history blame
9.51 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility functions to load from the checkpoints.
Each checkpoint is a torch.saved dict with the following keys:
- 'xp.cfg': the hydra config as dumped during training. This should be used
to rebuild the object using the audiocraft.models.builders functions,
- 'model_best_state': a readily loadable best state for the model, including
the conditioner. The model obtained from `xp.cfg` should be compatible
with this state dict. In the case of a LM, the encodec model would not be
bundled along but instead provided separately.
Those functions also support loading from a remote location with the Torch Hub API.
They also support overriding some parameters, in particular the device and dtype
of the returned model.
"""
from pathlib import Path
from huggingface_hub import hf_hub_download
import typing as tp
import os
from omegaconf import OmegaConf, DictConfig
import torch
import audiocraft
from . import builders
from .encodec import CompressionModel
def get_audiocraft_cache_dir() -> tp.Optional[str]:
return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
def _get_state_dict(
file_or_url_or_id: tp.Union[Path, str],
filename: tp.Optional[str] = None,
device='cpu',
cache_dir: tp.Optional[str] = None,
):
if cache_dir is None:
cache_dir = get_audiocraft_cache_dir()
# Return the state dict either from a file or url
file_or_url_or_id = str(file_or_url_or_id)
assert isinstance(file_or_url_or_id, str)
if os.path.isfile(file_or_url_or_id):
return torch.load(file_or_url_or_id, map_location=device)
if os.path.isdir(file_or_url_or_id):
file = f"{file_or_url_or_id}/{filename}"
return torch.load(file, map_location=device)
elif file_or_url_or_id.startswith('https://'):
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
else:
assert filename is not None, "filename needs to be defined if using HF checkpoints"
file = hf_hub_download(
repo_id=file_or_url_or_id,
filename=filename,
cache_dir=cache_dir,
library_name="audiocraft",
library_version=audiocraft.__version__,
)
return torch.load(file, map_location=device)
def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
def load_compression_model(
file_or_url_or_id: tp.Union[Path, str],
device="cpu",
cache_dir: tp.Optional[str] = None,
):
pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
if 'pretrained' in pkg:
return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
model = builders.get_compression_model(cfg)
model.load_state_dict(pkg["best_state"])
model.eval()
return model
def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
def _delete_param(cfg: DictConfig, full_name: str):
parts = full_name.split('.')
for part in parts[:-1]:
if part in cfg:
cfg = cfg[part]
else:
return
OmegaConf.set_struct(cfg, False)
if parts[-1] in cfg:
del cfg[parts[-1]]
OmegaConf.set_struct(cfg, True)
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu':
cfg.dtype = 'float32'
else:
cfg.dtype = 'float16'
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
_delete_param(cfg, 'conditioners.args.drop_desc_p')
model = builders.get_lm_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int,
device='cpu', cache_dir: tp.Optional[str] = None):
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu':
cfg.dtype = 'float32'
else:
cfg.dtype = 'float16'
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
_delete_param(cfg, 'conditioners.args.drop_desc_p')
cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate
cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration
cfg.transformer_lm.span_len = cfg.masking.span_len
# MAGNeT models v1 support only xformers backend.
from audiocraft.modules.transformer import set_efficient_attention_backend
if cfg.transformer_lm.memory_efficient:
set_efficient_attention_backend("xformers")
model = builders.get_lm_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def load_dit_model_melodyflow(file_or_url_or_id: tp.Union[Path, str],
device='cpu', cache_dir: tp.Optional[str] = None):
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu' or cfg.device == 'mps':
cfg.dtype = 'float32'
else:
cfg.dtype = 'bfloat16'
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
_delete_param(cfg, 'conditioners.args.drop_desc_p')
model = builders.get_dit_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
filename: tp.Optional[str] = None,
cache_dir: tp.Optional[str] = None):
return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
device='cpu',
filename: tp.Optional[str] = None,
cache_dir: tp.Optional[str] = None):
pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
models = []
processors = []
cfgs = []
sample_rate = pkg['sample_rate']
for i in range(pkg['n_bands']):
cfg = pkg[i]['cfg']
model = builders.get_diffusion_model(cfg)
model_dict = pkg[i]['model_state']
model.load_state_dict(model_dict)
model.to(device)
processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
processor_dict = pkg[i]['processor_state']
processor.load_state_dict(processor_dict)
processor.to(device)
models.append(model)
processors.append(processor)
cfgs.append(cfg)
return models, processors, cfgs
def load_audioseal_models(
file_or_url_or_id: tp.Union[Path, str],
device="cpu",
filename: tp.Optional[str] = None,
cache_dir: tp.Optional[str] = None,
):
detector_ckpt = _get_state_dict(
file_or_url_or_id,
filename=f"detector_{filename}.pth",
device=device,
cache_dir=cache_dir,
)
assert (
"model" in detector_ckpt
), f"No model state dict found in {file_or_url_or_id}/detector_{filename}.pth"
detector_state = detector_ckpt["model"]
generator_ckpt = _get_state_dict(
file_or_url_or_id,
filename=f"generator_{filename}.pth",
device=device,
cache_dir=cache_dir,
)
assert (
"model" in generator_ckpt
), f"No model state dict found in {file_or_url_or_id}/generator_{filename}.pth"
generator_state = generator_ckpt["model"]
def load_model_config():
if Path(file_or_url_or_id).joinpath(f"{filename}.yaml").is_file():
return OmegaConf.load(Path(file_or_url_or_id).joinpath(f"{filename}.yaml"))
elif file_or_url_or_id.startswith("https://"):
import requests # type: ignore
resp = requests.get(f"{file_or_url_or_id}/{filename}.yaml")
return OmegaConf.create(resp.text)
else:
file = hf_hub_download(
repo_id=file_or_url_or_id,
filename=f"{filename}.yaml",
cache_dir=cache_dir,
library_name="audiocraft",
library_version=audiocraft.__version__,
)
return OmegaConf.load(file)
try:
cfg = load_model_config()
except Exception as exc: # noqa
cfg_fp = (
Path(__file__)
.parents[2]
.joinpath("config", "model", "watermark", "default.yaml")
)
cfg = OmegaConf.load(cfg_fp)
OmegaConf.resolve(cfg)
model = builders.get_watermark_model(cfg)
model.generator.load_state_dict(generator_state)
model.detector.load_state_dict(detector_state)
return model.to(device)