Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Utility to export a training checkpoint to a lightweight release checkpoint. | |
""" | |
from pathlib import Path | |
import typing as tp | |
from omegaconf import OmegaConf, DictConfig | |
import torch | |
def _clean_lm_cfg(cfg: DictConfig): | |
OmegaConf.set_struct(cfg, False) | |
# This used to be set automatically in the LM solver, need a more robust solution | |
# for the future. | |
cfg['transformer_lm']['card'] = 2048 | |
cfg['transformer_lm']['n_q'] = 4 | |
# Experimental params no longer supported. | |
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', | |
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] | |
for name in bad_params: | |
del cfg['transformer_lm'][name] | |
OmegaConf.set_struct(cfg, True) | |
return cfg | |
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): | |
sig = Path(checkpoint_path).parent.name | |
assert len(sig) == 8, "Not a valid Dora signature" | |
pkg = torch.load(checkpoint_path, 'cpu') | |
new_pkg = { | |
'best_state': pkg['ema']['state']['model'], | |
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), | |
} | |
out_file = Path(out_folder) / f'{sig}.th' | |
torch.save(new_pkg, out_file) | |
return out_file | |
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): | |
sig = Path(checkpoint_path).parent.name | |
assert len(sig) == 8, "Not a valid Dora signature" | |
pkg = torch.load(checkpoint_path, 'cpu') | |
new_pkg = { | |
'best_state': pkg['fsdp_best_state']['model'], | |
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) | |
} | |
out_file = Path(out_folder) / f'{sig}.th' | |
torch.save(new_pkg, out_file) | |
return out_file | |