clementruhm commited on
Commit
9ef73e2
1 Parent(s): 4263bcd

reload model on synthesis if needed

Browse files
Files changed (1) hide show
  1. app.py +29 -14
app.py CHANGED
@@ -15,6 +15,10 @@ from huggingface_hub import hf_hub_download, list_repo_files
15
 
16
  # global tts module, initialized from a model selected
17
  tts = None
 
 
 
 
18
  model_repo_dir = "data"
19
  for name in list_repo_files(repo_id="balacoon/tts"):
20
  hf_hub_download(
@@ -58,14 +62,20 @@ def main():
58
 
59
  def set_model(model_name_str: str):
60
  """
61
- gets value from `model_name`, loads model,
62
- re-initializes tts object, gets list of
63
- speakers that model supports and set them to `speaker`
64
  """
65
- model_path = os.path.join(model_repo_dir, model_name_str)
66
- global tts
67
- tts = TTS(model_path)
68
- speakers = tts.get_speakers()
 
 
 
 
 
 
69
  value = speakers[-1]
70
  return gr.Dropdown.update(
71
  choices=speakers, value=value, visible=True
@@ -78,23 +88,28 @@ def main():
78
  with gr.Row(variant="panel"):
79
  audio = gr.Audio()
80
 
81
- def synthesize_audio(text_str: str, speaker_str: str = ""):
82
  """
83
  gets utterance to synthesize from `text` Textbox
84
  and speaker name from `speaker` dropdown list.
85
  speaker name might be empty for single-speaker models.
86
  Synthesizes the waveform and updates `audio` with it.
87
  """
88
- if not text_str:
89
- logging.info("text or speaker are not provided")
90
  return None
91
- global tts
 
 
 
 
92
  if len(text_str) > 1024:
 
93
  text_str = text_str[:1024]
94
- samples = cast(TTS, tts).synthesize(text_str, speaker_str)
95
- return gr.Audio.update(value=(cast(TTS, tts).get_sampling_rate(), samples))
96
 
97
- generate.click(synthesize_audio, inputs=[text, speaker], outputs=audio)
98
 
99
  demo.queue(concurrency_count=1).launch()
100
 
 
15
 
16
  # global tts module, initialized from a model selected
17
  tts = None
18
+ # path to the model that is currently used in tts
19
+ cur_model_path = None
20
+ # cache of speakers, maps model name to speaker list
21
+ model_to_speakers = dict()
22
  model_repo_dir = "data"
23
  for name in list_repo_files(repo_id="balacoon/tts"):
24
  hf_hub_download(
 
62
 
63
  def set_model(model_name_str: str):
64
  """
65
+ gets value from `model_name`. either
66
+ uses cached list of speakers for the given model name
67
+ or loads the addon and checks what are the speakers.
68
  """
69
+ if model_name_str in model_to_speakers:
70
+ speakers = model_to_speakers[model_name_str]
71
+ else:
72
+ # need to load this model to learn the list of speakers
73
+ model_path = os.path.join(model_repo_dir, model_name_str)
74
+ tts = TTS(model_path)
75
+ cur_model_path = model_path
76
+ speakers = tts.get_speakers()
77
+ model_to_speakers[model_name_str] = speakers
78
+
79
  value = speakers[-1]
80
  return gr.Dropdown.update(
81
  choices=speakers, value=value, visible=True
 
88
  with gr.Row(variant="panel"):
89
  audio = gr.Audio()
90
 
91
+ def synthesize_audio(text_str: str, model_name_str: str, speaker_str: str):
92
  """
93
  gets utterance to synthesize from `text` Textbox
94
  and speaker name from `speaker` dropdown list.
95
  speaker name might be empty for single-speaker models.
96
  Synthesizes the waveform and updates `audio` with it.
97
  """
98
+ if not text_str or not model_name_str or not speaker_str:
99
+ logging.info("text, model name or speaker are not provided")
100
  return None
101
+ expected_model_path = os.path.join(model_repo_dir, model_name_str)
102
+ if expected_model_path != cur_model_path:
103
+ # reload model
104
+ tts = TTS(expected_model_path)
105
+ cur_model_path = expected_model_path
106
  if len(text_str) > 1024:
107
+ # truncate the text
108
  text_str = text_str[:1024]
109
+ samples = tts.synthesize(text_str, speaker_str)
110
+ return gr.Audio.update(value=(tts.get_sampling_rate(), samples))
111
 
112
+ generate.click(synthesize_audio, inputs=[text, model_name, speaker], outputs=audio)
113
 
114
  demo.queue(concurrency_count=1).launch()
115