File size: 3,615 Bytes
da0005f
 
 
 
 
 
20d39bb
 
 
da0005f
dc3ecb8
21907eb
20d39bb
 
 
da0005f
cab3a0b
da0005f
 
20d39bb
 
da0005f
 
 
b79461c
20d39bb
 
cab3a0b
da0005f
 
20d39bb
da0005f
 
 
 
 
 
 
 
 
 
 
 
 
 
cab3a0b
da0005f
20d39bb
 
da0005f
 
 
 
 
 
cab3a0b
20d39bb
 
 
da0005f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cab3a0b
da0005f
 
 
 
 
 
 
17ded20
da0005f
 
17ded20
da0005f
 
 
 
 
 
 
 
 
 
b79461c
da0005f
 
 
17ded20
da0005f
17ded20
da0005f
59cfd1e
da0005f
cab3a0b
da0005f
 
 
 
 
 
17ded20
 
20d39bb
 
da0005f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
# coding: utf-8

# In[ ]:


#Importing all the necessary packages
import nltk
import librosa
import IPython.display
import torch
import gradio as gr
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
nltk.download("punkt")


# In[ ]:


#Loading the model and the tokenizer
model_name = "facebook/wav2vec2-base-960h"

#model_name = "facebook/wav2vec2-large-xlsr-53"
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
model = Wav2Vec2ForCTC.from_pretrained(model_name)


# In[ ]:


def load_data(input_file):
    """ Function for resampling to ensure that the speech input is sampled at 16KHz.
    """
    #read the file
    speech, sample_rate = librosa.load(input_file)
    #make it 1-D
    if len(speech.shape) > 1: 
        speech = speech[:,0] + speech[:,1]
    #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
    if sample_rate !=16000:
        speech = librosa.resample(speech, sample_rate,16000)
    #speeches = librosa.effects.split(speech)
    return speech


# In[ ]:


def correct_casing(input_sentence):
    """ This function is for correcting the casing of the generated transcribed text
    """
    sentences = nltk.sent_tokenize(input_sentence)
    return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))


# In[ ]:


def asr_transcript(input_file):
    """This function generates transcripts for the provided audio input
    """
    speech = load_data(input_file)
    #Tokenize
    input_values = tokenizer(speech, return_tensors="pt").input_values
    #Take logits
    logits = model(input_values).logits
    #Take argmax
    predicted_ids = torch.argmax(logits, dim=-1)
    #Get the words from predicted word ids
    transcription = tokenizer.decode(predicted_ids[0])
    #Output is all upper case
    transcription = correct_casing(transcription.lower())
    return transcription


# In[ ]:


def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
    transcript = ""
    # Ensure that the sample rate is 16k
    sample_rate = librosa.get_samplerate(input_file)

    # Stream over 10 seconds chunks rather than load the full file
    stream = librosa.stream(
        input_file,
        block_length=20, #number of seconds to split the batch
        frame_length=sample_rate, #16000,
        hop_length=sample_rate, #16000
    )

    for speech in stream:
        if len(speech.shape) > 1:
            speech = speech[:, 0] + speech[:, 1]
        if sample_rate !=16000:
            speech = librosa.resample(speech, sample_rate,16000)
        input_values = tokenizer(speech, return_tensors="pt").input_values
        logits = model(input_values).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = tokenizer.decode(predicted_ids[0])
        #transcript += transcription.lower()
        transcript += correct_casing(transcription.lower())
        #transcript += " "

    return transcript[:3800]


# In[ ]:


gr.Interface(asr_transcript_long,
             #inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
             inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your file here"),
             outputs = gr.outputs.Textbox(type="str",label="Output Text"),
             title="Transcript and Translate",
             description = "This application displays transcribed text for given audio input",
             examples = [["Test_File1.wav"], ["Test_File2.wav"], ["Test_File3.wav"]], theme="grass").launch()