PranshuSwaroop4's picture
Create app.py
8815ddf verified
raw
history blame contribute delete
No virus
2.38 kB
import gradio as gr
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import multiprocessing as mp
# Load the Wav2Vec2 model and processor
model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
# Function to process a single chunk of audio
def process_chunk(chunk, sample_rate):
# Resample the audio to 16000 Hz if necessary
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
chunk = resampler(chunk)
# Ensure the audio is in the correct format
chunk = chunk.squeeze().numpy()
# Process the audio to the format expected by the model
input_values = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values
# Perform inference
with torch.no_grad():
logits = model(input_values).logits
# Decode the logits to get the predicted text
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# Function to perform speech recognition on the entire audio
def speech_recognition(audio_path):
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_path)
# Split the waveform into chunks of 30 seconds
chunk_length = 30 * sample_rate # 30 seconds in samples
chunks = [waveform[:, i:i + chunk_length] for i in range(0, waveform.size(1), chunk_length)]
# Use multiprocessing to process chunks in parallel
with mp.Pool(mp.cpu_count()) as pool:
results = pool.starmap(process_chunk, [(chunk, sample_rate) for chunk in chunks])
# Combine the transcriptions
transcription = " ".join(results)
return transcription.strip()
# Create the Gradio interface
inputs = gr.Audio(type="filepath", label="Input Audio")
outputs = gr.Textbox(label="Transcription")
interface = gr.Interface(
fn=speech_recognition,
inputs=inputs,
outputs=outputs,
title="Speech Recognition using Wav2Vec2",
description="Upload a audio file or record the audio to get the transcription using the Wav2Vec2 model.",
article="This assignement is developed by Pranshu Swaroop",
)
# Launch the interface
if __name__ == "__main__":
interface.launch()