whisper-nepali / app.py
kiranpantha's picture
Update app.py
197cbd5 verified
import gradio as gr
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import torch
import torchaudio
# Model URLs
model_urls = [
"kiranpantha/whisper-tiny-ne",
"kiranpantha/whisper-base-ne",
"kiranpantha/whisper-small-np",
"kiranpantha/whisper-medium-nepali",
"kiranpantha/whisper-large-v3-nepali",
"kiranpantha/whisper-large-v3-turbo-nepali",
]
# Mapping model names correctly
processor_mappings = {
"kiranpantha/whisper-tiny-ne": "openai/whisper-tiny",
"kiranpantha/whisper-base-ne": "openai/whisper-base",
"kiranpantha/whisper-small-np": "openai/whisper-small",
"kiranpantha/whisper-medium-nepali": "openai/whisper-medium",
"kiranpantha/whisper-large-v3-nepali": "openai/whisper-large-v3",
"kiranpantha/whisper-large-v3-turbo-nepali": "openai/whisper-large-v3",
}
# Cache models and processors
model_cache = {}
def load_model(model_name):
"""Loads and caches the model and processor with proper device management."""
if model_name not in model_cache:
processor_name = processor_mappings.get(model_name, model_name) # Handle mapping
processor = AutoProcessor.from_pretrained(processor_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
model.eval()
model_cache[model_name] = (processor, model, device)
return model_cache[model_name]
import numpy as np
def transcribe_audio(model_name, audio_chunk):
try:
print("Received audio_chunk:", type(audio_chunk), audio_chunk)
if audio_chunk is None:
return "Error: No audio received"
if isinstance(audio_chunk, str):
# Upload case
audio_tensor, sample_rate = torchaudio.load(audio_chunk)
audio_array = audio_tensor.squeeze(0).numpy()
elif isinstance(audio_chunk, tuple) and isinstance(audio_chunk[1], np.ndarray):
# Microphone case
sample_rate, audio_array = audio_chunk
else:
return "Error: Invalid audio input format"
# Stereo to mono
if audio_array.ndim == 2:
audio_array = np.mean(audio_array, axis=0)
# Resample
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
audio_array = resampler(torch.tensor(audio_array).unsqueeze(0)).squeeze(0).numpy()
# Load model
processor, model, device = load_model(model_name)
# Prepare inputs
inputs = processor(
torch.tensor(audio_array), sampling_rate=16000, return_tensors="pt"
)
input_features = inputs.input_features.to(device)
# Generate output
generated_ids = model.generate(
input_features,
forced_decoder_ids=processor.get_decoder_prompt_ids(language="ne", task="transcribe"),
max_length=448,
)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription.strip()
except Exception as e:
return f"Error: {str(e)}"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# πŸŽ™οΈ Nepali Speech Recognition with Whisper Models")
model_dropdown = gr.Dropdown(choices=model_urls, label="Select Model", value=model_urls[0])
audio_input = gr.Audio(type="numpy", label="🎀 Record your voice here")
output_text = gr.Textbox(label="πŸ“„ Transcription Output")
transcribe_button = gr.Button("Transcribe")
transcribe_button.click(
fn=transcribe_audio, # <-- fixed function name
inputs=[model_dropdown, audio_input],
outputs=output_text,
)
demo.launch()