MassivelyMultilingualTTS / run_weight_averaging.py
Flux9665's picture
use explicit code instead of relying on release download
9e275b8
raw
history blame
No virus
3.6 kB
"""
https://alexander-stasiuk.medium.com/pytorch-weights-averaging-e2c0fa611a0c
"""
import os
import torch
from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS
from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN
from Utility.storage_config import MODELS_DIR
def load_net_toucan(path):
check_dict = torch.load(path, map_location=torch.device("cpu"))
net = ToucanTTS(weights=check_dict["model"], config=check_dict["config"])
return net, check_dict["default_emb"]
def load_net_bigvgan(path):
check_dict = torch.load(path, map_location=torch.device("cpu"))
net = HiFiGAN(weights=check_dict["generator"])
return net, None
def get_n_recent_checkpoints_paths(checkpoint_dir, n=5):
print("selecting checkpoints...")
checkpoint_list = list()
for el in os.listdir(checkpoint_dir):
if el.endswith(".pt") and el.startswith("checkpoint_"):
try:
checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
except RuntimeError:
pass
if len(checkpoint_list) == 0:
return None
elif len(checkpoint_list) < n:
n = len(checkpoint_list)
checkpoint_list.sort(reverse=True)
return [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:n]]
def average_checkpoints(list_of_checkpoint_paths, load_func):
# COLLECT CHECKPOINTS
if list_of_checkpoint_paths is None or len(list_of_checkpoint_paths) == 0:
return None
checkpoints_weights = {}
model = None
default_embed = None
# LOAD CHECKPOINTS
for path_to_checkpoint in list_of_checkpoint_paths:
print("loading model {}".format(path_to_checkpoint))
model, default_embed = load_func(path=path_to_checkpoint)
checkpoints_weights[path_to_checkpoint] = dict(model.named_parameters())
# AVERAGE CHECKPOINTS
params = model.named_parameters()
dict_params = dict(params)
checkpoint_amount = len(checkpoints_weights)
print("averaging...")
for name in dict_params.keys():
custom_params = None
for _, checkpoint_parameters in checkpoints_weights.items():
if custom_params is None:
custom_params = checkpoint_parameters[name].data
else:
custom_params += checkpoint_parameters[name].data
dict_params[name].data.copy_(custom_params / checkpoint_amount)
model_dict = model.state_dict()
model_dict.update(dict_params)
model.load_state_dict(model_dict)
model.eval()
return model, default_embed
def save_model_for_use(model, name="", default_embed=None, dict_name="model"):
print("saving model...")
torch.save({dict_name: model.state_dict(), "default_emb": default_embed, "config": model.config}, name)
print("...done!")
def make_best_in_all():
for model_dir in os.listdir(MODELS_DIR):
if os.path.isdir(os.path.join(MODELS_DIR, model_dir)):
if "ToucanTTS" in model_dir:
checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=os.path.join(MODELS_DIR, model_dir), n=3)
if checkpoint_paths is None:
continue
averaged_model, default_embed = average_checkpoints(checkpoint_paths, load_func=load_net_toucan)
save_model_for_use(model=averaged_model, default_embed=default_embed, name=os.path.join(MODELS_DIR, model_dir, "best.pt"))
def count_parameters(net):
return sum(p.numel() for p in net.parameters() if p.requires_grad)
if __name__ == '__main__':
make_best_in_all()