unijoh commited on
Commit
382e37a
1 Parent(s): 778a8fe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torchaudio
3
+ import torch
4
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
+ from transformers import Speech2Text2Processor, Speech2Text2ForConditionalGeneration
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
7
+
8
+ # Load the models
9
+ asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all")
10
+ asr_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
11
+
12
+ tts_model = Speech2Text2ForConditionalGeneration.from_pretrained("facebook/mms-tts")
13
+ tts_processor = Speech2Text2Processor.from_pretrained("facebook/mms-tts")
14
+
15
+ lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-1024")
16
+ lid_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-lid-1024")
17
+
18
+ # ASR Function
19
+ def asr_transcribe(audio):
20
+ inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
21
+ with torch.no_grad():
22
+ logits = asr_model(**inputs).logits
23
+ predicted_ids = torch.argmax(logits, dim=-1)
24
+ transcription = asr_processor.batch_decode(predicted_ids)
25
+ return transcription[0]
26
+
27
+ # TTS Function
28
+ def tts_synthesize(text):
29
+ inputs = tts_processor(text, return_tensors="pt", padding=True)
30
+ with torch.no_grad():
31
+ generated_ids = tts_model.generate(**inputs)
32
+ audio = tts_processor.batch_decode(generated_ids, skip_special_tokens=True)
33
+ return audio[0]
34
+
35
+ # Language ID Function
36
+ def identify_language(audio):
37
+ inputs = lid_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
38
+ with torch.no_grad():
39
+ logits = lid_model(**inputs).logits
40
+ predicted_ids = torch.argmax(logits, dim=-1)
41
+ language = lid_processor.batch_decode(predicted_ids)
42
+ return language[0]
43
+
44
+ # Define the Gradio interfaces
45
+ with gr.Blocks() as demo:
46
+ with gr.Tab("ASR"):
47
+ gr.Markdown("## Automatic Speech Recognition (ASR)")
48
+ audio_input = gr.Audio(source="microphone", type="numpy")
49
+ text_output = gr.Textbox(label="Transcription")
50
+ gr.Button("Clear", clear_audio_input)
51
+ gr.Button("Submit", fn=asr_transcribe, inputs=audio_input, outputs=text_output)
52
+
53
+ with gr.Tab("TTS"):
54
+ gr.Markdown("## Text-to-Speech (TTS)")
55
+ text_input = gr.Textbox(label="Text")
56
+ audio_output = gr.Audio(label="Audio Output")
57
+ gr.Button("Clear", clear_text_input)
58
+ gr.Button("Submit", fn=tts_synthesize, inputs=text_input, outputs=audio_output)
59
+
60
+ with gr.Tab("Language ID"):
61
+ gr.Markdown("## Language Identification (LangID)")
62
+ audio_input = gr.Audio(source="microphone", type="numpy")
63
+ language_output = gr.Textbox(label="Identified Language")
64
+ gr.Button("Clear", clear_audio_input)
65
+ gr.Button("Submit", fn=identify_language, inputs=audio_input, outputs=language_output)
66
+
67
+ demo.launch()