File size: 1,906 Bytes
acc615e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Legacy functions used at the time of the first release, kept for referencd.
"""

from pathlib import Path
import typing as tp

from omegaconf import OmegaConf, DictConfig
import torch


def _clean_lm_cfg(cfg: DictConfig):
    OmegaConf.set_struct(cfg, False)
    # This used to be set automatically in the LM solver, need a more robust solution
    # for the future.
    cfg['transformer_lm']['card'] = 2048
    cfg['transformer_lm']['n_q'] = 4
    # Experimental params no longer supported.
    bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
                  'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
    for name in bad_params:
        del cfg['transformer_lm'][name]
    OmegaConf.set_struct(cfg, True)
    return cfg


def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
    sig = Path(checkpoint_path).parent.name
    assert len(sig) == 8, "Not a valid Dora signature"
    pkg = torch.load(checkpoint_path, 'cpu')
    new_pkg = {
        'best_state': pkg['ema']['state']['model'],
        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
    }
    out_file = Path(out_folder) / f'{sig}.th'
    torch.save(new_pkg, out_file)
    return out_file


def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
    sig = Path(checkpoint_path).parent.name
    assert len(sig) == 8, "Not a valid Dora signature"
    pkg = torch.load(checkpoint_path, 'cpu')
    new_pkg = {
        'best_state': pkg['fsdp_best_state']['model'],
        'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
    }
    out_file = Path(out_folder) / f'{sig}.th'
    torch.save(new_pkg, out_file)
    return out_file