leo4.0 / onnx /onnx_export_48k.py
chilge's picture
Duplicate from innnky/sovits4
d18bd24
raw
history blame
2.64 kB
import argparse
import time
import numpy as np
import onnx
from onnxsim import simplify
import onnxruntime as ort
import onnxoptimizer
import torch
from model_onnx_48k import SynthesizerTrn
import utils
from hubert import hubert_model_onnx
def main(HubertExport,NetExport):
path = "NyaruTaffy"
if(HubertExport):
device = torch.device("cuda")
hubert_soft = hubert_model_onnx.hubert_soft("hubert/model.pt")
test_input = torch.rand(1, 1, 16000)
input_names = ["source"]
output_names = ["embed"]
torch.onnx.export(hubert_soft.to(device),
test_input.to(device),
"hubert3.0.onnx",
dynamic_axes={
"source": {
2: "sample_length"
}
},
verbose=False,
opset_version=13,
input_names=input_names,
output_names=output_names)
if(NetExport):
device = torch.device("cuda")
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
SVCVITS = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
_ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
_ = SVCVITS.eval().to(device)
for i in SVCVITS.parameters():
i.requires_grad = False
test_hidden_unit = torch.rand(1, 50, 256)
test_lengths = torch.LongTensor([50])
test_pitch = torch.rand(1, 50)
test_sid = torch.LongTensor([0])
input_names = ["hidden_unit", "lengths", "pitch", "sid"]
output_names = ["audio", ]
SVCVITS.eval()
torch.onnx.export(SVCVITS,
(
test_hidden_unit.to(device),
test_lengths.to(device),
test_pitch.to(device),
test_sid.to(device)
),
f"checkpoints/{path}/model.onnx",
dynamic_axes={
"hidden_unit": [0, 1],
"pitch": [1]
},
do_constant_folding=False,
opset_version=16,
verbose=False,
input_names=input_names,
output_names=output_names)
if __name__ == '__main__':
main(False,True)