PranshuSwaroop4 commited on
Commit
8815ddf
1 Parent(s): d1722d4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
+ import torchaudio
5
+ import multiprocessing as mp
6
+
7
+ # Load the Wav2Vec2 model and processor
8
+ model_name = "facebook/wav2vec2-base-960h"
9
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
10
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
11
+
12
+ # Function to process a single chunk of audio
13
+ def process_chunk(chunk, sample_rate):
14
+ # Resample the audio to 16000 Hz if necessary
15
+ if sample_rate != 16000:
16
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
17
+ chunk = resampler(chunk)
18
+
19
+ # Ensure the audio is in the correct format
20
+ chunk = chunk.squeeze().numpy()
21
+
22
+ # Process the audio to the format expected by the model
23
+ input_values = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values
24
+
25
+ # Perform inference
26
+ with torch.no_grad():
27
+ logits = model(input_values).logits
28
+
29
+ # Decode the logits to get the predicted text
30
+ predicted_ids = torch.argmax(logits, dim=-1)
31
+ transcription = processor.batch_decode(predicted_ids)[0]
32
+
33
+ return transcription
34
+
35
+ # Function to perform speech recognition on the entire audio
36
+ def speech_recognition(audio_path):
37
+ # Load the audio file
38
+ waveform, sample_rate = torchaudio.load(audio_path)
39
+
40
+ # Split the waveform into chunks of 30 seconds
41
+ chunk_length = 30 * sample_rate # 30 seconds in samples
42
+ chunks = [waveform[:, i:i + chunk_length] for i in range(0, waveform.size(1), chunk_length)]
43
+
44
+ # Use multiprocessing to process chunks in parallel
45
+ with mp.Pool(mp.cpu_count()) as pool:
46
+ results = pool.starmap(process_chunk, [(chunk, sample_rate) for chunk in chunks])
47
+
48
+ # Combine the transcriptions
49
+ transcription = " ".join(results)
50
+
51
+ return transcription.strip()
52
+
53
+ # Create the Gradio interface
54
+ inputs = gr.Audio(type="filepath", label="Input Audio")
55
+ outputs = gr.Textbox(label="Transcription")
56
+
57
+ interface = gr.Interface(
58
+ fn=speech_recognition,
59
+ inputs=inputs,
60
+ outputs=outputs,
61
+ title="Speech Recognition using Wav2Vec2",
62
+ description="Upload a audio file or record the audio to get the transcription using the Wav2Vec2 model.",
63
+ article="This assignement is developed by Pranshu Swaroop",
64
+ )
65
+
66
+ # Launch the interface
67
+ if __name__ == "__main__":
68
+ interface.launch()