clementruhm commited on
Commit
f4fe081
1 Parent(s): 6473463

app.py: add locker

Browse files
Files changed (1) hide show
  1. app.py +32 -22
app.py CHANGED
@@ -8,24 +8,28 @@ import os
8
  import glob
9
  import logging
10
  from typing import cast
 
11
 
12
  import gradio as gr
13
  from balacoon_tts import TTS
14
  from huggingface_hub import hf_hub_download, list_repo_files
15
 
 
 
16
  # global tts module, initialized from a model selected
17
  tts = None
18
  # path to the model that is currently used in tts
19
  cur_model_path = None
20
  # cache of speakers, maps model name to speaker list
21
  model_to_speakers = dict()
22
- model_repo_dir = "data"
23
  for name in list_repo_files(repo_id="balacoon/tts"):
24
- hf_hub_download(
25
- repo_id="balacoon/tts",
26
- filename=name,
27
- local_dir=model_repo_dir,
28
- )
 
29
 
30
 
31
  def main():
@@ -70,13 +74,16 @@ def main():
70
  if model_name_str in model_to_speakers:
71
  speakers = model_to_speakers[model_name_str]
72
  else:
73
- global tts, cur_model_path
74
- # need to load this model to learn the list of speakers
75
- model_path = os.path.join(model_repo_dir, model_name_str)
76
- tts = TTS(model_path)
77
- cur_model_path = model_path
78
- speakers = tts.get_speakers()
79
- model_to_speakers[model_name_str] = speakers
 
 
 
80
 
81
  value = speakers[-1]
82
  return gr.Dropdown.update(
@@ -101,15 +108,18 @@ def main():
101
  logging.info("text, model name or speaker are not provided")
102
  return None
103
  expected_model_path = os.path.join(model_repo_dir, model_name_str)
104
- global tts, cur_model_path
105
- if expected_model_path != cur_model_path:
106
- # reload model
107
- tts = TTS(expected_model_path)
108
- cur_model_path = expected_model_path
109
- if len(text_str) > 1024:
110
- # truncate the text
111
- text_str = text_str[:1024]
112
- samples = tts.synthesize(text_str, speaker_str)
 
 
 
113
  return gr.Audio.update(value=(tts.get_sampling_rate(), samples))
114
 
115
  generate.click(synthesize_audio, inputs=[text, model_name, speaker], outputs=audio)
 
8
  import glob
9
  import logging
10
  from typing import cast
11
+ from threading import Lock
12
 
13
  import gradio as gr
14
  from balacoon_tts import TTS
15
  from huggingface_hub import hf_hub_download, list_repo_files
16
 
17
+ # locker that disallow access to the tts object from more then one thread
18
+ locker = Lock()
19
  # global tts module, initialized from a model selected
20
  tts = None
21
  # path to the model that is currently used in tts
22
  cur_model_path = None
23
  # cache of speakers, maps model name to speaker list
24
  model_to_speakers = dict()
25
+ model_repo_dir = "/data"
26
  for name in list_repo_files(repo_id="balacoon/tts"):
27
+ if not os.path.isfile(os.path.join(model_repo_dir, name)):
28
+ hf_hub_download(
29
+ repo_id="balacoon/tts",
30
+ filename=name,
31
+ local_dir=model_repo_dir,
32
+ )
33
 
34
 
35
  def main():
 
74
  if model_name_str in model_to_speakers:
75
  speakers = model_to_speakers[model_name_str]
76
  else:
77
+ global tts, cur_model_path, model_to_speakers, locker
78
+ with locker:
79
+ # need to load this model to learn the list of speakers
80
+ model_path = os.path.join(model_repo_dir, model_name_str)
81
+ if tts is not None:
82
+ del tts
83
+ tts = TTS(model_path)
84
+ cur_model_path = model_path
85
+ speakers = tts.get_speakers()
86
+ model_to_speakers[model_name_str] = speakers
87
 
88
  value = speakers[-1]
89
  return gr.Dropdown.update(
 
108
  logging.info("text, model name or speaker are not provided")
109
  return None
110
  expected_model_path = os.path.join(model_repo_dir, model_name_str)
111
+ global tts, cur_model_path, locker
112
+ with locker:
113
+ if expected_model_path != cur_model_path:
114
+ # reload model
115
+ if tts is not None:
116
+ del tts
117
+ tts = TTS(expected_model_path)
118
+ cur_model_path = expected_model_path
119
+ if len(text_str) > 1024:
120
+ # truncate the text
121
+ text_str = text_str[:1024]
122
+ samples = tts.synthesize(text_str, speaker_str)
123
  return gr.Audio.update(value=(tts.get_sampling_rate(), samples))
124
 
125
  generate.click(synthesize_audio, inputs=[text, model_name, speaker], outputs=audio)