Spaces:
Runtime error
Runtime error
File size: 2,288 Bytes
58fbdee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from collections import OrderedDict
import torch
import utils
from models import SynthesizerTrn
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith('module'):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ','.join(k.split('.')[start_idx:])
new_state_dict[name] = v
return new_state_dict
def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
hps = utils.get_hparams_from_file(config)
net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
optim_g = torch.optim.AdamW(net_g.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps)
state_dict_g = torch.load(input_model, map_location="cpu")
new_dict_g = copyStateDict(state_dict_g)
keys = []
for k, v in new_dict_g['model'].items():
if "enc_q" in k: continue # noqa: E701
keys.append(k)
new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys}
torch.save(
{
'model': new_dict_g,
'iteration': 0,
'optimizer': optim_g.state_dict(),
'learning_rate': 0.0001
}, output_model)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-c",
"--config",
type=str,
default='configs/config.json')
parser.add_argument("-i", "--input", type=str)
parser.add_argument("-o", "--output", type=str, default=None)
parser.add_argument('-hf', '--half', action='store_true', default=False, help='Save as FP16')
args = parser.parse_args()
output = args.output
if output is None:
import os.path
filename, ext = os.path.splitext(args.input)
half = "_half" if args.half else ""
output = filename + "_release" + half + ext
removeOptimizer(args.config, args.input, args.half, output) |