Vaibhav Srivastav commited on
Commit
57926d1
1 Parent(s): 6e7462e

Adding Wav2Vec2 code

Browse files
Files changed (2) hide show
  1. app.py +55 -4
  2. test.wav +0 -0
app.py CHANGED
@@ -1,7 +1,58 @@
 
 
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Importing all the necessary packages
2
+ import nltk
3
+ import librosa
4
+ import torch
5
  import gradio as gr
6
+ from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
7
 
8
+ nltk.download("punkt")
 
9
 
10
+ #Loading the model and the tokenizer
11
+ model_name = "facebook/wav2vec2-base-960h"
12
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
13
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
14
+
15
+ def load_data(input_file):
16
+
17
+ """ Function for resampling to ensure that the speech input is sampled at 16KHz.
18
+ """
19
+ #read the file
20
+ speech, sample_rate = librosa.load(input_file)
21
+ #make it 1-D
22
+ if len(speech.shape) > 1:
23
+ speech = speech[:,0] + speech[:,1]
24
+ #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
25
+ if sample_rate !=16000:
26
+ speech = librosa.resample(speech, sample_rate,16000)
27
+ return speech
28
+
29
+
30
+ def correct_casing(input_sentence):
31
+ """ This function is for correcting the casing of the letters in the sentence
32
+ """
33
+ sentences = nltk.sent_tokenize(input_sentence)
34
+ return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
35
+
36
+ def asr_transcript(input_file):
37
+ """This function generates transcripts for the provided audio input
38
+ """
39
+ speech = load_data(input_file)
40
+
41
+ #Tokenize
42
+ input_values = tokenizer(speech, return_tensors="pt").input_values
43
+ #Take logits
44
+ logits = model(input_values).logits
45
+ #Take argmax
46
+ predicted_ids = torch.argmax(logits, dim=-1)
47
+ #Get the words from predicted word ids
48
+ transcription = tokenizer.decode(predicted_ids[0])
49
+ #Output is all upper case
50
+ transcription = correct_casing(transcription.lower())
51
+ return transcription
52
+
53
+ gr.Interface(asr_transcript,
54
+ inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker"),
55
+ outputs = gr.outputs.Textbox(label="Output Text"),
56
+ title="ASR using Wav2Vec 2.0",
57
+ description = "Wav2Vec2 in-action",
58
+ examples = [["test.wav"]], theme="grass").launch()
test.wav ADDED
Binary file (165 kB). View file