Vaibhav Srivastav commited on
Commit
0448aa2
1 Parent(s): 57926d1

adding pyctcdecode code

Browse files
Files changed (1) hide show
  1. app.py +23 -27
app.py CHANGED
@@ -1,58 +1,54 @@
1
- #Importing all the necessary packages
2
  import nltk
3
  import librosa
4
  import torch
5
  import gradio as gr
6
- from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
 
7
 
8
  nltk.download("punkt")
9
 
10
  #Loading the model and the tokenizer
11
  model_name = "facebook/wav2vec2-base-960h"
12
- tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
13
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
14
 
15
- def load_data(input_file):
16
-
17
- """ Function for resampling to ensure that the speech input is sampled at 16KHz.
18
- """
19
  #read the file
20
  speech, sample_rate = librosa.load(input_file)
21
  #make it 1-D
22
  if len(speech.shape) > 1:
23
  speech = speech[:,0] + speech[:,1]
24
- #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
25
  if sample_rate !=16000:
26
  speech = librosa.resample(speech, sample_rate,16000)
27
  return speech
28
 
29
 
30
- def correct_casing(input_sentence):
31
- """ This function is for correcting the casing of the letters in the sentence
32
- """
33
  sentences = nltk.sent_tokenize(input_sentence)
34
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
35
 
36
- def asr_transcript(input_file):
37
- """This function generates transcripts for the provided audio input
38
- """
39
  speech = load_data(input_file)
40
-
41
- #Tokenize
42
- input_values = tokenizer(speech, return_tensors="pt").input_values
43
- #Take logits
44
  logits = model(input_values).logits
45
- #Take argmax
46
- predicted_ids = torch.argmax(logits, dim=-1)
47
- #Get the words from predicted word ids
48
- transcription = tokenizer.decode(predicted_ids[0])
 
 
 
 
49
  #Output is all upper case
50
- transcription = correct_casing(transcription.lower())
51
- return transcription
52
 
53
- gr.Interface(asr_transcript,
54
  inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker"),
55
  outputs = gr.outputs.Textbox(label="Output Text"),
56
- title="ASR using Wav2Vec 2.0",
57
  description = "Wav2Vec2 in-action",
58
- examples = [["test.wav"]], theme="grass").launch()
 
 
 
1
  import nltk
2
  import librosa
3
  import torch
4
  import gradio as gr
5
+ from pyctcdecode import build_ctcdecoder
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
7
 
8
  nltk.download("punkt")
9
 
10
  #Loading the model and the tokenizer
11
  model_name = "facebook/wav2vec2-base-960h"
12
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
13
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
14
 
15
+ def load_data(input_file):
 
 
 
16
  #read the file
17
  speech, sample_rate = librosa.load(input_file)
18
  #make it 1-D
19
  if len(speech.shape) > 1:
20
  speech = speech[:,0] + speech[:,1]
21
+ #resampling to 16KHz
22
  if sample_rate !=16000:
23
  speech = librosa.resample(speech, sample_rate,16000)
24
  return speech
25
 
26
 
27
+ def fix_transcription_casing(input_sentence):
 
 
28
  sentences = nltk.sent_tokenize(input_sentence)
29
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
30
 
31
+ def predict_and_decode(input_file):
 
 
32
  speech = load_data(input_file)
33
+ #tokenize
34
+ input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
 
 
35
  logits = model(input_values).logits
36
+ vocab_list = list(processor.tokenizer.get_vocab().keys())
37
+ # #Take argmax
38
+ # predicted_ids = torch.argmax(logits, dim=-1)
39
+ # #Get the words from predicted word ids
40
+ # transcription = tokenizer.decode(predicted_ids[0])
41
+ decoder = build_ctcdecoder(vocab_list)
42
+ pred = decoder.decode(logits)
43
+
44
  #Output is all upper case
45
+ transcribed_text = fix_transcription_casing(pred.lower())
46
+ return transcribed_text
47
 
48
+ gr.Interface(predict_and_decode,
49
  inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker"),
50
  outputs = gr.outputs.Textbox(label="Output Text"),
51
+ title="ASR using Wav2Vec 2.0 & pyctcdecode",
52
  description = "Wav2Vec2 in-action",
53
+ layout = "horizontal",
54
+ examples = [["test.wav"]], theme="huggingface").launch()