metaambod / app.py
unijoh's picture
Create app.py
382e37a verified
raw
history blame
No virus
2.82 kB
import gradio as gr
import torchaudio
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import Speech2Text2Processor, Speech2Text2ForConditionalGeneration
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
# Load the models
asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all")
asr_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
tts_model = Speech2Text2ForConditionalGeneration.from_pretrained("facebook/mms-tts")
tts_processor = Speech2Text2Processor.from_pretrained("facebook/mms-tts")
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-1024")
lid_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-lid-1024")
# ASR Function
def asr_transcribe(audio):
inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = asr_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.batch_decode(predicted_ids)
return transcription[0]
# TTS Function
def tts_synthesize(text):
inputs = tts_processor(text, return_tensors="pt", padding=True)
with torch.no_grad():
generated_ids = tts_model.generate(**inputs)
audio = tts_processor.batch_decode(generated_ids, skip_special_tokens=True)
return audio[0]
# Language ID Function
def identify_language(audio):
inputs = lid_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = lid_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
language = lid_processor.batch_decode(predicted_ids)
return language[0]
# Define the Gradio interfaces
with gr.Blocks() as demo:
with gr.Tab("ASR"):
gr.Markdown("## Automatic Speech Recognition (ASR)")
audio_input = gr.Audio(source="microphone", type="numpy")
text_output = gr.Textbox(label="Transcription")
gr.Button("Clear", clear_audio_input)
gr.Button("Submit", fn=asr_transcribe, inputs=audio_input, outputs=text_output)
with gr.Tab("TTS"):
gr.Markdown("## Text-to-Speech (TTS)")
text_input = gr.Textbox(label="Text")
audio_output = gr.Audio(label="Audio Output")
gr.Button("Clear", clear_text_input)
gr.Button("Submit", fn=tts_synthesize, inputs=text_input, outputs=audio_output)
with gr.Tab("Language ID"):
gr.Markdown("## Language Identification (LangID)")
audio_input = gr.Audio(source="microphone", type="numpy")
language_output = gr.Textbox(label="Identified Language")
gr.Button("Clear", clear_audio_input)
gr.Button("Submit", fn=identify_language, inputs=audio_input, outputs=language_output)
demo.launch()