csukuangfj commited on
Commit
ebb01fc
1 Parent(s): 97b6152
Files changed (2) hide show
  1. app.py +14 -4
  2. model.py +7 -5
app.py CHANGED
@@ -71,10 +71,10 @@ def build_html_output(s: str, style: str = "result_item_success"):
71
  """
72
 
73
 
74
- def process(language: str, repo_id: str, text: str, sid: str):
75
- logging.info(f"Input text: {text}. sid: {sid}")
76
  sid = int(sid)
77
- tts = get_pretrained_model(repo_id)
78
 
79
  start = time.time()
80
  audio = tts.generate(text, sid=sid)
@@ -97,7 +97,7 @@ def process(language: str, repo_id: str, text: str, sid: str):
97
  """
98
 
99
  logging.info(info)
100
- logging.info(f"\nrepo_id: {repo_id}\ntext: {text}\nsid: {sid}")
101
 
102
  filename = str(uuid.uuid4())
103
  filename = f"{filename}.wav"
@@ -153,6 +153,15 @@ with demo:
153
  value="0",
154
  placeholder="Speaker ID. Valid only for mult-speaker model",
155
  )
 
 
 
 
 
 
 
 
 
156
  input_button = gr.Button("Submit")
157
 
158
  output_audio = gr.Audio(label="Output")
@@ -166,6 +175,7 @@ with demo:
166
  model_dropdown,
167
  input_text,
168
  input_sid,
 
169
  ],
170
  outputs=[
171
  output_audio,
 
71
  """
72
 
73
 
74
+ def process(language: str, repo_id: str, text: str, sid: str, speed: float):
75
+ logging.info(f"Input text: {text}. sid: {sid}, speed: {speed}")
76
  sid = int(sid)
77
+ tts = get_pretrained_model(repo_id, speed)
78
 
79
  start = time.time()
80
  audio = tts.generate(text, sid=sid)
 
97
  """
98
 
99
  logging.info(info)
100
+ logging.info(f"\nrepo_id: {repo_id}\ntext: {text}\nsid: {sid}\nspeed: {speed}")
101
 
102
  filename = str(uuid.uuid4())
103
  filename = f"{filename}.wav"
 
153
  value="0",
154
  placeholder="Speaker ID. Valid only for mult-speaker model",
155
  )
156
+
157
+ input_speed = gr.Slider(
158
+ minimum=0.1,
159
+ maximum=10,
160
+ value=1,
161
+ step=0.1,
162
+ label="Speed (larger->faster; smaller->slower)",
163
+ )
164
+
165
  input_button = gr.Button("Submit")
166
 
167
  output_audio = gr.Audio(label="Output")
 
175
  model_dropdown,
176
  input_text,
177
  input_sid,
178
+ input_speed,
179
  ],
180
  outputs=[
181
  output_audio,
model.py CHANGED
@@ -34,7 +34,7 @@ def get_file(
34
 
35
 
36
  @lru_cache(maxsize=10)
37
- def _get_vits_vctk(repo_id: str) -> sherpa_onnx.OfflineTts:
38
  assert repo_id == "csukuangfj/vits-vctk"
39
 
40
  model = get_file(
@@ -61,6 +61,7 @@ def _get_vits_vctk(repo_id: str) -> sherpa_onnx.OfflineTts:
61
  model=model,
62
  lexicon=lexicon,
63
  tokens=tokens,
 
64
  ),
65
  provider="cpu",
66
  debug=False,
@@ -73,7 +74,7 @@ def _get_vits_vctk(repo_id: str) -> sherpa_onnx.OfflineTts:
73
 
74
 
75
  @lru_cache(maxsize=10)
76
- def _get_vits_zh_aishell3(repo_id: str) -> sherpa_onnx.OfflineTts:
77
  assert repo_id == "csukuangfj/vits-zh-aishell3"
78
 
79
  model = get_file(
@@ -100,6 +101,7 @@ def _get_vits_zh_aishell3(repo_id: str) -> sherpa_onnx.OfflineTts:
100
  model=model,
101
  lexicon=lexicon,
102
  tokens=tokens,
 
103
  ),
104
  provider="cpu",
105
  debug=False,
@@ -112,11 +114,11 @@ def _get_vits_zh_aishell3(repo_id: str) -> sherpa_onnx.OfflineTts:
112
 
113
 
114
  @lru_cache(maxsize=10)
115
- def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineTts:
116
  if repo_id in chinese_models:
117
- return chinese_models[repo_id](repo_id)
118
  elif repo_id in english_models:
119
- return english_models[repo_id](repo_id)
120
  else:
121
  raise ValueError(f"Unsupported repo_id: {repo_id}")
122
 
 
34
 
35
 
36
  @lru_cache(maxsize=10)
37
+ def _get_vits_vctk(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
38
  assert repo_id == "csukuangfj/vits-vctk"
39
 
40
  model = get_file(
 
61
  model=model,
62
  lexicon=lexicon,
63
  tokens=tokens,
64
+ length_scale=1.0 / speed,
65
  ),
66
  provider="cpu",
67
  debug=False,
 
74
 
75
 
76
  @lru_cache(maxsize=10)
77
+ def _get_vits_zh_aishell3(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
78
  assert repo_id == "csukuangfj/vits-zh-aishell3"
79
 
80
  model = get_file(
 
101
  model=model,
102
  lexicon=lexicon,
103
  tokens=tokens,
104
+ length_scale=1.0 / speed,
105
  ),
106
  provider="cpu",
107
  debug=False,
 
114
 
115
 
116
  @lru_cache(maxsize=10)
117
+ def get_pretrained_model(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
118
  if repo_id in chinese_models:
119
+ return chinese_models[repo_id](repo_id, speed)
120
  elif repo_id in english_models:
121
+ return english_models[repo_id](repo_id, speed)
122
  else:
123
  raise ValueError(f"Unsupported repo_id: {repo_id}")
124