from collections import OrderedDict from text.symbols import symbols import torch from tools.log import logger import utils from models import SynthesizerTrn import os def copyStateDict(state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ",".join(k.split(".")[start_idx:]) new_state_dict[name] = v return new_state_dict def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str): hps = utils.get_hparams_from_file(config) net_g = SynthesizerTrn( len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model, ) optim_g = torch.optim.AdamW( net_g.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps, ) state_dict_g = torch.load(input_model, map_location="cpu") new_dict_g = copyStateDict(state_dict_g) keys = [] for k, v in new_dict_g["model"].items(): if "enc_q" in k: continue # noqa: E701 keys.append(k) new_dict_g = ( {k: new_dict_g["model"][k].half() for k in keys} if ishalf else {k: new_dict_g["model"][k] for k in keys} ) torch.save( { "model": new_dict_g, "iteration": 0, "optimizer": optim_g.state_dict(), "learning_rate": 0.0001, }, output_model, ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("-c", "--config", type=str, default="configs/config.json") parser.add_argument("-i", "--input", type=str) parser.add_argument("-o", "--output", type=str, default=None) parser.add_argument( "-hf", "--half", action="store_true", default=False, help="Save as FP16" ) args = parser.parse_args() output = args.output if output is None: import os.path filename, ext = os.path.splitext(args.input) half = "_half" if args.half else "" output = filename + "_release" + half + ext removeOptimizer(args.config, args.input, args.half, output) logger.info(f"压缩模型成功, 输出模型: {os.path.abspath(output)}")