import gradio as gr import torchaudio import torch from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC from transformers import Speech2Text2Processor, Speech2Text2ForConditionalGeneration from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification # Load the models asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all") asr_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all") tts_model = Speech2Text2ForConditionalGeneration.from_pretrained("facebook/mms-tts") tts_processor = Speech2Text2Processor.from_pretrained("facebook/mms-tts") lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-1024") lid_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-lid-1024") # ASR Function def asr_transcribe(audio): inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) with torch.no_grad(): logits = asr_model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = asr_processor.batch_decode(predicted_ids) return transcription[0] # TTS Function def tts_synthesize(text): inputs = tts_processor(text, return_tensors="pt", padding=True) with torch.no_grad(): generated_ids = tts_model.generate(**inputs) audio = tts_processor.batch_decode(generated_ids, skip_special_tokens=True) return audio[0] # Language ID Function def identify_language(audio): inputs = lid_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) with torch.no_grad(): logits = lid_model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) language = lid_processor.batch_decode(predicted_ids) return language[0] # Define the Gradio interfaces with gr.Blocks() as demo: with gr.Tab("ASR"): gr.Markdown("## Automatic Speech Recognition (ASR)") audio_input = gr.Audio(source="microphone", type="numpy") text_output = gr.Textbox(label="Transcription") gr.Button("Clear", clear_audio_input) gr.Button("Submit", fn=asr_transcribe, inputs=audio_input, outputs=text_output) with gr.Tab("TTS"): gr.Markdown("## Text-to-Speech (TTS)") text_input = gr.Textbox(label="Text") audio_output = gr.Audio(label="Audio Output") gr.Button("Clear", clear_text_input) gr.Button("Submit", fn=tts_synthesize, inputs=text_input, outputs=audio_output) with gr.Tab("Language ID"): gr.Markdown("## Language Identification (LangID)") audio_input = gr.Audio(source="microphone", type="numpy") language_output = gr.Textbox(label="Identified Language") gr.Button("Clear", clear_audio_input) gr.Button("Submit", fn=identify_language, inputs=audio_input, outputs=language_output) demo.launch()