|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Utility to export a training checkpoint to a lightweight release checkpoint. |
|
""" |
|
|
|
from pathlib import Path |
|
import typing as tp |
|
|
|
from omegaconf import OmegaConf |
|
import torch |
|
|
|
from audiocraft import __version__ |
|
|
|
|
|
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): |
|
"""Export only the best state from the given EnCodec checkpoint. This |
|
should be used if you trained your own EnCodec model. |
|
""" |
|
pkg = torch.load(checkpoint_path, 'cpu') |
|
new_pkg = { |
|
'best_state': pkg['best_state']['model'], |
|
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), |
|
'version': __version__, |
|
'exported': True, |
|
} |
|
Path(out_file).parent.mkdir(exist_ok=True, parents=True) |
|
torch.save(new_pkg, out_file) |
|
return out_file |
|
|
|
|
|
def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): |
|
"""Export a compression model (potentially EnCodec) from a pretrained model. |
|
This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. |
|
Do not include the //pretrained/ prefix. For instance if you trained a model |
|
with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. |
|
|
|
In that case, this will not actually include a copy of the model, simply the reference |
|
to the model used. |
|
""" |
|
if Path(pretrained_encodec).exists(): |
|
pkg = torch.load(pretrained_encodec) |
|
assert 'best_state' in pkg |
|
assert 'xp.cfg' in pkg |
|
assert 'version' in pkg |
|
assert 'exported' in pkg |
|
else: |
|
pkg = { |
|
'pretrained': pretrained_encodec, |
|
'exported': True, |
|
'version': __version__, |
|
} |
|
Path(out_file).parent.mkdir(exist_ok=True, parents=True) |
|
torch.save(pkg, out_file) |
|
|
|
|
|
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): |
|
"""Export only the best state from the given MusicGen or AudioGen checkpoint. |
|
""" |
|
pkg = torch.load(checkpoint_path, 'cpu') |
|
if pkg['fsdp_best_state']: |
|
best_state = pkg['fsdp_best_state']['model'] |
|
else: |
|
assert pkg['best_state'] |
|
best_state = pkg['best_state']['model'] |
|
new_pkg = { |
|
'best_state': best_state, |
|
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), |
|
'version': __version__, |
|
'exported': True, |
|
} |
|
|
|
Path(out_file).parent.mkdir(exist_ok=True, parents=True) |
|
torch.save(new_pkg, out_file) |
|
return out_file |
|
|