Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
import torch | |
import torchaudio | |
# Model URLs | |
model_urls = [ | |
"kiranpantha/whisper-tiny-ne", | |
"kiranpantha/whisper-base-ne", | |
"kiranpantha/whisper-small-np", | |
"kiranpantha/whisper-medium-nepali", | |
"kiranpantha/whisper-large-v3-nepali", | |
"kiranpantha/whisper-large-v3-turbo-nepali", | |
] | |
# Mapping model names correctly | |
processor_mappings = { | |
"kiranpantha/whisper-tiny-ne": "openai/whisper-tiny", | |
"kiranpantha/whisper-base-ne": "openai/whisper-base", | |
"kiranpantha/whisper-small-np": "openai/whisper-small", | |
"kiranpantha/whisper-medium-nepali": "openai/whisper-medium", | |
"kiranpantha/whisper-large-v3-nepali": "openai/whisper-large-v3", | |
"kiranpantha/whisper-large-v3-turbo-nepali": "openai/whisper-large-v3", | |
} | |
# Cache models and processors | |
model_cache = {} | |
def load_model(model_name): | |
"""Loads and caches the model and processor with proper device management.""" | |
if model_name not in model_cache: | |
processor_name = processor_mappings.get(model_name, model_name) # Handle mapping | |
processor = AutoProcessor.from_pretrained(processor_name) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device) | |
model.eval() | |
model_cache[model_name] = (processor, model, device) | |
return model_cache[model_name] | |
import numpy as np | |
def transcribe_audio(model_name, audio_chunk): | |
try: | |
print("Received audio_chunk:", type(audio_chunk), audio_chunk) | |
if audio_chunk is None: | |
return "Error: No audio received" | |
if isinstance(audio_chunk, str): | |
# Upload case | |
audio_tensor, sample_rate = torchaudio.load(audio_chunk) | |
audio_array = audio_tensor.squeeze(0).numpy() | |
elif isinstance(audio_chunk, tuple) and isinstance(audio_chunk[1], np.ndarray): | |
# Microphone case | |
sample_rate, audio_array = audio_chunk | |
else: | |
return "Error: Invalid audio input format" | |
# Stereo to mono | |
if audio_array.ndim == 2: | |
audio_array = np.mean(audio_array, axis=0) | |
# Resample | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
audio_array = resampler(torch.tensor(audio_array).unsqueeze(0)).squeeze(0).numpy() | |
# Load model | |
processor, model, device = load_model(model_name) | |
# Prepare inputs | |
inputs = processor( | |
torch.tensor(audio_array), sampling_rate=16000, return_tensors="pt" | |
) | |
input_features = inputs.input_features.to(device) | |
# Generate output | |
generated_ids = model.generate( | |
input_features, | |
forced_decoder_ids=processor.get_decoder_prompt_ids(language="ne", task="transcribe"), | |
max_length=448, | |
) | |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return transcription.strip() | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# ποΈ Nepali Speech Recognition with Whisper Models") | |
model_dropdown = gr.Dropdown(choices=model_urls, label="Select Model", value=model_urls[0]) | |
audio_input = gr.Audio(type="numpy", label="π€ Record your voice here") | |
output_text = gr.Textbox(label="π Transcription Output") | |
transcribe_button = gr.Button("Transcribe") | |
transcribe_button.click( | |
fn=transcribe_audio, # <-- fixed function name | |
inputs=[model_dropdown, audio_input], | |
outputs=output_text, | |
) | |
demo.launch() |