File size: 3,455 Bytes
17fa61e
 
1a6ae54
17fa61e
1a6ae54
 
17fa61e
 
5d74bcf
17fa61e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfcd066
17fa61e
1a6ae54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#Importing all the necessary packages
import nltk
import librosa
import IPython.display
import torch
import gradio as gr
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
nltk.download("punkt")
#Loading the model
model_name = "facebook/wav2vec2-base-960h"
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
model = Wav2Vec2ForCTC.from_pretrained(model_name)

def load_data(input_file):
    """ Function for resampling to ensure that the speech input is sampled at 16KHz.
    """
    #read the file
    speech, sample_rate = librosa.load(input_file)
    #make it 1-D
    if len(speech.shape) > 1: 
        speech = speech[:,0] + speech[:,1]
    #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
    if sample_rate !=16000:
        speech = librosa.resample(speech, sample_rate,16000)
    #speeches = librosa.effects.split(speech)
    return speech
def correct_casing(input_sentence):
    """ This function is for correcting the casing of the generated transcribed text
    """
    sentences = nltk.sent_tokenize(input_sentence)
    return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))

def asr_transcript(input_file):
    """This function generates transcripts for the provided audio input
    """
    speech = load_data(input_file)
    #Tokenize
    input_values = tokenizer(speech, return_tensors="pt").input_values
    #Take logits
    logits = model(input_values).logits
    #Take argmax
    predicted_ids = torch.argmax(logits, dim=-1)
    #Get the words from predicted word ids
    transcription = tokenizer.decode(predicted_ids[0])
    #Output is all upper case
    transcription = correct_casing(transcription.lower())
    return transcription
def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
    transcript = ""
    # Ensure that the sample rate is 16k
    sample_rate = librosa.get_samplerate(input_file)

    # Stream over 10 seconds chunks rather than load the full file
    stream = librosa.stream(
        input_file,
        block_length=20, #number of seconds to split the batch
        frame_length=sample_rate, #16000,
        hop_length=sample_rate, #16000
    )

    for speech in stream:
        if len(speech.shape) > 1:
            speech = speech[:, 0] + speech[:, 1]
        if sample_rate !=16000:
            speech = librosa.resample(speech, sample_rate,16000)
        input_values = tokenizer(speech, return_tensors="pt").input_values
        logits = model(input_values).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = tokenizer.decode(predicted_ids[0])
        #transcript += transcription.lower()
        transcript += correct_casing(transcription.lower())
        #transcript += " "

    return transcript[:3800]
gr.Interface(asr_transcript_long,
             #inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
             inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your audio file here"),
             outputs = gr.outputs.Textbox(type="str",label="Output Text"),
             title="English Audio Transcriptor",
             description = "This tool transcribes your audio to the text",
             examples = [["Batman1_dialogue.wav"], ["batman2_dialogue.wav"], ["batman3_dialogue.wav"],["catwoman_dialogue.wav"]], theme="grass").launch()