""" https://alexander-stasiuk.medium.com/pytorch-weights-averaging-e2c0fa611a0c """ import os import torch from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN from Utility.storage_config import MODELS_DIR def load_net_toucan(path): check_dict = torch.load(path, map_location=torch.device("cpu")) net = ToucanTTS(weights=check_dict["model"], config=check_dict["config"]) return net, check_dict["default_emb"] def load_net_bigvgan(path): check_dict = torch.load(path, map_location=torch.device("cpu")) net = HiFiGAN(weights=check_dict["generator"]) return net, None def get_n_recent_checkpoints_paths(checkpoint_dir, n=5): print("selecting checkpoints...") checkpoint_list = list() for el in os.listdir(checkpoint_dir): if el.endswith(".pt") and el.startswith("checkpoint_"): try: checkpoint_list.append(int(el.split(".")[0].split("_")[1])) except RuntimeError: pass if len(checkpoint_list) == 0: return None elif len(checkpoint_list) < n: n = len(checkpoint_list) checkpoint_list.sort(reverse=True) return [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:n]] def average_checkpoints(list_of_checkpoint_paths, load_func): # COLLECT CHECKPOINTS if list_of_checkpoint_paths is None or len(list_of_checkpoint_paths) == 0: return None checkpoints_weights = {} model = None default_embed = None # LOAD CHECKPOINTS for path_to_checkpoint in list_of_checkpoint_paths: print("loading model {}".format(path_to_checkpoint)) model, default_embed = load_func(path=path_to_checkpoint) checkpoints_weights[path_to_checkpoint] = dict(model.named_parameters()) # AVERAGE CHECKPOINTS params = model.named_parameters() dict_params = dict(params) checkpoint_amount = len(checkpoints_weights) print("averaging...") for name in dict_params.keys(): custom_params = None for _, checkpoint_parameters in checkpoints_weights.items(): if custom_params is None: custom_params = checkpoint_parameters[name].data else: custom_params += checkpoint_parameters[name].data dict_params[name].data.copy_(custom_params / checkpoint_amount) model_dict = model.state_dict() model_dict.update(dict_params) model.load_state_dict(model_dict) model.eval() return model, default_embed def save_model_for_use(model, name="", default_embed=None, dict_name="model"): print("saving model...") torch.save({dict_name: model.state_dict(), "default_emb": default_embed, "config": model.config}, name) print("...done!") def make_best_in_all(): for model_dir in os.listdir(MODELS_DIR): if os.path.isdir(os.path.join(MODELS_DIR, model_dir)): if "ToucanTTS" in model_dir: checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=os.path.join(MODELS_DIR, model_dir), n=3) if checkpoint_paths is None: continue averaged_model, default_embed = average_checkpoints(checkpoint_paths, load_func=load_net_toucan) save_model_for_use(model=averaged_model, default_embed=default_embed, name=os.path.join(MODELS_DIR, model_dir, "best.pt")) def count_parameters(net): return sum(p.numel() for p in net.parameters() if p.requires_grad) if __name__ == '__main__': make_best_in_all()