import gradio as gr import torch from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor import torchaudio import multiprocessing as mp # Load the Wav2Vec2 model and processor model_name = "facebook/wav2vec2-base-960h" processor = Wav2Vec2Processor.from_pretrained(model_name) model = Wav2Vec2ForCTC.from_pretrained(model_name) # Function to process a single chunk of audio def process_chunk(chunk, sample_rate): # Resample the audio to 16000 Hz if necessary if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) chunk = resampler(chunk) # Ensure the audio is in the correct format chunk = chunk.squeeze().numpy() # Process the audio to the format expected by the model input_values = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values # Perform inference with torch.no_grad(): logits = model(input_values).logits # Decode the logits to get the predicted text predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] return transcription # Function to perform speech recognition on the entire audio def speech_recognition(audio_path): # Load the audio file waveform, sample_rate = torchaudio.load(audio_path) # Split the waveform into chunks of 30 seconds chunk_length = 30 * sample_rate # 30 seconds in samples chunks = [waveform[:, i:i + chunk_length] for i in range(0, waveform.size(1), chunk_length)] # Use multiprocessing to process chunks in parallel with mp.Pool(mp.cpu_count()) as pool: results = pool.starmap(process_chunk, [(chunk, sample_rate) for chunk in chunks]) # Combine the transcriptions transcription = " ".join(results) return transcription.strip() # Create the Gradio interface inputs = gr.Audio(type="filepath", label="Input Audio") outputs = gr.Textbox(label="Transcription") interface = gr.Interface( fn=speech_recognition, inputs=inputs, outputs=outputs, title="Speech Recognition using Wav2Vec2", description="Upload a audio file or record the audio to get the transcription using the Wav2Vec2 model.", article="This assignement is developed by Pranshu Swaroop", ) # Launch the interface if __name__ == "__main__": interface.launch()