asr-whisper / app.py
GavinHuang's picture
Enhance audio processing by ensuring correct numpy array conversion and adding error handling during transcription
0a928fe
import gradio as gr
import numpy as np
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import librosa
from collections import deque
import time
import spaces
# Model settings
MODEL_ID = "openai/whisper-small"
DEVICE = "cpu" # ZeroGPU uses CPU
WINDOW_SECONDS = 1.0 # Window size for transcription
OVERLAP_SECONDS = 0.5 # Overlap between windows
RATE = 16000 # Whisper expects 16kHz audio
# Initialize Whisper model and processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID, low_cpu_mem_usage=True, use_safetensors=True
).to(DEVICE)
processor = AutoProcessor.from_pretrained(MODEL_ID)
# Global state
audio_buffer = deque()
buffer_duration = 0.0
last_transcription = ""
def process_audio_chunk(audio_chunk):
"""Process a single audio chunk and update buffer."""
global audio_buffer, buffer_duration
# Convert audio chunk to numpy array if not already
if not isinstance(audio_chunk, np.ndarray):
audio_array = np.array(audio_chunk, dtype=np.float32)
else:
audio_array = audio_chunk # Already a numpy array with correct type
audio_buffer.append(audio_array)
buffer_duration += len(audio_array) / RATE
return audio_array
def transcribe_audio():
"""Process audio buffer with sliding window and yield transcriptions."""
global audio_buffer, buffer_duration, last_transcription
window_samples = int(WINDOW_SECONDS * RATE)
overlap_samples = int(OVERLAP_SECONDS * RATE)
step_samples = window_samples - overlap_samples # Step size for sliding window
while buffer_duration >= WINDOW_SECONDS:
# Concatenate buffer into a window
audio_window = np.concatenate(list(audio_buffer))
audio_window = audio_window[:window_samples] # Trim to window size
# Process audio with Whisper
try:
# Ensure audio is in the correct format for librosa
audio_window = audio_window.astype(np.float32)
audio_input, _ = librosa.load(audio_window, sr=RATE, mono=True)
inputs = processor(audio_input, sampling_rate=RATE, return_tensors="pt").to(DEVICE)
with torch.no_grad():
predicted_ids = model.generate(inputs["input_features"])
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
except Exception as e:
print(f"Error during transcription: {e}")
continue
# Yield transcription if different from the last one
if transcription and transcription != last_transcription:
last_transcription = transcription
yield transcription
# Slide window: remove samples up to step size
samples_to_remove = step_samples
while samples_to_remove > 0 and audio_buffer:
if len(audio_buffer[0]) > samples_to_remove:
audio_buffer[0] = audio_buffer[0][samples_to_remove:]
buffer_duration -= samples_to_remove / RATE
break
else:
samples_to_remove -= len(audio_buffer[0])
buffer_duration -= len(audio_buffer[0]) / RATE
audio_buffer.popleft()
@spaces.GPU
def audio_stream(audio):
"""Handle streaming audio input from Gradio."""
# Audio is a tuple (sample_rate, data) from Gradio
sample_rate, audio_data = audio
# Ensure audio data is floating-point for librosa
audio_data = np.array(audio_data, dtype=np.float32)
# Resample audio to 16kHz if needed
if sample_rate != RATE:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=RATE)
# Process audio chunk
process_audio_chunk(audio_data)
# Transcribe and yield results
for transcription in transcribe_audio():
yield transcription
# Initialize application state
def init_app():
"""Initialize the application state."""
global audio_buffer, buffer_duration, last_transcription
audio_buffer = deque()
buffer_duration = 0.0
last_transcription = ""
return "Transcription is active. Speak into the microphone."
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Real-Time Speech-to-Text with Whisper")
gr.Markdown("Record audio using the microphone and see transcriptions in real-time. Hosted on Hugging Face Spaces with ZeroGPU.")
audio_input = gr.Audio(sources=["microphone"], streaming=True, label="Speak Here")
output_text = gr.Textbox(label="Transcription", value="Transcription is active. Speak into the microphone.", interactive=False)
demo.load(init_app, outputs=output_text)
audio_input.stream(audio_stream, inputs=audio_input, outputs=output_text)
# Launch the app
demo.launch()