File size: 2,801 Bytes
09ec1f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import numpy as np
import librosa
import torch

from math import ceil
import nemo.collections.asr as nemo_asr


asr_model = nemo_asr.models.EncDecCTCModelBPE. \
                    from_pretrained("theodotus/stt_uk_squeezeformer_ctc_sm",map_location="cpu")

asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()


buffer_len = 3.2
chunk_len = 0.8
total_buffer = round(buffer_len * asr_model.cfg.sample_rate)
overhead_len = round((buffer_len - chunk_len) *  asr_model.cfg.sample_rate)
model_stride = 4



model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride
tokens_per_chunk = ceil(chunk_len / model_stride_in_secs)
mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs)



def resample(audio):
    audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"], 
                            mono=True,  res_type='soxr_hq')
    return audio_16k


def model(audio_16k):
    logits, logits_len, greedy_predictions = asr_model.forward(
        input_signal=torch.tensor([audio_16k]), 
        input_signal_length=torch.tensor([len(audio_16k)])
    )
    return logits


def decode_predictions(logits_list):
    logits_len = logits_list[0].shape[1]
    # cut overhead
    cutted_logits = []
    for idx in range(len(logits_list)):
        start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay
        end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk
        logits = logits_list[idx][:, start_cut:end_cut]
        cutted_logits.append(logits)

    # join
    logits = torch.cat(cutted_logits, axis=1)
    logits_len = torch.tensor([logits.shape[1]])
    current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
        logits, decoder_lengths=logits_len, return_hypotheses=False,
    )

    return current_hypotheses[0]


def transcribe(audio, state):
    if state is None:
        state = [np.array([], dtype=np.float32), []]

    audio_16k = resample(audio)

    # join to audio sequence
    state[0] = np.concatenate([state[0], audio_16k])

    while (len(state[0]) > total_buffer):
        buffer = state[0][:total_buffer]
        state[0] = state[0][total_buffer - overhead_len:]
        # run model
        logits = model(buffer)
        # add logits
        state[1].append(logits)

    if len(state[1]) == 0:
        text = ""
    else:
        text = decode_predictions(state[1])
    return text, state


gr.Interface(
    fn=transcribe, 
    inputs=[
        gr.Audio(source="microphone", type="filepath", streaming=True), 
        gr.State(None)
    ],
    outputs=[
        "textbox",
        "state"
    ],
    live=True).launch()