# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # import logging import os import re import yaml import torch from collections import OrderedDict import datetime def load_checkpoint(model: torch.nn.Module, path: str) -> dict: if torch.cuda.is_available(): logging.info("Checkpoint: loading from checkpoint %s for GPU" % path) checkpoint = torch.load(path) else: logging.info("Checkpoint: loading from checkpoint %s for CPU" % path) checkpoint = torch.load(path, map_location="cpu") model.load_state_dict(checkpoint, strict=False) info_path = re.sub(".pt$", ".yaml", path) configs = {} if os.path.exists(info_path): with open(info_path, "r") as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) return configs def save_checkpoint(model: torch.nn.Module, path: str, infos=None): """ Args: infos (dict or None): any info you want to save. """ logging.info("Checkpoint: save to checkpoint %s" % path) if isinstance(model, torch.nn.DataParallel): state_dict = model.module.state_dict() elif isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save(state_dict, path) info_path = re.sub(".pt$", ".yaml", path) if infos is None: infos = {} infos["save_time"] = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S") with open(info_path, "w") as fout: data = yaml.dump(infos) fout.write(data) def filter_modules(model_state_dict, modules): new_mods = [] incorrect_mods = [] mods_model = model_state_dict.keys() for mod in modules: if any(key.startswith(mod) for key in mods_model): new_mods += [mod] else: incorrect_mods += [mod] if incorrect_mods: logging.warning( "module(s) %s don't match or (partially match) " "available modules in model.", incorrect_mods, ) logging.warning("for information, the existing modules in model are:") logging.warning("%s", mods_model) return new_mods def load_trained_modules(model: torch.nn.Module, args: None): # Load encoder modules with pre-trained model(s). enc_model_path = args.enc_init enc_modules = args.enc_init_mods main_state_dict = model.state_dict() logging.warning("model(s) found for pre-initialization") if os.path.isfile(enc_model_path): logging.info("Checkpoint: loading from checkpoint %s for CPU" % enc_model_path) model_state_dict = torch.load(enc_model_path, map_location="cpu") modules = filter_modules(model_state_dict, enc_modules) partial_state_dict = OrderedDict() for key, value in model_state_dict.items(): if any(key.startswith(m) for m in modules): partial_state_dict[key] = value main_state_dict.update(partial_state_dict) else: logging.warning("model was not found : %s", enc_model_path) model.load_state_dict(main_state_dict) configs = {} return configs