rayl-aoit commited on
Commit
c2dd837
1 Parent(s): 954e8fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -6
app.py CHANGED
@@ -1,14 +1,28 @@
1
  import gradio as gr
2
  import langcodes
3
- from transformers import pipeline, VitsModel, AutoTokenizer, set_seed
4
- from huggingface_hub import InferenceClient
5
- from langdetect import detect, DetectorFactory
6
  import torch
7
  import uuid
 
 
 
 
 
8
  import scipy.io.wavfile as wav
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- playground = gr.Blocks()
12
 
13
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
14
  image_pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
@@ -17,6 +31,87 @@ ner_pipe = pipeline("ner", model="dslim/bert-base-NER")
17
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
18
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def gen_speech(text):
21
  set_seed(555) # Make it deterministic
22
  input_text = tts_tokenizer(text, return_tensors="pt")
@@ -119,9 +214,68 @@ def create_playground_footer():
119
  **To Learn More about 🤗 Hugging Face, [Click Here](https://huggingface.co/docs)**
120
  """)
121
 
 
 
122
 
123
  with playground:
124
  create_playground_header()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  with gr.Tabs():
126
  ## ================================================================================================================================
127
  ## Image Captioning
@@ -148,8 +302,8 @@ with playground:
148
 
149
  gr.Examples(
150
  examples=[
151
- ["lion-dog-costume.jpg"],
152
- ["dog-halloween.jpeg"]
153
  ],
154
  inputs=[img],
155
  outputs=[generated_textbox, audio_output],
 
1
  import gradio as gr
2
  import langcodes
 
 
 
3
  import torch
4
  import uuid
5
+ import json
6
+ import librosa
7
+ import os
8
+ import tempfile
9
+ import soundfile as sf
10
  import scipy.io.wavfile as wav
11
 
12
+ from transformers import pipeline, VitsModel, AutoTokenizer, set_seed
13
+ from huggingface_hub import InferenceClient
14
+ from langdetect import detect, DetectorFactory
15
+ from nemo.collections.asr.models import EncDecMultiTaskModel
16
+
17
+ # Constants
18
+ SAMPLE_RATE = 16000 # Hz
19
+
20
+ # load ASR model
21
+ canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b')
22
+ decode_cfg = canary_model.cfg.decoding
23
+ decode_cfg.beam.beam_size = 1
24
+ canary_model.change_decoding_strategy(decode_cfg)
25
 
 
26
 
27
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
28
  image_pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
 
31
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
32
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
33
 
34
+ # Function to convert audio to text using ASR
35
+ def gen_text(audio_filepath, action, source_lang, target_lang):
36
+ if audio_filepath is None:
37
+ raise gr.Error("Please provide some input audio.")
38
+
39
+ utt_id = uuid.uuid4()
40
+ with tempfile.TemporaryDirectory() as tmpdir:
41
+ # Convert to 16 kHz
42
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
43
+ if sr != SAMPLE_RATE:
44
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
45
+ converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav")
46
+ sf.write(converted_audio_filepath, data, SAMPLE_RATE)
47
+
48
+ # Transcribe audio
49
+ duration = len(data) / SAMPLE_RATE
50
+ manifest_data = {
51
+ "audio_filepath": converted_audio_filepath,
52
+ "taskname": action,
53
+ "source_lang": source_lang,
54
+ "target_lang": source_lang if action=="asr" else target_lang,
55
+ "pnc": "no",
56
+ "answer": "predict",
57
+ "duration": str(duration),
58
+ }
59
+ manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
60
+ with open(manifest_filepath, 'w') as fout:
61
+ fout.write(json.dumps(manifest_data))
62
+
63
+ predicted_text = canary_model.transcribe(manifest_filepath)[0]
64
+ # if duration < 40:
65
+ # predicted_text = canary_model.transcribe(manifest_filepath)[0]
66
+ # else:
67
+ # predicted_text = get_buffered_pred_feat_multitaskAED(
68
+ # frame_asr,
69
+ # canary_model.cfg.preprocessor,
70
+ # model_stride_in_secs,
71
+ # canary_model.device,
72
+ # manifest=manifest_filepath,
73
+ # )[0].text
74
+
75
+ return predicted_text
76
+
77
+ # Function to convert text to speech using TTS
78
+ def gen_translated_speech(text, lang):
79
+ set_seed(555) # Make it deterministic
80
+ match lang:
81
+ case "en":
82
+ model = "facebook/mms-tts-eng"
83
+ case "fr":
84
+ model = "facebook/mms-tts-fra"
85
+ case "de":
86
+ model = "facebook/mms-tts-deu"
87
+ case "es":
88
+ model = "facebook/mms-tts-spa"
89
+ case _:
90
+ model = "facebook/mms-tts"
91
+
92
+ # load TTS model
93
+ tts_model = VitsModel.from_pretrained(model)
94
+ tts_tokenizer = AutoTokenizer.from_pretrained(model)
95
+
96
+ input_text = tts_tokenizer(text, return_tensors="pt")
97
+ with torch.no_grad():
98
+ outputs = tts_model(**input_text)
99
+ waveform_np = outputs.waveform[0].cpu().numpy()
100
+ output_file = f"{str(uuid.uuid4())}.wav"
101
+ wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
102
+ return output_file
103
+
104
+ # Root function for Gradio interface
105
+ def start_process(audio_filepath, source_lang, target_lang):
106
+ transcription = gen_text(audio_filepath, "asr", source_lang, target_lang)
107
+ print("Done transcribing")
108
+ translation = gen_text(audio_filepath, "s2t_translation", source_lang, target_lang)
109
+ print("Done translation")
110
+ audio_output_filepath = gen_translated_speech(translation, target_lang)
111
+ print("Done speaking")
112
+ return transcription, translation, audio_output_filepath
113
+
114
+
115
  def gen_speech(text):
116
  set_seed(555) # Make it deterministic
117
  input_text = tts_tokenizer(text, return_tensors="pt")
 
214
  **To Learn More about 🤗 Hugging Face, [Click Here](https://huggingface.co/docs)**
215
  """)
216
 
217
+ # Create Gradio interface
218
+ playground = gr.Blocks()
219
 
220
  with playground:
221
  create_playground_header()
222
+ with gr.Tabs():
223
+ ## ================================================================================================================================
224
+ ## Speech Translator
225
+ ## ================================================================================================================================
226
+ with gr.TabItem("Speech Translator"):
227
+ gr.Markdown("""
228
+ ## Your AI Translate Assistant
229
+ ### Gets input audio from user, transcribe and translate it. Convert back to speech.
230
+ - category: [Automatic Speech Recognition](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition), model: [nvidia/canary-1b](https://huggingface.co/nvidia/canary-1b)
231
+ - category: [Text-to-Speech](https://huggingface.co/models?pipeline_tag=text-to-speech), model: [facebook/mms-tts](https://huggingface.co/facebook/mms-tts)
232
+ """)
233
+
234
+ with gr.Row():
235
+ with gr.Column():
236
+ source_lang = gr.Dropdown(
237
+ choices=["en", "de", "es", "fr"], value="en", label="Source Language"
238
+ )
239
+ with gr.Column():
240
+ target_lang = gr.Dropdown(
241
+ choices=["en", "de", "es", "fr"], value="fr", label="Target Language"
242
+ )
243
+
244
+ with gr.Row():
245
+ with gr.Column():
246
+ input_audio = gr.Audio(sources=["microphone"], type="filepath", label="Input Audio")
247
+ with gr.Column():
248
+ translated_speech = gr.Audio(type="filepath", label="Generated Speech")
249
+
250
+ with gr.Row():
251
+ with gr.Column():
252
+ transcipted_text = gr.Textbox(label="Transcription")
253
+ with gr.Column():
254
+ translated_text = gr.Textbox(label="Translation")
255
+
256
+ with gr.Row():
257
+ with gr.Column():
258
+ submit_button = gr.Button(value="Start Process", variant="primary")
259
+ with gr.Column():
260
+ clear_button = gr.ClearButton(components=[input_audio, source_lang, target_lang, transcipted_text, translated_text, translated_speech], value="Clear")
261
+
262
+ with gr.Row():
263
+ gr.Examples(
264
+ examples=[
265
+ ["audio/sample_en.wav","en","fr"],
266
+ ["audio/sample_fr.wav","fr","de"],
267
+ ["audio/sample_de.wav","de","es"],
268
+ ["audio/sample_es.wav","es","en"]
269
+ ],
270
+ inputs=[input_audio, source_lang, target_lang],
271
+ outputs=[transcipted_text, translated_text, translated_speech],
272
+ run_on_click=True, cache_examples=True, fn=start_process
273
+ )
274
+
275
+ submit_button.click(start_process, inputs=[input_audio, source_lang, target_lang], outputs=[transcipted_text, translated_text, translated_speech])
276
+
277
+
278
+
279
  with gr.Tabs():
280
  ## ================================================================================================================================
281
  ## Image Captioning
 
302
 
303
  gr.Examples(
304
  examples=[
305
+ ["image/lion-dog-costume.jpg"],
306
+ ["image/dog-halloween.jpeg"]
307
  ],
308
  inputs=[img],
309
  outputs=[generated_textbox, audio_output],