Sakil commited on
Commit
17fa61e
Β·
1 Parent(s): 54bb083

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -19
app.py CHANGED
@@ -1,23 +1,84 @@
 
 
1
  import librosa
 
2
  import torch
3
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
4
  import gradio as gr
5
- from transformers import pipeline
6
- import IPython.display as display
7
- import soundfile as sf
8
- def speech_text(audio_file):
9
- tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
10
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
11
- speech, rate = librosa.load(audio_file,sr=16000)
12
- display.Audio(audio_file, autoplay=True)
13
- print(rate)
14
- input_values = tokenizer(speech, return_tensors ='pt').input_values
15
- #Store logits (non-normalized predictions)
16
- logits = model(input_values).logits
17
- #Store predicted id's
18
- predicted_ids = torch.argmax(logits, dim =-1)
19
- transcriptions = tokenizer.decode(predicted_ids[0])
20
- return transcriptions
21
- iface = gr.Interface(speech_text,inputs="audio",outputs="text",title='Sakil Transcription',description="Transcription")
22
- iface.launch(inline=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
 
1
+ #Importing all the necessary packages
2
+ import nltk
3
  import librosa
4
+ import IPython.display
5
  import torch
 
6
  import gradio as gr
7
+ from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
8
+ nltk.download("punkt")
9
+ #Loading the model and the tokenizer
10
+ model_name = "facebook/wav2vec2-base-960h"
11
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
12
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
13
+
14
+ def load_data(input_file):
15
+ """ Function for resampling to ensure that the speech input is sampled at 16KHz.
16
+ """
17
+ #read the file
18
+ speech, sample_rate = librosa.load(input_file)
19
+ #make it 1-D
20
+ if len(speech.shape) > 1:
21
+ speech = speech[:,0] + speech[:,1]
22
+ #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
23
+ if sample_rate !=16000:
24
+ speech = librosa.resample(speech, sample_rate,16000)
25
+ #speeches = librosa.effects.split(speech)
26
+ return speech
27
+ def correct_casing(input_sentence):
28
+ """ This function is for correcting the casing of the generated transcribed text
29
+ """
30
+ sentences = nltk.sent_tokenize(input_sentence)
31
+ return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
32
+
33
+ def asr_transcript(input_file):
34
+ """This function generates transcripts for the provided audio input
35
+ """
36
+ speech = load_data(input_file)
37
+ #Tokenize
38
+ input_values = tokenizer(speech, return_tensors="pt").input_values
39
+ #Take logits
40
+ logits = model(input_values).logits
41
+ #Take argmax
42
+ predicted_ids = torch.argmax(logits, dim=-1)
43
+ #Get the words from predicted word ids
44
+ transcription = tokenizer.decode(predicted_ids[0])
45
+ #Output is all upper case
46
+ transcription = correct_casing(transcription.lower())
47
+ return transcription
48
+ def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
49
+ transcript = ""
50
+ # Ensure that the sample rate is 16k
51
+ sample_rate = librosa.get_samplerate(input_file)
52
+
53
+ # Stream over 10 seconds chunks rather than load the full file
54
+ stream = librosa.stream(
55
+ input_file,
56
+ block_length=20, #number of seconds to split the batch
57
+ frame_length=sample_rate, #16000,
58
+ hop_length=sample_rate, #16000
59
+ )
60
+
61
+ for speech in stream:
62
+ if len(speech.shape) > 1:
63
+ speech = speech[:, 0] + speech[:, 1]
64
+ if sample_rate !=16000:
65
+ speech = librosa.resample(speech, sample_rate,16000)
66
+ input_values = tokenizer(speech, return_tensors="pt").input_values
67
+ logits = model(input_values).logits
68
+
69
+ predicted_ids = torch.argmax(logits, dim=-1)
70
+ transcription = tokenizer.decode(predicted_ids[0])
71
+ #transcript += transcription.lower()
72
+ transcript += correct_casing(transcription.lower())
73
+ #transcript += " "
74
+
75
+ return transcript[:3800]
76
+ gr.Interface(asr_transcript_long,
77
+ #inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
78
+ inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your audio file here"),
79
+ outputs = gr.outputs.Textbox(type="str",label="Output Text"),
80
+ title="English Audio Transcriptor",
81
+ description = "This tool transcribes your audio to the text",
82
+ examples = [["Batman1_dialogue.wav"], ["Batman2_dialogue.wav"], ["Batman3_dialogue.wav"],["catwoman_dialogue.wav"]], theme="grass").launch()
83
+
84