Yilin0601 commited on
Commit
e2fc711
·
verified ·
1 Parent(s): 1ee4794

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -35
app.py CHANGED
@@ -5,17 +5,17 @@ import librosa
5
  from transformers import pipeline, VitsModel, AutoTokenizer
6
  import scipy # if needed for processing
7
 
8
- # -----------------------------------------------
9
  # 1. ASR Pipeline (English)
10
- # -----------------------------------------------
11
  asr = pipeline(
12
  "automatic-speech-recognition",
13
  model="facebook/wav2vec2-base-960h"
14
  )
15
 
16
- # -----------------------------------------------
17
  # 2. Translation Models (3 languages)
18
- # -----------------------------------------------
19
  translation_models = {
20
  "Spanish": "Helsinki-NLP/opus-mt-en-es",
21
  "Chinese": "Helsinki-NLP/opus-mt-en-zh",
@@ -28,34 +28,32 @@ translation_tasks = {
28
  "Japanese": "translation_en_to_ja"
29
  }
30
 
31
- # -----------------------------------------------
32
- # 3. TTS Model Configurations (All VITS)
33
- # -----------------------------------------------
34
- # Make sure these model IDs exist on Hugging Face.
 
35
  tts_config = {
36
  "Spanish": {
37
- "model_id": "facebook/mms-tts-spa",
38
- "architecture": "vits"
39
- },
40
- "Chinese": {
41
- "model_id": "facebook/mms-tts-che",
42
  "architecture": "vits"
43
  },
 
44
  "Japanese": {
45
- "model_id": "facebook/mms-tts-jpn",
46
  "architecture": "vits"
47
  }
48
  }
49
 
50
- # -----------------------------------------------
51
  # 4. Caches
52
- # -----------------------------------------------
53
  translator_cache = {}
54
  tts_model_cache = {} # store (model, tokenizer, architecture)
55
 
56
- # -----------------------------------------------
57
  # 5. Translator Helper
58
- # -----------------------------------------------
59
  def get_translator(lang):
60
  if lang in translator_cache:
61
  return translator_cache[lang]
@@ -65,25 +63,27 @@ def get_translator(lang):
65
  translator_cache[lang] = translator
66
  return translator
67
 
68
- # -----------------------------------------------
69
  # 6. TTS Loading Helper
70
- # -----------------------------------------------
71
  def get_tts_model(lang):
72
  """
73
  Loads (model, tokenizer, architecture) from Hugging Face once, then caches.
 
74
  """
75
  if lang in tts_model_cache:
76
  return tts_model_cache[lang]
77
 
78
  config = tts_config.get(lang)
79
  if config is None:
 
80
  raise ValueError(f"No TTS config found for language: {lang}")
81
 
82
  model_id = config["model_id"]
83
  arch = config["architecture"]
84
 
85
  try:
86
- # Since arch == "vits" for all three languages, we load VitsModel + AutoTokenizer
87
  model = VitsModel.from_pretrained(model_id)
88
  tokenizer = AutoTokenizer.from_pretrained(model_id)
89
  except Exception as e:
@@ -92,9 +92,9 @@ def get_tts_model(lang):
92
  tts_model_cache[lang] = (model, tokenizer, arch)
93
  return tts_model_cache[lang]
94
 
95
- # -----------------------------------------------
96
  # 7. TTS Inference Helper
97
- # -----------------------------------------------
98
  def run_tts_inference(lang, text):
99
  """
100
  Generates waveform using the loaded TTS model and tokenizer.
@@ -119,14 +119,14 @@ def run_tts_inference(lang, text):
119
  sample_rate = 16000
120
  return (sample_rate, waveform)
121
 
122
- # -----------------------------------------------
123
  # 8. Prediction Function
124
- # -----------------------------------------------
125
  def predict(audio, text, target_language):
126
  """
127
  1. Obtain English text (from text input or ASR).
128
  2. Translate English -> target_language.
129
- 3. Run VITS-based TTS for that language.
130
  """
131
  # Step 1: English text
132
  if text.strip():
@@ -142,7 +142,7 @@ def predict(audio, text, target_language):
142
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
143
  audio_data = np.mean(audio_data, axis=1)
144
 
145
- # Resample to 16k
146
  if sample_rate != 16000:
147
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
148
 
@@ -160,17 +160,20 @@ def predict(audio, text, target_language):
160
  except Exception as e:
161
  return english_text, f"Translation error: {e}", None
162
 
163
- # Step 3: TTS
164
  try:
 
 
 
165
  sample_rate, waveform = run_tts_inference(target_language, translated_text)
166
  except Exception as e:
167
  return english_text, translated_text, f"TTS error: {e}"
168
 
169
  return english_text, translated_text, (sample_rate, waveform)
170
 
171
- # -----------------------------------------------
172
  # 9. Gradio Interface
173
- # -----------------------------------------------
174
  iface = gr.Interface(
175
  fn=predict,
176
  inputs=[
@@ -181,19 +184,20 @@ iface = gr.Interface(
181
  outputs=[
182
  gr.Textbox(label="English Transcription"),
183
  gr.Textbox(label="Translation (Target Language)"),
184
- gr.Audio(label="Synthesized Speech in Target Language")
185
  ],
186
  title="Multimodal Language Learning Aid (MMS TTS / VITS)",
187
  description=(
188
  "This app:\n"
189
  "1. Transcribes English speech (via ASR) or accepts English text.\n"
190
  "2. Translates to Spanish, Chinese, or Japanese (Helsinki-NLP).\n"
191
- "3. Synthesizes speech with VITS-based MMS TTS models.\n\n"
192
- "Note: Ensure the MMS model IDs exist on Hugging Face. If not, you'll see an error.\n"
193
- "Record/upload English audio or enter text, then select a target language."
194
  ),
195
  allow_flagging="never"
196
  )
197
 
198
  if __name__ == "__main__":
199
- iface.launch()
 
 
 
5
  from transformers import pipeline, VitsModel, AutoTokenizer
6
  import scipy # if needed for processing
7
 
8
+ # ------------------------------------------------------
9
  # 1. ASR Pipeline (English)
10
+ # ------------------------------------------------------
11
  asr = pipeline(
12
  "automatic-speech-recognition",
13
  model="facebook/wav2vec2-base-960h"
14
  )
15
 
16
+ # ------------------------------------------------------
17
  # 2. Translation Models (3 languages)
18
+ # ------------------------------------------------------
19
  translation_models = {
20
  "Spanish": "Helsinki-NLP/opus-mt-en-es",
21
  "Chinese": "Helsinki-NLP/opus-mt-en-zh",
 
28
  "Japanese": "translation_en_to_ja"
29
  }
30
 
31
+ # ------------------------------------------------------
32
+ # 3. TTS Model Configurations
33
+ # NOTE: MMS does not provide a Mandarin TTS model,
34
+ # so we skip TTS for Chinese.
35
+ # ------------------------------------------------------
36
  tts_config = {
37
  "Spanish": {
38
+ "model_id": "facebook/mms-tts-spa", # MMS Spanish
 
 
 
 
39
  "architecture": "vits"
40
  },
41
+ "Chinese": None, # No MMS TTS for Chinese
42
  "Japanese": {
43
+ "model_id": "facebook/mms-tts-jpn", # MMS Japanese
44
  "architecture": "vits"
45
  }
46
  }
47
 
48
+ # ------------------------------------------------------
49
  # 4. Caches
50
+ # ------------------------------------------------------
51
  translator_cache = {}
52
  tts_model_cache = {} # store (model, tokenizer, architecture)
53
 
54
+ # ------------------------------------------------------
55
  # 5. Translator Helper
56
+ # ------------------------------------------------------
57
  def get_translator(lang):
58
  if lang in translator_cache:
59
  return translator_cache[lang]
 
63
  translator_cache[lang] = translator
64
  return translator
65
 
66
+ # ------------------------------------------------------
67
  # 6. TTS Loading Helper
68
+ # ------------------------------------------------------
69
  def get_tts_model(lang):
70
  """
71
  Loads (model, tokenizer, architecture) from Hugging Face once, then caches.
72
+ If no config is found (e.g. for Chinese), raises ValueError.
73
  """
74
  if lang in tts_model_cache:
75
  return tts_model_cache[lang]
76
 
77
  config = tts_config.get(lang)
78
  if config is None:
79
+ # No TTS model for this language
80
  raise ValueError(f"No TTS config found for language: {lang}")
81
 
82
  model_id = config["model_id"]
83
  arch = config["architecture"]
84
 
85
  try:
86
+ # Since arch == "vits" for these examples, load VitsModel + AutoTokenizer
87
  model = VitsModel.from_pretrained(model_id)
88
  tokenizer = AutoTokenizer.from_pretrained(model_id)
89
  except Exception as e:
 
92
  tts_model_cache[lang] = (model, tokenizer, arch)
93
  return tts_model_cache[lang]
94
 
95
+ # ------------------------------------------------------
96
  # 7. TTS Inference Helper
97
+ # ------------------------------------------------------
98
  def run_tts_inference(lang, text):
99
  """
100
  Generates waveform using the loaded TTS model and tokenizer.
 
119
  sample_rate = 16000
120
  return (sample_rate, waveform)
121
 
122
+ # ------------------------------------------------------
123
  # 8. Prediction Function
124
+ # ------------------------------------------------------
125
  def predict(audio, text, target_language):
126
  """
127
  1. Obtain English text (from text input or ASR).
128
  2. Translate English -> target_language.
129
+ 3. Run VITS-based TTS for that language (if available).
130
  """
131
  # Step 1: English text
132
  if text.strip():
 
142
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
143
  audio_data = np.mean(audio_data, axis=1)
144
 
145
+ # Resample to 16k if needed
146
  if sample_rate != 16000:
147
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
148
 
 
160
  except Exception as e:
161
  return english_text, f"Translation error: {e}", None
162
 
163
+ # Step 3: TTS (skip if no config for language)
164
  try:
165
+ if tts_config[target_language] is None:
166
+ # No TTS model for Chinese or not supported
167
+ return english_text, translated_text, None
168
  sample_rate, waveform = run_tts_inference(target_language, translated_text)
169
  except Exception as e:
170
  return english_text, translated_text, f"TTS error: {e}"
171
 
172
  return english_text, translated_text, (sample_rate, waveform)
173
 
174
+ # ------------------------------------------------------
175
  # 9. Gradio Interface
176
+ # ------------------------------------------------------
177
  iface = gr.Interface(
178
  fn=predict,
179
  inputs=[
 
184
  outputs=[
185
  gr.Textbox(label="English Transcription"),
186
  gr.Textbox(label="Translation (Target Language)"),
187
+ gr.Audio(label="Synthesized Speech (if available)")
188
  ],
189
  title="Multimodal Language Learning Aid (MMS TTS / VITS)",
190
  description=(
191
  "This app:\n"
192
  "1. Transcribes English speech (via ASR) or accepts English text.\n"
193
  "2. Translates to Spanish, Chinese, or Japanese (Helsinki-NLP).\n"
194
+ "3. Synthesizes speech with VITS-based MMS TTS models for Spanish/Japanese.\n\n"
195
+ "Note: MMS does NOT currently provide a Mandarin TTS model, so TTS is skipped for Chinese."
 
196
  ),
197
  allow_flagging="never"
198
  )
199
 
200
  if __name__ == "__main__":
201
+ # If running locally, uncomment:
202
+ # iface.launch()
203
+ iface.launch(server_name="0.0.0.0", server_port=7860)