zyingt's picture
Upload 685 files
0d80816
raw
history blame contribute delete
No virus
3.52 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
import os
import argparse
import glob
import yaml
import numpy as np
import torch
def get_args():
parser = argparse.ArgumentParser(description="average model")
parser.add_argument("--dst_model", required=True, help="averaged model")
parser.add_argument("--src_path", required=True, help="src model path for average")
parser.add_argument("--val_best", action="store_true", help="averaged model")
parser.add_argument("--num", default=5, type=int, help="nums for averaged model")
parser.add_argument(
"--min_epoch", default=0, type=int, help="min epoch used for averaging model"
)
parser.add_argument(
"--max_epoch",
default=65536,
type=int,
help="max epoch used for averaging model",
)
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
checkpoints = []
val_scores = []
if args.val_best:
yamls = glob.glob("{}/[!train]*.yaml".format(args.src_path))
for y in yamls:
with open(y, "r") as f:
dic_yaml = yaml.load(f, Loader=yaml.FullLoader)
loss = dic_yaml["cv_loss"]
epoch = dic_yaml["epoch"]
if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores += [[epoch, loss]]
val_scores = np.array(val_scores)
sort_idx = np.argsort(val_scores[:, -1])
sorted_val_scores = val_scores[sort_idx][::1]
print("best val scores = " + str(sorted_val_scores[: args.num, 1]))
print(
"selected epochs = "
+ str(sorted_val_scores[: args.num, 0].astype(np.int64))
)
path_list = [
args.src_path + "/{}.pt".format(int(epoch))
for epoch in sorted_val_scores[: args.num, 0]
]
else:
path_list = glob.glob("{}/[0-9]*.pt".format(args.src_path))
path_list = sorted(path_list, key=os.path.getmtime)
path_list = path_list[-args.num :]
print(path_list)
avg = None
num = args.num
assert num == len(path_list)
for path in path_list:
print("Processing {}".format(path))
states = torch.load(path, map_location=torch.device("cpu"))
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
# pytorch 1.6 use true_divide instead of /=
avg[k] = torch.true_divide(avg[k], num)
print("Saving to {}".format(args.dst_model))
torch.save(avg, args.dst_model)
if __name__ == "__main__":
main()