File size: 1,907 Bytes
5238467 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility to export a training checkpoint to a lightweight release checkpoint.
"""
from pathlib import Path
import typing as tp
from omegaconf import OmegaConf, DictConfig
import torch
def _clean_lm_cfg(cfg: DictConfig):
OmegaConf.set_struct(cfg, False)
# This used to be set automatically in the LM solver, need a more robust solution
# for the future.
cfg['transformer_lm']['card'] = 2048
cfg['transformer_lm']['n_q'] = 4
# Experimental params no longer supported.
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
for name in bad_params:
del cfg['transformer_lm'][name]
OmegaConf.set_struct(cfg, True)
return cfg
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
sig = Path(checkpoint_path).parent.name
assert len(sig) == 8, "Not a valid Dora signature"
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['ema']['state']['model'],
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
}
out_file = Path(out_folder) / f'{sig}.th'
torch.save(new_pkg, out_file)
return out_file
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
sig = Path(checkpoint_path).parent.name
assert len(sig) == 8, "Not a valid Dora signature"
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['fsdp_best_state']['model'],
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
}
out_file = Path(out_folder) / f'{sig}.th'
torch.save(new_pkg, out_file)
return out_file
|