File size: 2,711 Bytes
0992503
 
61cb41b
0992503
 
1b56b73
0992503
 
 
 
ff8ca2b
0992503
 
 
 
 
 
 
 
1b56b73
 
 
 
0821d74
0992503
 
1b56b73
 
 
 
0992503
61cb41b
 
68b908e
0992503
 
 
 
 
 
 
 
 
 
 
 
1b56b73
0992503
 
 
1b56b73
 
 
0992503
 
 
 
 
 
 
 
 
 
 
 
0eb651e
 
0992503
61cb41b
0992503
 
 
 
2f1c221
0992503
 
 
 
 
 
 
 
 
 
 
0eb651e
0992503
 
 
 
 
61cb41b
0992503
 
 
 
03267a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import numpy as np
import librosa
import torch

from math import ceil
import nemo.collections.asr as nemo_asr


asr_model = nemo_asr.models.EncDecCTCModelBPE. \
                    from_pretrained("theodotus/stt_uk_squeezeformer_ctc_ml",map_location="cpu")

asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()


buffer_len = 8.0
chunk_len = 4.8
total_buffer = round(buffer_len * asr_model.cfg.sample_rate)
overhead_len = round((buffer_len - chunk_len) *  asr_model.cfg.sample_rate)
model_stride = 4


model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride
tokens_per_chunk = ceil(chunk_len / model_stride_in_secs)
mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs)


def resample(audio):
    audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"], 
                            mono=True,  res_type='soxr_hq')
    return audio_16k


def model(audio_16k):
    logits, logits_len, greedy_predictions = asr_model.forward(
        input_signal=torch.tensor([audio_16k]), 
        input_signal_length=torch.tensor([len(audio_16k)])
    )
    return logits


def decode_predictions(logits_list):
    logits_len = logits_list[0].shape[1]
    # cut overhead
    cutted_logits = []
    for idx in range(len(logits_list)):
        start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay
        end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk
        logits = logits_list[idx][:, start_cut:end_cut]
        cutted_logits.append(logits)

    # join
    logits = torch.cat(cutted_logits, axis=1)
    logits_len = torch.tensor([logits.shape[1]])
    current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
        logits, decoder_lengths=logits_len, return_hypotheses=False,
    )

    return current_hypotheses[0]


def transcribe(audio):
    state = [np.array([], dtype=np.float32), []]

    audio_16k = resample(audio)

    # join to audio sequence
    state[0] = np.concatenate([state[0], audio_16k])

    while (len(state[0]) > overhead_len) or (len(state[1]) == 0):
        buffer = state[0][:total_buffer]
        state[0] = state[0][total_buffer - overhead_len:]
        # run model
        logits = model(buffer)
        # add logits
        state[1].append(logits)

    if len(state[1]) == 0:
        text = ""
    else:
        text = decode_predictions(state[1])
    return text


gr.Interface(
    fn=transcribe, 
    inputs=[
        gr.Audio(source="upload", type="filepath"), 
    ],
    outputs=[
        "textbox",
    ],
).launch()