Pierre Andrews commited on
Commit
92d98dc
1 Parent(s): 6f77ead

add toxicity mitigation to m4tv2

Browse files
Files changed (1) hide show
  1. app.py +32 -14
app.py CHANGED
@@ -60,11 +60,13 @@ if torch.cuda.is_available():
60
  else:
61
  device = torch.device("cpu")
62
  dtype = torch.float32
 
63
  translator = Translator(
64
  model_name_or_card="seamlessM4T_v2_large",
65
  vocoder_name_or_card="vocoder_v2",
66
  device=device,
67
  dtype=dtype,
 
68
  )
69
 
70
 
@@ -78,12 +80,16 @@ def preprocess_audio(input_audio: str) -> None:
78
  torchaudio.save(input_audio, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
79
 
80
 
81
- def run_s2st(input_audio: str, target_language: str) -> tuple[tuple[int, np.ndarray] | None, str]:
 
 
82
  preprocess_audio(input_audio)
 
83
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
84
  out_texts, out_audios = translator.predict(
85
  input=input_audio,
86
  task_str="S2ST",
 
87
  tgt_lang=target_language_code,
88
  )
89
  out_text = str(out_texts[0])
@@ -91,13 +97,15 @@ def run_s2st(input_audio: str, target_language: str) -> tuple[tuple[int, np.ndar
91
  return (int(AUDIO_SAMPLE_RATE), out_wav), out_text
92
 
93
 
94
- def run_s2tt(input_audio: str, target_language: str) -> str:
95
  preprocess_audio(input_audio)
 
96
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
97
  out_texts, _ = translator.predict(
98
  input=input_audio,
99
  task_str="S2TT",
100
  tgt_lang=target_language_code,
 
101
  )
102
  return str(out_texts[0])
103
 
@@ -144,6 +152,11 @@ with gr.Blocks() as demo_s2st:
144
  with gr.Column():
145
  with gr.Group():
146
  input_audio = gr.Audio(label="Input speech", type="filepath")
 
 
 
 
 
147
  target_language = gr.Dropdown(
148
  label="Target language",
149
  choices=S2ST_TARGET_LANGUAGE_NAMES,
@@ -162,12 +175,12 @@ with gr.Blocks() as demo_s2st:
162
 
163
  gr.Examples(
164
  examples=[
165
- ["assets/sample_input.mp3", "French"],
166
- ["assets/sample_input.mp3", "Mandarin Chinese"],
167
- ["assets/sample_input_2.mp3", "Hindi"],
168
- ["assets/sample_input_2.mp3", "Spanish"],
169
  ],
170
- inputs=[input_audio, target_language],
171
  outputs=[output_audio, output_text],
172
  fn=run_s2st,
173
  cache_examples=CACHE_EXAMPLES,
@@ -176,7 +189,7 @@ with gr.Blocks() as demo_s2st:
176
 
177
  btn.click(
178
  fn=run_s2st,
179
- inputs=[input_audio, target_language],
180
  outputs=[output_audio, output_text],
181
  api_name="s2st",
182
  )
@@ -186,6 +199,11 @@ with gr.Blocks() as demo_s2tt:
186
  with gr.Column():
187
  with gr.Group():
188
  input_audio = gr.Audio(label="Input speech", type="filepath")
 
 
 
 
 
189
  target_language = gr.Dropdown(
190
  label="Target language",
191
  choices=S2TT_TARGET_LANGUAGE_NAMES,
@@ -197,12 +215,12 @@ with gr.Blocks() as demo_s2tt:
197
 
198
  gr.Examples(
199
  examples=[
200
- ["assets/sample_input.mp3", "French"],
201
- ["assets/sample_input.mp3", "Mandarin Chinese"],
202
- ["assets/sample_input_2.mp3", "Hindi"],
203
- ["assets/sample_input_2.mp3", "Spanish"],
204
  ],
205
- inputs=[input_audio, target_language],
206
  outputs=output_text,
207
  fn=run_s2tt,
208
  cache_examples=CACHE_EXAMPLES,
@@ -211,7 +229,7 @@ with gr.Blocks() as demo_s2tt:
211
 
212
  btn.click(
213
  fn=run_s2tt,
214
- inputs=[input_audio, target_language],
215
  outputs=output_text,
216
  api_name="s2tt",
217
  )
 
60
  else:
61
  device = torch.device("cpu")
62
  dtype = torch.float32
63
+
64
  translator = Translator(
65
  model_name_or_card="seamlessM4T_v2_large",
66
  vocoder_name_or_card="vocoder_v2",
67
  device=device,
68
  dtype=dtype,
69
+ apply_mintox=True,
70
  )
71
 
72
 
 
80
  torchaudio.save(input_audio, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
81
 
82
 
83
+ def run_s2st(
84
+ input_audio: str, source_language: str, target_language: str
85
+ ) -> tuple[tuple[int, np.ndarray] | None, str]:
86
  preprocess_audio(input_audio)
87
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
88
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
89
  out_texts, out_audios = translator.predict(
90
  input=input_audio,
91
  task_str="S2ST",
92
+ src_lang=source_language_code,
93
  tgt_lang=target_language_code,
94
  )
95
  out_text = str(out_texts[0])
 
97
  return (int(AUDIO_SAMPLE_RATE), out_wav), out_text
98
 
99
 
100
+ def run_s2tt(input_audio: str, source_language: str, target_language: str) -> str:
101
  preprocess_audio(input_audio)
102
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
103
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
104
  out_texts, _ = translator.predict(
105
  input=input_audio,
106
  task_str="S2TT",
107
  tgt_lang=target_language_code,
108
+ src_lang=source_language_code,
109
  )
110
  return str(out_texts[0])
111
 
 
152
  with gr.Column():
153
  with gr.Group():
154
  input_audio = gr.Audio(label="Input speech", type="filepath")
155
+ source_language = gr.Dropdown(
156
+ label="Source language",
157
+ choices=ASR_TARGET_LANGUAGE_NAMES,
158
+ value="English",
159
+ )
160
  target_language = gr.Dropdown(
161
  label="Target language",
162
  choices=S2ST_TARGET_LANGUAGE_NAMES,
 
175
 
176
  gr.Examples(
177
  examples=[
178
+ ["assets/sample_input.mp3", "English", "French"],
179
+ ["assets/sample_input.mp3", "English", "Mandarin Chinese"],
180
+ ["assets/sample_input_2.mp3", "English", "Hindi"],
181
+ ["assets/sample_input_2.mp3", "English", "Spanish"],
182
  ],
183
+ inputs=[input_audio, source_language, target_language],
184
  outputs=[output_audio, output_text],
185
  fn=run_s2st,
186
  cache_examples=CACHE_EXAMPLES,
 
189
 
190
  btn.click(
191
  fn=run_s2st,
192
+ inputs=[input_audio, source_language, target_language],
193
  outputs=[output_audio, output_text],
194
  api_name="s2st",
195
  )
 
199
  with gr.Column():
200
  with gr.Group():
201
  input_audio = gr.Audio(label="Input speech", type="filepath")
202
+ source_language = gr.Dropdown(
203
+ label="Source language",
204
+ choices=ASR_TARGET_LANGUAGE_NAMES,
205
+ value="English",
206
+ )
207
  target_language = gr.Dropdown(
208
  label="Target language",
209
  choices=S2TT_TARGET_LANGUAGE_NAMES,
 
215
 
216
  gr.Examples(
217
  examples=[
218
+ ["assets/sample_input.mp3", "English", "French"],
219
+ ["assets/sample_input.mp3", "English", "Mandarin Chinese"],
220
+ ["assets/sample_input_2.mp3", "English", "Hindi"],
221
+ ["assets/sample_input_2.mp3", "English", "Spanish"],
222
  ],
223
+ inputs=[input_audio, source_language, target_language],
224
  outputs=output_text,
225
  fn=run_s2tt,
226
  cache_examples=CACHE_EXAMPLES,
 
229
 
230
  btn.click(
231
  fn=run_s2tt,
232
+ inputs=[input_audio, source_language, target_language],
233
  outputs=output_text,
234
  api_name="s2tt",
235
  )