Spaces:
Running
Running
import gradio as gr | |
import soundfile as sf | |
import torch | |
from transformers import Wav2Vec2ForCTC, AutoProcessor | |
# Assuming 'transcribe' was defined in a previous cell. | |
# If not, define it here or import it from the correct module. | |
# Create a placeholder for ASR_LANGUAGES if it's not defined elsewhere. | |
ASR_LANGUAGES = {"eng": "English", "swh": "Swahili"} # Replace with your actual languages | |
# ✅ Define or Re-define the `transcribe` function within this cell | |
MODEL_ID = "facebook/mms-1b-all" # Make sure this is the same model ID used for training | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) | |
def transcribe(audio_path, language): | |
"""Transcribes an audio file using the fine-tuned model.""" | |
# Set the target language based on user selection | |
if language: | |
target_lang = language.split(" ")[0] # Extract language code | |
processor.tokenizer.set_target_lang(target_lang) | |
if target_lang != "eng": # Load adapter if not English | |
model.load_adapter(target_lang) | |
audio, samplerate = sf.read(audio_path) | |
inputs = processor(audio, sampling_rate=samplerate, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs).logits | |
ids = torch.argmax(outputs, dim=-1)[0] | |
return processor.decode(ids) | |
mms_transcribe = gr.Interface( | |
fn=transcribe, | |
inputs=[ | |
gr.Audio(), | |
gr.Dropdown( | |
[f"{k} ({v})" for k, v in ASR_LANGUAGES.items()], | |
label="Language", | |
value="eng English", | |
), | |
], | |
outputs="text", | |
title="Speech-to-Text Transcription", | |
description="Transcribe audio input into text.", | |
allow_flagging="never", | |
) | |
with gr.Blocks() as demo: | |
mms_transcribe.render() | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() |