PranshuSwaroop4's picture
Create app.py
8815ddf verified
import gradio as gr
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import multiprocessing as mp
# Load the Wav2Vec2 model and processor
model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
# Function to process a single chunk of audio
def process_chunk(chunk, sample_rate):
# Resample the audio to 16000 Hz if necessary
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
chunk = resampler(chunk)
# Ensure the audio is in the correct format
chunk = chunk.squeeze().numpy()
# Process the audio to the format expected by the model
input_values = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values
# Perform inference
with torch.no_grad():
logits = model(input_values).logits
# Decode the logits to get the predicted text
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# Function to perform speech recognition on the entire audio
def speech_recognition(audio_path):
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_path)
# Split the waveform into chunks of 30 seconds
chunk_length = 30 * sample_rate # 30 seconds in samples
chunks = [waveform[:, i:i + chunk_length] for i in range(0, waveform.size(1), chunk_length)]
# Use multiprocessing to process chunks in parallel
with mp.Pool(mp.cpu_count()) as pool:
results = pool.starmap(process_chunk, [(chunk, sample_rate) for chunk in chunks])
# Combine the transcriptions
transcription = " ".join(results)
return transcription.strip()
# Create the Gradio interface
inputs = gr.Audio(type="filepath", label="Input Audio")
outputs = gr.Textbox(label="Transcription")
interface = gr.Interface(
fn=speech_recognition,
inputs=inputs,
outputs=outputs,
title="Speech Recognition using Wav2Vec2",
description="Upload a audio file or record the audio to get the transcription using the Wav2Vec2 model.",
article="This assignement is developed by Pranshu Swaroop",
)
# Launch the interface
if __name__ == "__main__":
interface.launch()