vits-vctk / test.py
csukuangfj's picture
add onnx models
aeee641
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
from typing import Dict, List
import onnxruntime
import soundfile
import torch
def display(sess):
for i in sess.get_inputs():
print(i)
print("-" * 10)
for o in sess.get_outputs():
print(o)
class OnnxModel:
def __init__(
self,
model: str,
):
session_opts = onnxruntime.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = onnxruntime.InferenceSession(
model,
sess_options=self.session_opts,
)
display(self.model)
meta = self.model.get_modelmeta().custom_metadata_map
self.add_blank = int(meta["add_blank"])
self.sample_rate = int(meta["sample_rate"])
self.punctuation = meta["punctuation"].split()
print(meta)
def __call__(
self,
x: torch.Tensor,
sid: int,
) -> torch.Tensor:
"""
Args:
x:
A int64 tensor of shape (L,)
"""
x = x.unsqueeze(0)
x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
length_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_w = torch.tensor([1], dtype=torch.float32)
sid = torch.tensor([sid], dtype=torch.int64)
y = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_length.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: length_scale.numpy(),
self.model.get_inputs()[4].name: noise_scale_w.numpy(),
self.model.get_inputs()[5].name: sid.numpy(),
},
)[0]
return torch.from_numpy(y).squeeze()
def read_lexicon() -> Dict[str, List[str]]:
ans = dict()
with open("./lexicon.txt", encoding="utf-8") as f:
for line in f:
w_p = line.split()
w = w_p[0]
p = w_p[1:]
ans[w] = p
return ans
def read_tokens() -> Dict[str, int]:
ans = dict()
with open("./tokens.txt", encoding="utf-8") as f:
for line in f:
t_i = line.strip().split()
if len(t_i) == 1:
token = " "
idx = t_i[0]
else:
assert len(t_i) == 2, (t_i, line)
token = t_i[0]
idx = t_i[1]
ans[token] = int(idx)
return ans
def convert_lexicon(lexicon, tokens):
for w in lexicon:
phones = lexicon[w]
try:
p = [tokens[i] for i in phones]
lexicon[w] = p
except Exception:
# print("skip", w)
continue
"""
skip rapprochement
skip croissants
skip aix-en-provence
skip provence
skip croissant
skip denouement
skip hola
skip blanc
"""
def get_text(text, lexicon, tokens, punctuation):
text = text.lower().split()
ans = []
for i in range(len(text)):
w = text[i]
punct = None
if w[0] in punctuation:
ans.append(tokens[w[0]])
w = w[1:]
if w[-1] in punctuation:
punct = tokens[w[-1]]
w = w[:-1]
if w in lexicon:
ans.extend(lexicon[w])
if punct:
ans.append(punct)
if i != len(text) - 1:
ans.append(tokens[" "])
continue
print("ignore", w)
return ans
def generate(model, text, lexicon, tokens, sid):
x = get_text(
text,
lexicon,
tokens,
model.punctuation,
)
if model.add_blank:
x2 = [0] * (2 * len(x) + 1)
x2[1::2] = x
x = x2
x = torch.tensor(x, dtype=torch.int64)
y = model(x, sid=sid)
return y
def main():
model = OnnxModel("./vits-vctk.onnx")
lexicon = read_lexicon()
tokens = read_tokens()
convert_lexicon(lexicon, tokens)
text = "Liliana, our most beautiful and lovely assistant"
y = generate(model, text, lexicon, tokens, sid=0)
soundfile.write("test-0.wav", y.numpy(), model.sample_rate)
text = "Ask not what your country can do for you; ask what you can do for your country."
y = generate(model, text, lexicon, tokens, sid=1)
soundfile.write("test-1.wav", y.numpy(), model.sample_rate)
text = "Success is not final, failure is not fatal, it is the courage to continue that counts!"
y = generate(model, text, lexicon, tokens, sid=2)
soundfile.write("test-2.wav", y.numpy(), model.sample_rate)
if __name__ == "__main__":
main()