Spaces:
Running
Running
import sys,os | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
import torch | |
import argparse | |
from omegaconf import OmegaConf | |
from vits.models import SynthesizerInfer | |
def load_model(checkpoint_path, model): | |
assert os.path.isfile(checkpoint_path) | |
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | |
saved_state_dict = checkpoint_dict["model_g"] | |
if hasattr(model, "module"): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
try: | |
new_state_dict[k] = saved_state_dict[k] | |
except: | |
new_state_dict[k] = v | |
if hasattr(model, "module"): | |
model.module.load_state_dict(new_state_dict) | |
else: | |
model.load_state_dict(new_state_dict) | |
return model | |
def save_pretrain(checkpoint_path, save_path): | |
assert os.path.isfile(checkpoint_path) | |
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | |
torch.save({ | |
'model_g': checkpoint_dict['model_g'], | |
'model_d': checkpoint_dict['model_d'], | |
}, save_path) | |
def save_model(model, checkpoint_path): | |
if hasattr(model, 'module'): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
torch.save({'model_g': state_dict}, checkpoint_path) | |
def main(args): | |
hp = OmegaConf.load(args.config) | |
model = SynthesizerInfer( | |
hp.data.filter_length // 2 + 1, | |
hp.data.segment_size // hp.data.hop_length, | |
hp) | |
# save_pretrain(args.checkpoint_path, "sovits5.0.pretrain.pth") | |
load_model(args.checkpoint_path, model) | |
save_model(model, "sovits5.0.pth") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-c', '--config', type=str, required=True, | |
help="yaml file for config. will use hp_str from checkpoint if not given.") | |
parser.add_argument('-p', '--checkpoint_path', type=str, required=True, | |
help="path of checkpoint pt file for evaluation") | |
args = parser.parse_args() | |
main(args) | |