|
import gradio as gr |
|
import torch |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
import torchaudio |
|
import multiprocessing as mp |
|
|
|
|
|
model_name = "facebook/wav2vec2-base-960h" |
|
processor = Wav2Vec2Processor.from_pretrained(model_name) |
|
model = Wav2Vec2ForCTC.from_pretrained(model_name) |
|
|
|
|
|
def process_chunk(chunk, sample_rate): |
|
|
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
chunk = resampler(chunk) |
|
|
|
|
|
chunk = chunk.squeeze().numpy() |
|
|
|
|
|
input_values = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values |
|
|
|
|
|
with torch.no_grad(): |
|
logits = model(input_values).logits |
|
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.batch_decode(predicted_ids)[0] |
|
|
|
return transcription |
|
|
|
|
|
def speech_recognition(audio_path): |
|
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
|
|
|
|
chunk_length = 30 * sample_rate |
|
chunks = [waveform[:, i:i + chunk_length] for i in range(0, waveform.size(1), chunk_length)] |
|
|
|
|
|
with mp.Pool(mp.cpu_count()) as pool: |
|
results = pool.starmap(process_chunk, [(chunk, sample_rate) for chunk in chunks]) |
|
|
|
|
|
transcription = " ".join(results) |
|
|
|
return transcription.strip() |
|
|
|
|
|
inputs = gr.Audio(type="filepath", label="Input Audio") |
|
outputs = gr.Textbox(label="Transcription") |
|
|
|
interface = gr.Interface( |
|
fn=speech_recognition, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Speech Recognition using Wav2Vec2", |
|
description="Upload a audio file or record the audio to get the transcription using the Wav2Vec2 model.", |
|
article="This assignement is developed by Pranshu Swaroop", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|