VITS2-Chinese / export_onnx.py
kevinwang676's picture
Upload folder using huggingface_hub
de332ab
import argparse
from pathlib import Path
from typing import Optional
import torch
import utils
from models import SynthesizerTrn
from text.symbols import symbols
OPSET_VERSION = 15
def main() -> None:
torch.manual_seed(1234)
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path", required=True, help="Path to model weights (.pth)"
)
parser.add_argument(
"--config-path", required=True, help="Path to model config (.json)"
)
parser.add_argument("--output", required=True, help="Path to output model (.onnx)")
args = parser.parse_args()
args.model_path = Path(args.model_path)
args.config_path = Path(args.config_path)
args.output = Path(args.output)
args.output.parent.mkdir(parents=True, exist_ok=True)
hps = utils.get_hparams_from_file(args.config_path)
if (
"use_mel_posterior_encoder" in hps.model.keys()
and hps.model.use_mel_posterior_encoder == True
):
print("Using mel posterior encoder for VITS2")
posterior_channels = 80 # vits2
hps.data.use_mel_posterior_encoder = True
else:
print("Using lin posterior encoder for VITS1")
posterior_channels = hps.data.filter_length // 2 + 1
hps.data.use_mel_posterior_encoder = False
model_g = SynthesizerTrn(
len(symbols),
posterior_channels,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
_ = model_g.eval()
_ = utils.load_checkpoint(args.model_path, model_g, None)
def infer_forward(text, text_lengths, scales, sid=None):
noise_scale = scales[0]
length_scale = scales[1]
noise_scale_w = scales[2]
audio = model_g.infer(
text,
text_lengths,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
sid=sid,
)[0]
return audio
model_g.forward = infer_forward
dummy_input_length = 50
sequences = torch.randint(
low=0, high=len(symbols), size=(1, dummy_input_length), dtype=torch.long
)
sequence_lengths = torch.LongTensor([sequences.size(1)])
sid: Optional[torch.LongTensor] = None
if hps.data.n_speakers > 1:
sid = torch.LongTensor([0])
# noise, length, noise_w
scales = torch.FloatTensor([0.667, 1.0, 0.8])
dummy_input = (sequences, sequence_lengths, scales, sid)
# Export
torch.onnx.export(
model=model_g,
args=dummy_input,
f=str(args.output),
verbose=False,
opset_version=OPSET_VERSION,
input_names=["input", "input_lengths", "scales", "sid"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "phonemes"},
"input_lengths": {0: "batch_size"},
"output": {0: "batch_size", 1: "time1", 2: "time2"},
},
)
print(f"Exported model to {args.output}")
if __name__ == "__main__":
main()