#!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) from typing import Dict, List import onnxruntime import soundfile import torch def display(sess): for i in sess.get_inputs(): print(i) print("-" * 10) for o in sess.get_outputs(): print(o) class OnnxModel: def __init__( self, model: str, ): session_opts = onnxruntime.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 4 self.session_opts = session_opts self.model = onnxruntime.InferenceSession( model, sess_options=self.session_opts, ) display(self.model) meta = self.model.get_modelmeta().custom_metadata_map self.add_blank = int(meta["add_blank"]) self.sample_rate = int(meta["sample_rate"]) self.punctuation = meta["punctuation"].split() print(meta) def __call__( self, x: torch.Tensor, ) -> torch.Tensor: """ Args: x: A int64 tensor of shape (L,) """ x = x.unsqueeze(0) x_length = torch.tensor([x.shape[1]], dtype=torch.int64) noise_scale = torch.tensor([1], dtype=torch.float32) length_scale = torch.tensor([1], dtype=torch.float32) noise_scale_w = torch.tensor([1], dtype=torch.float32) y = self.model.run( [ self.model.get_outputs()[0].name, ], { self.model.get_inputs()[0].name: x.numpy(), self.model.get_inputs()[1].name: x_length.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), self.model.get_inputs()[3].name: length_scale.numpy(), self.model.get_inputs()[4].name: noise_scale_w.numpy(), }, )[0] return torch.from_numpy(y).squeeze() def read_lexicon() -> Dict[str, List[str]]: ans = dict() with open("./lexicon.txt", encoding="utf-8") as f: for line in f: w_p = line.split() w = w_p[0] p = w_p[1:] ans[w] = p return ans def read_tokens() -> Dict[str, int]: ans = dict() with open("./tokens.txt", encoding="utf-8") as f: for line in f: t_i = line.strip().split() if len(t_i) == 1: token = " " idx = t_i[0] else: assert len(t_i) == 2, (t_i, line) token = t_i[0] idx = t_i[1] ans[token] = int(idx) return ans def convert_lexicon(lexicon, tokens): for w in lexicon: phones = lexicon[w] try: p = [tokens[i] for i in phones] lexicon[w] = p except Exception: # print("skip", w) continue """ skip rapprochement skip croissants skip aix-en-provence skip provence skip croissant skip denouement skip hola skip blanc """ def get_text(text, lexicon, tokens, punctuation): text = text.lower().split() ans = [] for i in range(len(text)): w = text[i] punct = None if w[0] in punctuation: ans.append(tokens[w[0]]) w = w[1:] if w[-1] in punctuation: punct = tokens[w[-1]] w = w[:-1] if w in lexicon: ans.extend(lexicon[w]) if punct: ans.append(punct) if i != len(text) - 1: ans.append(tokens[" "]) continue print("ignore", w) return ans def generate(model, text, lexicon, tokens): x = get_text( text, lexicon, tokens, model.punctuation, ) if model.add_blank: x2 = [0] * (2 * len(x) + 1) x2[1::2] = x x = x2 x = torch.tensor(x, dtype=torch.int64) y = model(x) return y def main(): model = OnnxModel("./vits-ljs.onnx") lexicon = read_lexicon() tokens = read_tokens() convert_lexicon(lexicon, tokens) text = "Liliana, our most beautiful and lovely assistant" y = generate(model, text, lexicon, tokens) soundfile.write("test-0.wav", y.numpy(), model.sample_rate) text = "Ask not what your country can do for you; ask what you can do for your country." y = generate(model, text, lexicon, tokens) soundfile.write("test-1.wav", y.numpy(), model.sample_rate) text = "Success is not final, failure is not fatal, it is the courage to continue that counts!" y = generate(model, text, lexicon, tokens) soundfile.write("test-2.wav", y.numpy(), model.sample_rate) if __name__ == "__main__": main()