File size: 2,801 Bytes
056b612
4f90f68
0808d5f
4f90f68
056b612
c2163fe
056b612
 
 
 
f083e13
056b612
4f90f68
 
 
 
 
 
 
f083e13
c2163fe
 
 
f083e13
c2163fe
 
 
 
 
 
a20f918
 
4f90f68
0808d5f
 
ac5c028
4f90f68
 
 
adffa4f
4f90f68
 
 
 
2f05f3a
a20f918
2f05f3a
adffa4f
c2163fe
adffa4f
 
 
c2163fe
 
 
adffa4f
 
 
 
2f05f3a
4f90f68
 
 
 
 
056b612
 
a20f918
 
adffa4f
4f90f68
0808d5f
4f90f68
a20f918
 
 
2a5f9c9
 
 
a20f918
adffa4f
2f05f3a
adffa4f
2f05f3a
adffa4f
2a5f9c9
 
 
2f05f3a
056b612
 
 
 
b31fd8e
0808d5f
a20f918
b31fd8e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import numpy as np
import librosa
import torch

from math import ceil
import nemo.collections.asr as nemo_asr


asr_model = nemo_asr.models.EncDecCTCModelBPE. \
                    from_pretrained("theodotus/stt_uk_squeezeformer_ctc_sm",map_location="cpu")

asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()


buffer_len = 3.2
chunk_len = 0.8
total_buffer = round(buffer_len * asr_model.cfg.sample_rate)
overhead_len = round((buffer_len - chunk_len) *  asr_model.cfg.sample_rate)
model_stride = 4



model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride
tokens_per_chunk = ceil(chunk_len / model_stride_in_secs)
mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs)



def resample(audio):
    audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"], 
                            mono=True,  res_type='soxr_hq')
    return audio_16k


def model(audio_16k):
    logits, logits_len, greedy_predictions = asr_model.forward(
        input_signal=torch.tensor([audio_16k]), 
        input_signal_length=torch.tensor([len(audio_16k)])
    )
    return logits


def decode_predictions(logits_list):
    logits_len = logits_list[0].shape[1]
    # cut overhead
    cutted_logits = []
    for idx in range(len(logits_list)):
        start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay
        end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk
        logits = logits_list[idx][:, start_cut:end_cut]
        cutted_logits.append(logits)

    # join
    logits = torch.cat(cutted_logits, axis=1)
    logits_len = torch.tensor([logits.shape[1]])
    current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
        logits, decoder_lengths=logits_len, return_hypotheses=False,
    )

    return current_hypotheses[0]


def transcribe(audio, state):
    if state is None:
        state = [np.array([], dtype=np.float32), []]

    audio_16k = resample(audio)

    # join to audio sequence
    state[0] = np.concatenate([state[0], audio_16k])

    while (len(state[0]) > total_buffer):
        buffer = state[0][:total_buffer]
        state[0] = state[0][total_buffer - overhead_len:]
        # run model
        logits = model(buffer)
        # add logits
        state[1].append(logits)

    if len(state[1]) == 0:
        text = ""
    else:
        text = decode_predictions(state[1])
    return text, state


gr.Interface(
    fn=transcribe, 
    inputs=[
        gr.Audio(source="microphone", type="filepath", streaming=True), 
        gr.State(None)
    ],
    outputs=[
        "textbox",
        "state"
    ],
    live=True).launch()