|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Builds a .nemo file with average weights over multiple .ckpt files (assumes .ckpt files in same folder as .nemo file). |
|
|
|
Usage example for building *-averaged.nemo for a given .nemo file: |
|
|
|
NeMo/scripts/checkpoint_averaging/checkpoint_averaging.py my_model.nemo |
|
|
|
Usage example for building *-averaged.nemo files for all results in sub-directories under current path: |
|
|
|
find . -name '*.nemo' | grep -v -- "-averaged.nemo" | xargs NeMo/scripts/checkpoint_averaging/checkpoint_averaging.py |
|
|
|
|
|
NOTE: if yout get the following error `AttributeError: Can't get attribute '???' on <module '__main__' from '???'>` |
|
use --import_fname_list <FILE> with all files that contains missing classes. |
|
""" |
|
|
|
import argparse |
|
import glob |
|
import importlib |
|
import os |
|
import sys |
|
|
|
import torch |
|
from tqdm.auto import tqdm |
|
|
|
from nemo.core import ModelPT |
|
from nemo.utils import logging, model_utils |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'model_fname_list', |
|
metavar='NEMO_FILE_OR_FOLDER', |
|
type=str, |
|
nargs='+', |
|
help='Input .nemo files (or folders who contains them) to parse', |
|
) |
|
parser.add_argument( |
|
'--import_fname_list', |
|
metavar='FILE', |
|
type=str, |
|
nargs='+', |
|
default=[], |
|
help='A list of Python file names to "from FILE import *" (Needed when some classes were defined in __main__ of a script)', |
|
) |
|
parser.add_argument( |
|
'--class_path', type=str, default='', help='A path to class "module.submodule.class" (if given)', |
|
) |
|
args = parser.parse_args() |
|
|
|
logging.info( |
|
f"\n\nIMPORTANT:\nIf you get the following error:\n\t(AttributeError: Can't get attribute '???' on <module '__main__' from '???'>)\nuse:\n\t--import_fname_list\nfor all files that contain missing classes.\n\n" |
|
) |
|
|
|
for fn in args.import_fname_list: |
|
logging.info(f"Importing * from {fn}") |
|
sys.path.insert(0, os.path.dirname(fn)) |
|
globals().update(importlib.import_module(os.path.splitext(os.path.basename(fn))[0]).__dict__) |
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
for model_fname_i, model_fname in enumerate(args.model_fname_list): |
|
if not model_fname.endswith(".nemo"): |
|
|
|
nemo_files = list( |
|
filter(lambda fn: not fn.endswith("-averaged.nemo"), glob.glob(os.path.join(model_fname, "*.nemo"))) |
|
) |
|
if len(nemo_files) != 1: |
|
raise RuntimeError(f"Expected exactly one .nemo file but discovered {len(nemo_files)} .nemo files") |
|
|
|
model_fname = nemo_files[0] |
|
|
|
model_folder_path = os.path.dirname(model_fname) |
|
fn, fe = os.path.splitext(model_fname) |
|
avg_model_fname = f"{fn}-averaged{fe}" |
|
|
|
logging.info(f"\n===> [{model_fname_i+1} / {len(args.model_fname_list)}] Parsing folder {model_folder_path}\n") |
|
|
|
|
|
model_cfg = ModelPT.restore_from(restore_path=model_fname, return_config=True) |
|
if args.class_path: |
|
classpath = args.class_path |
|
else: |
|
classpath = model_cfg.target |
|
imported_class = model_utils.import_class_by_path(classpath) |
|
logging.info(f"Loading model {model_fname}") |
|
nemo_model = imported_class.restore_from(restore_path=model_fname, map_location=device) |
|
|
|
|
|
checkpoint_paths = [ |
|
os.path.join(model_folder_path, x) |
|
for x in os.listdir(model_folder_path) |
|
if x.endswith('.ckpt') and not x.endswith('-last.ckpt') |
|
] |
|
""" < Checkpoint Averaging Logic > """ |
|
|
|
n = len(checkpoint_paths) |
|
avg_state = None |
|
|
|
logging.info(f"Averaging {n} checkpoints ...") |
|
|
|
for ix, path in enumerate(tqdm(checkpoint_paths, total=n, desc='Averaging checkpoints')): |
|
checkpoint = torch.load(path, map_location=device) |
|
|
|
if 'state_dict' in checkpoint: |
|
checkpoint = checkpoint['state_dict'] |
|
else: |
|
raise RuntimeError(f"Checkpoint from {path} does not include a state_dict.") |
|
|
|
if ix == 0: |
|
|
|
avg_state = checkpoint |
|
|
|
logging.info(f"Initialized average state dict with checkpoint:\n\t{path}") |
|
else: |
|
|
|
for k in avg_state: |
|
avg_state[k] = avg_state[k] + checkpoint[k] |
|
|
|
logging.info(f"Updated average state dict with state from checkpoint:\n\t{path}") |
|
|
|
for k in avg_state: |
|
if str(avg_state[k].dtype).startswith("torch.int"): |
|
|
|
|
|
pass |
|
else: |
|
avg_state[k] = avg_state[k] / n |
|
|
|
|
|
nemo_model.load_state_dict(avg_state, strict=True) |
|
|
|
logging.info(f"Saving average model to:\n\t{avg_model_fname}") |
|
nemo_model.save_to(avg_model_fname) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|