TeraSpace commited on
Commit
151e67b
1 Parent(s): 1959e6a

Update infer_onnx.py

Browse files
Files changed (1) hide show
  1. infer_onnx.py +8 -4
infer_onnx.py CHANGED
@@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
6
  from num2words import num2words
7
  import re
8
  from transliterate import translit
 
9
 
10
  class TTS:
11
  def __init__(self, model_name: str, save_path: str = "./model", add_time_to_end: float = 0.8) -> None:
@@ -22,7 +23,9 @@ class TTS:
22
  )
23
 
24
  self.model = onnxruntime.InferenceSession(os.path.join(model_dir, "exported/model.onnx"), providers=['CPUExecutionProvider'])
25
-
 
 
26
  if os.path.exists(os.path.join(model_dir, "exported/dictionary.txt")):
27
  from tokenizer import TokenizerG2P
28
  print("Use g2p")
@@ -43,9 +46,9 @@ class TTS:
43
  return audio_with_silence
44
 
45
 
46
- def save_wav(self, audio, path:str):
47
  '''save audio to wav'''
48
- scipy.io.wavfile.write(path, 22050, audio)
49
 
50
 
51
  def _intersperse(self, lst, item):
@@ -83,5 +86,6 @@ class TTS:
83
  "sid": None,
84
  },
85
  )[0][0,0][0]
86
- audio = self._add_silent(audio, silence_duration = self.add_time_to_end)
 
87
  return audio
 
6
  from num2words import num2words
7
  import re
8
  from transliterate import translit
9
+ import json
10
 
11
  class TTS:
12
  def __init__(self, model_name: str, save_path: str = "./model", add_time_to_end: float = 0.8) -> None:
 
23
  )
24
 
25
  self.model = onnxruntime.InferenceSession(os.path.join(model_dir, "exported/model.onnx"), providers=['CPUExecutionProvider'])
26
+ with open(os.path.join(model_dir, "exported/config.json")) as config_file:
27
+ self.config = json.load(config_file)["model_config"]
28
+
29
  if os.path.exists(os.path.join(model_dir, "exported/dictionary.txt")):
30
  from tokenizer import TokenizerG2P
31
  print("Use g2p")
 
46
  return audio_with_silence
47
 
48
 
49
+ def save_wav(self, audio, path:str, sample_rate: int = 22050):
50
  '''save audio to wav'''
51
+ scipy.io.wavfile.write(path, sample_rate, audio)
52
 
53
 
54
  def _intersperse(self, lst, item):
 
86
  "sid": None,
87
  },
88
  )[0][0,0][0]
89
+
90
+ audio = self._add_silent(audio, silence_duration = self.add_time_to_end, sample_rate=self.config["samplerate"])
91
  return audio