|
|
|
|
|
|
|
from typing import Dict, List |
|
|
|
import onnxruntime |
|
import soundfile |
|
import torch |
|
|
|
|
|
def display(sess): |
|
for i in sess.get_inputs(): |
|
print(i) |
|
|
|
print("-" * 10) |
|
for o in sess.get_outputs(): |
|
print(o) |
|
|
|
|
|
class OnnxModel: |
|
def __init__( |
|
self, |
|
model: str, |
|
): |
|
session_opts = onnxruntime.SessionOptions() |
|
session_opts.inter_op_num_threads = 1 |
|
session_opts.intra_op_num_threads = 4 |
|
|
|
self.session_opts = session_opts |
|
|
|
self.model = onnxruntime.InferenceSession( |
|
model, |
|
sess_options=self.session_opts, |
|
) |
|
display(self.model) |
|
|
|
meta = self.model.get_modelmeta().custom_metadata_map |
|
self.add_blank = int(meta["add_blank"]) |
|
self.sample_rate = int(meta["sample_rate"]) |
|
self.punctuation = meta["punctuation"].split() |
|
print(meta) |
|
|
|
def __call__( |
|
self, |
|
x: torch.Tensor, |
|
sid: int, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: |
|
A int64 tensor of shape (L,) |
|
""" |
|
x = x.unsqueeze(0) |
|
x_length = torch.tensor([x.shape[1]], dtype=torch.int64) |
|
noise_scale = torch.tensor([1], dtype=torch.float32) |
|
length_scale = torch.tensor([1], dtype=torch.float32) |
|
noise_scale_w = torch.tensor([1], dtype=torch.float32) |
|
sid = torch.tensor([sid], dtype=torch.int64) |
|
|
|
y = self.model.run( |
|
[ |
|
self.model.get_outputs()[0].name, |
|
], |
|
{ |
|
self.model.get_inputs()[0].name: x.numpy(), |
|
self.model.get_inputs()[1].name: x_length.numpy(), |
|
self.model.get_inputs()[2].name: noise_scale.numpy(), |
|
self.model.get_inputs()[3].name: length_scale.numpy(), |
|
self.model.get_inputs()[4].name: noise_scale_w.numpy(), |
|
self.model.get_inputs()[5].name: sid.numpy(), |
|
}, |
|
)[0] |
|
return torch.from_numpy(y).squeeze() |
|
|
|
|
|
def read_lexicon() -> Dict[str, List[str]]: |
|
ans = dict() |
|
with open("./lexicon.txt", encoding="utf-8") as f: |
|
for line in f: |
|
w_p = line.split() |
|
w = w_p[0] |
|
p = w_p[1:] |
|
ans[w] = p |
|
return ans |
|
|
|
|
|
def read_tokens() -> Dict[str, int]: |
|
ans = dict() |
|
with open("./tokens.txt", encoding="utf-8") as f: |
|
for line in f: |
|
t_i = line.strip().split() |
|
if len(t_i) == 1: |
|
token = " " |
|
idx = t_i[0] |
|
else: |
|
assert len(t_i) == 2, (t_i, line) |
|
token = t_i[0] |
|
idx = t_i[1] |
|
ans[token] = int(idx) |
|
return ans |
|
|
|
|
|
def convert_lexicon(lexicon, tokens): |
|
for w in lexicon: |
|
phones = lexicon[w] |
|
try: |
|
p = [tokens[i] for i in phones] |
|
lexicon[w] = p |
|
except Exception: |
|
|
|
continue |
|
|
|
|
|
""" |
|
skip rapprochement |
|
skip croissants |
|
skip aix-en-provence |
|
skip provence |
|
skip croissant |
|
skip denouement |
|
skip hola |
|
skip blanc |
|
""" |
|
|
|
|
|
def get_text(text, lexicon, tokens, punctuation): |
|
text = text.lower().split() |
|
ans = [] |
|
for i in range(len(text)): |
|
w = text[i] |
|
punct = None |
|
|
|
if w[0] in punctuation: |
|
ans.append(tokens[w[0]]) |
|
w = w[1:] |
|
|
|
if w[-1] in punctuation: |
|
punct = tokens[w[-1]] |
|
w = w[:-1] |
|
|
|
if w in lexicon: |
|
ans.extend(lexicon[w]) |
|
if punct: |
|
ans.append(punct) |
|
|
|
if i != len(text) - 1: |
|
ans.append(tokens[" "]) |
|
continue |
|
print("ignore", w) |
|
return ans |
|
|
|
|
|
def generate(model, text, lexicon, tokens, sid): |
|
x = get_text( |
|
text, |
|
lexicon, |
|
tokens, |
|
model.punctuation, |
|
) |
|
if model.add_blank: |
|
x2 = [0] * (2 * len(x) + 1) |
|
x2[1::2] = x |
|
x = x2 |
|
|
|
x = torch.tensor(x, dtype=torch.int64) |
|
|
|
y = model(x, sid=sid) |
|
|
|
return y |
|
|
|
|
|
def main(): |
|
model = OnnxModel("./vits-vctk.onnx") |
|
|
|
lexicon = read_lexicon() |
|
tokens = read_tokens() |
|
convert_lexicon(lexicon, tokens) |
|
|
|
text = "Liliana, our most beautiful and lovely assistant" |
|
y = generate(model, text, lexicon, tokens, sid=0) |
|
soundfile.write("test-0.wav", y.numpy(), model.sample_rate) |
|
|
|
text = "Ask not what your country can do for you; ask what you can do for your country." |
|
y = generate(model, text, lexicon, tokens, sid=1) |
|
soundfile.write("test-1.wav", y.numpy(), model.sample_rate) |
|
|
|
text = "Success is not final, failure is not fatal, it is the courage to continue that counts!" |
|
y = generate(model, text, lexicon, tokens, sid=2) |
|
soundfile.write("test-2.wav", y.numpy(), model.sample_rate) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|