import argparse import time import numpy as np import onnx from onnxsim import simplify import onnxruntime as ort import onnxoptimizer import torch from model_onnx import SynthesizerTrn import utils from hubert import hubert_model_onnx def main(HubertExport,NetExport): path = "NyaruTaffy" if(HubertExport): device = torch.device("cuda") hubert_soft = utils.get_hubert_model() test_input = torch.rand(1, 1, 16000) input_names = ["source"] output_names = ["embed"] torch.onnx.export(hubert_soft.to(device), test_input.to(device), "hubert3.0.onnx", dynamic_axes={ "source": { 2: "sample_length" } }, verbose=False, opset_version=13, input_names=input_names, output_names=output_names) if(NetExport): device = torch.device("cuda") hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") SVCVITS = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model) _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None) _ = SVCVITS.eval().to(device) for i in SVCVITS.parameters(): i.requires_grad = False test_hidden_unit = torch.rand(1, 50, 256) test_lengths = torch.LongTensor([50]) test_pitch = torch.rand(1, 50) test_sid = torch.LongTensor([0]) input_names = ["hidden_unit", "lengths", "pitch", "sid"] output_names = ["audio", ] SVCVITS.eval() torch.onnx.export(SVCVITS, ( test_hidden_unit.to(device), test_lengths.to(device), test_pitch.to(device), test_sid.to(device) ), f"checkpoints/{path}/model.onnx", dynamic_axes={ "hidden_unit": [0, 1], "pitch": [1] }, do_constant_folding=False, opset_version=16, verbose=False, input_names=input_names, output_names=output_names) if __name__ == '__main__': main(False,True)