#!/usr/bin/env python3 import sys sys.path.insert(0, "VITS-fast-fine-tuning") import os from pathlib import Path from typing import Any, Dict import onnx import torch import utils from models import SynthesizerTrn class OnnxModel(torch.nn.Module): def __init__(self, model: SynthesizerTrn): super().__init__() self.model = model def forward( self, x, x_lengths, noise_scale=1, length_scale=1, noise_scale_w=1.0, sid=0, max_len=None, ): return self.model.infer( x=x, x_lengths=x_lengths, sid=sid, noise_scale=noise_scale, length_scale=length_scale, noise_scale_w=noise_scale_w, max_len=max_len, )[0] def add_meta_data(filename: str, meta_data: Dict[str, Any]): """Add meta data to an ONNX model. It is changed in-place. Args: filename: Filename of the ONNX model to be changed. meta_data: Key-value pairs. """ model = onnx.load(filename) for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key meta.value = str(value) onnx.save(model, filename) @torch.no_grad() def main(): name = os.environ.get("NAME", None) if not name: print("Please provide the environment variable NAME") return print("name", name) if name == "C": model_path = "G_C.pth" config_path = "G_C.json" elif name == "ZhiHuiLaoZhe": model_path = "G_lkz_lao_new_new1_latest.pth" config_path = "G_lkz_lao_new_new1_latest.json" elif name == "ZhiHuiLaoZhe_new": model_path = "G_lkz_unity_onnx_new1_latest.pth" config_path = "G_lkz_unity_onnx_new1_latest.json" else: model_path = f"G_{name}_latest.pth" config_path = f"G_{name}_latest.json" print(name, model_path, config_path) hps = utils.get_hparams_from_file(config_path) net_g = SynthesizerTrn( len(hps.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model, ) _ = net_g.eval() _ = utils.load_checkpoint(model_path, net_g, None) x = torch.randint(low=1, high=50, size=(50,), dtype=torch.int64) x = x.unsqueeze(0) x_length = torch.tensor([x.shape[1]], dtype=torch.int64) noise_scale = torch.tensor([1], dtype=torch.float32) length_scale = torch.tensor([1], dtype=torch.float32) noise_scale_w = torch.tensor([1], dtype=torch.float32) sid = torch.tensor([0], dtype=torch.int64) model = OnnxModel(net_g) opset_version = 13 filename = f"vits-zh-hf-fanchen-{name}.onnx" torch.onnx.export( model, (x, x_length, noise_scale, length_scale, noise_scale_w, sid), filename, opset_version=opset_version, input_names=[ "x", "x_length", "noise_scale", "length_scale", "noise_scale_w", "sid", ], output_names=["y"], dynamic_axes={ "x": {0: "N", 1: "L"}, # n_audio is also known as batch_size "x_length": {0: "N"}, "y": {0: "N", 2: "L"}, }, ) meta_data = { "model_type": "vits", "comment": f"hf-vits-models-fanchen-{name}", "language": "Chinese", "add_blank": int(hps.data.add_blank), "n_speakers": int(hps.data.n_speakers), "sample_rate": hps.data.sampling_rate, "punctuation": ", . : ; ! ? , 。 : ; ! ? 、", } print("meta_data", meta_data) add_meta_data(filename=filename, meta_data=meta_data) if __name__ == "__main__": main()