csukuangfj commited on
Commit
bd651de
1 Parent(s): 2d99181
Files changed (1) hide show
  1. export-onnx-zh-hf-fanchen-models.py +0 -144
export-onnx-zh-hf-fanchen-models.py DELETED
@@ -1,144 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- import sys
4
-
5
- sys.path.insert(0, "VITS-fast-fine-tuning")
6
-
7
- import os
8
- from pathlib import Path
9
- from typing import Any, Dict
10
-
11
- import onnx
12
- import torch
13
- import utils
14
- from models import SynthesizerTrn
15
-
16
-
17
- class OnnxModel(torch.nn.Module):
18
- def __init__(self, model: SynthesizerTrn):
19
- super().__init__()
20
- self.model = model
21
-
22
- def forward(
23
- self,
24
- x,
25
- x_lengths,
26
- noise_scale=1,
27
- length_scale=1,
28
- noise_scale_w=1.0,
29
- sid=0,
30
- max_len=None,
31
- ):
32
- return self.model.infer(
33
- x=x,
34
- x_lengths=x_lengths,
35
- sid=sid,
36
- noise_scale=noise_scale,
37
- length_scale=length_scale,
38
- noise_scale_w=noise_scale_w,
39
- max_len=max_len,
40
- )[0]
41
-
42
-
43
- def add_meta_data(filename: str, meta_data: Dict[str, Any]):
44
- """Add meta data to an ONNX model. It is changed in-place.
45
-
46
- Args:
47
- filename:
48
- Filename of the ONNX model to be changed.
49
- meta_data:
50
- Key-value pairs.
51
- """
52
- model = onnx.load(filename)
53
- for key, value in meta_data.items():
54
- meta = model.metadata_props.add()
55
- meta.key = key
56
- meta.value = str(value)
57
-
58
- onnx.save(model, filename)
59
-
60
-
61
- @torch.no_grad()
62
- def main():
63
- name = os.environ.get("NAME", None)
64
- if not name:
65
- print("Please provide the environment variable NAME")
66
- return
67
-
68
- print("name", name)
69
-
70
- if name == "C":
71
- model_path = "G_C.pth"
72
- config_path = "G_C.json"
73
- elif name == "ZhiHuiLaoZhe":
74
- model_path = "G_lkz_lao_new_new1_latest.pth"
75
- config_path = "G_lkz_lao_new_new1_latest.json"
76
- elif name == "ZhiHuiLaoZhe_new":
77
- model_path = "G_lkz_unity_onnx_new1_latest.pth"
78
- config_path = "G_lkz_unity_onnx_new1_latest.json"
79
- else:
80
- model_path = f"G_{name}_latest.pth"
81
- config_path = f"G_{name}_latest.json"
82
-
83
- print(name, model_path, config_path)
84
- hps = utils.get_hparams_from_file(config_path)
85
- net_g = SynthesizerTrn(
86
- len(hps.symbols),
87
- hps.data.filter_length // 2 + 1,
88
- hps.train.segment_size // hps.data.hop_length,
89
- n_speakers=hps.data.n_speakers,
90
- **hps.model,
91
- )
92
- _ = net_g.eval()
93
- _ = utils.load_checkpoint(model_path, net_g, None)
94
-
95
- x = torch.randint(low=1, high=50, size=(50,), dtype=torch.int64)
96
- x = x.unsqueeze(0)
97
-
98
- x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
99
- noise_scale = torch.tensor([1], dtype=torch.float32)
100
- length_scale = torch.tensor([1], dtype=torch.float32)
101
- noise_scale_w = torch.tensor([1], dtype=torch.float32)
102
- sid = torch.tensor([0], dtype=torch.int64)
103
-
104
- model = OnnxModel(net_g)
105
-
106
- opset_version = 13
107
-
108
- filename = f"vits-zh-hf-fanchen-{name}.onnx"
109
-
110
- torch.onnx.export(
111
- model,
112
- (x, x_length, noise_scale, length_scale, noise_scale_w, sid),
113
- filename,
114
- opset_version=opset_version,
115
- input_names=[
116
- "x",
117
- "x_length",
118
- "noise_scale",
119
- "length_scale",
120
- "noise_scale_w",
121
- "sid",
122
- ],
123
- output_names=["y"],
124
- dynamic_axes={
125
- "x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
126
- "x_length": {0: "N"},
127
- "y": {0: "N", 2: "L"},
128
- },
129
- )
130
- meta_data = {
131
- "model_type": "vits",
132
- "comment": f"hf-vits-models-fanchen-{name}",
133
- "language": "Chinese",
134
- "add_blank": int(hps.data.add_blank),
135
- "n_speakers": int(hps.data.n_speakers),
136
- "sample_rate": hps.data.sampling_rate,
137
- "punctuation": ", . : ; ! ? , 。 : ; ! ? 、",
138
- }
139
- print("meta_data", meta_data)
140
- add_meta_data(filename=filename, meta_data=meta_data)
141
-
142
-
143
- if __name__ == "__main__":
144
- main()