piecurus commited on
Commit
20d39bb
1 Parent(s): 7dc1211

app commit with wav2vec-base-960h v3

Browse files
Files changed (1) hide show
  1. app.py +64 -42
app.py CHANGED
@@ -1,44 +1,66 @@
1
- import soundfile as sf
 
 
 
 
 
 
2
  import torch
3
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
- from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
5
  import gradio as gr
6
- import sox
7
-
8
- def convert(inputfile, outfile):
9
- sox_tfm = sox.Transformer()
10
- sox_tfm.set_output_format(
11
- file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
12
- )
13
- sox_tfm.build(inputfile, outfile)
14
-
15
-
16
-
17
- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
18
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
19
-
20
- def parse_transcription(wav_file):
21
- filename = wav_file.name.split('.')[0]
22
- convert(wav_file.name, filename + "16k.wav")
23
- speech, _ = sf.read(filename + "16k.wav")
24
- input_values = processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
25
- logits = model(input_values).logits
26
- predicted_ids = torch.argmax(logits, dim=-1)
27
- transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
28
- return transcription,
29
-
30
-
31
-
32
- output1 = gr.outputs.Textbox(label="Transcription in English: ")
33
- output2 = gr.outputs.Textbox(label="Validated Transcription in English")
34
-
35
- input_ = gr.inputs.Audio(source="microphone", type="file")
36
- #gr.Interface(parse_transcription, inputs = input_, outputs="text",
37
- # analytics_enabled=False, show_tips=False, enable_queue=True).launch(inline=False);
38
-
39
- gr.Interface(parse_transcription, inputs = input_, outputs=[output1, output2], analytics_enabled=False,
40
- show_tips=False,
41
- theme='huggingface',
42
- layout='vertical',
43
- title="Piecurus Test on Speech Transcription",
44
- description="This is a live demo for Speech to Text Translation. Models used: facebook/wav2vec2-base-960h", enable_queue=True).launch( inline=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #References: 1. https://www.kdnuggets.com/2021/03/speech-text-wav2vec.html
2
+ #2. https://www.youtube.com/watch?v=4CoVcsxZphE
3
+ #3. https://www.analyticsvidhya.com/blog/2021/02/hugging-face-introduces-the-first-automatic-speech-recognition-model-wav2vec2/
4
+
5
+ #Importing all the necessary packages
6
+ import nltk
7
+ import librosa
8
  import torch
 
 
9
  import gradio as gr
10
+ from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
11
+ nltk.download("punkt")
12
+
13
+ #Loading the model and the tokenizer
14
+ model_name = "facebook/wav2vec2-base-960h"
15
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
16
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
17
+
18
+
19
+ def load_data(input_file):
20
+
21
+ """ Function for resampling to ensure that the speech input is sampled at 16KHz.
22
+ """
23
+ #read the file
24
+ speech, sample_rate = librosa.load(input_file)
25
+ #make it 1-D
26
+ if len(speech.shape) > 1:
27
+ speech = speech[:,0] + speech[:,1]
28
+ #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
29
+ if sample_rate !=16000:
30
+ speech = librosa.resample(speech, sample_rate,16000)
31
+ return speech
32
+
33
+
34
+
35
+ def correct_casing(input_sentence):
36
+ """ This function is for correcting the casing of the generated transcribed text
37
+ """
38
+ sentences = nltk.sent_tokenize(input_sentence)
39
+ return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
40
+
41
+
42
+
43
+ def asr_transcript(input_file):
44
+ """This function generates transcripts for the provided audio input
45
+ """
46
+ speech = load_data(input_file)
47
+
48
+ #Tokenize
49
+ input_values = tokenizer(speech, return_tensors="pt").input_values
50
+ #Take logits
51
+ logits = model(input_values).logits
52
+ #Take argmax
53
+ predicted_ids = torch.argmax(logits, dim=-1)
54
+ #Get the words from predicted word ids
55
+ transcription = tokenizer.decode(predicted_ids[0])
56
+ #Output is all upper case
57
+ transcription = correct_casing(transcription.lower())
58
+ return transcription
59
+
60
+
61
+ gr.Interface(asr_transcript,
62
+ inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
63
+ outputs = gr.outputs.Textbox(label="Output Text"),
64
+ title="ASR using Wav2Vec 2.0",
65
+ description = "This application displays transcribed text for given audio input",
66
+ examples = [["Test_File1.wav"], ["Test_File2.wav"], ["Test_File3.wav"]], theme="grass").launch()