metaambod / app.py
unijoh's picture
Update app.py
7bce9ab verified
raw
history blame
No virus
3.02 kB
import gradio as gr
import torchaudio
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import AutoProcessor, AutoModelForSeq2SeqLM
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
# Load the models
asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all")
asr_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
# Correct TTS model path
tts_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mms-tts/models/fao")
tts_processor = AutoProcessor.from_pretrained("facebook/mms-tts/models/fao")
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-1024")
lid_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-lid-1024")
# ASR Function
def asr_transcribe(audio):
inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = asr_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.batch_decode(predicted_ids)
return transcription[0]
# TTS Function
def tts_synthesize(text):
inputs = tts_processor(text, return_tensors="pt", padding=True)
with torch.no_grad():
generated_ids = tts_model.generate(**inputs)
audio = tts_processor.batch_decode(generated_ids, skip_special_tokens=True)
return audio[0]
# Language ID Function
def identify_language(audio):
inputs = lid_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = lid_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
language = lid_processor.batch_decode(predicted_ids)
return language[0]
# Clear Functions
def clear_audio_input():
return None
def clear_text_input():
return ""
# Define the Gradio interfaces
with gr.Blocks() as demo:
with gr.Tab("ASR"):
gr.Markdown("## Automatic Speech Recognition (ASR)")
audio_input = gr.Audio(source="microphone", type="numpy")
text_output = gr.Textbox(label="Transcription")
gr.Button("Clear", fn=clear_audio_input, inputs=[], outputs=audio_input)
gr.Button("Submit", fn=asr_transcribe, inputs=audio_input, outputs=text_output)
with gr.Tab("TTS"):
gr.Markdown("## Text-to-Speech (TTS)")
text_input = gr.Textbox(label="Text")
audio_output = gr.Audio(label="Audio Output")
gr.Button("Clear", fn=clear_text_input, inputs=[], outputs=text_input)
gr.Button("Submit", fn=tts_synthesize, inputs=text_input, outputs=audio_output)
with gr.Tab("Language ID"):
gr.Markdown("## Language Identification (LangID)")
audio_input = gr.Audio(source="microphone", type="numpy")
language_output = gr.Textbox(label="Identified Language")
gr.Button("Clear", fn=clear_audio_input, inputs=[], outputs=audio_input)
gr.Button("Submit", fn=identify_language, inputs=audio_input, outputs=language_output)
demo.launch()