File size: 3,619 Bytes
09ec1f1
 
 
 
 
 
 
 
 
 
02e5d0c
09ec1f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c6f6ce
09ec1f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c6f6ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09ec1f1
 
 
 
d0dc758
32f9819
6775bac
9c6f6ce
 
09ec1f1
 
 
 
 
32f9819
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import numpy as np
import librosa
import torch

from math import ceil
import nemo.collections.asr as nemo_asr


asr_model = nemo_asr.models.EncDecCTCModelBPE. \
                    from_pretrained("theodotus/stt_ua_fastconformer_hybrid_large_pc",map_location="cpu")

asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()


buffer_len = 3.2
chunk_len = 0.8
total_buffer = round(buffer_len * asr_model.cfg.sample_rate)
overhead_len = round((buffer_len - chunk_len) *  asr_model.cfg.sample_rate)
model_stride = 4



model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride
tokens_per_chunk = ceil(chunk_len / model_stride_in_secs)
mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs)



def resample(audio):
    audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"], 
                            mono=True,  res_type='soxr_hq')
    return audio_16k


def model(audio_16k):
    logits, logits_len, greedy_predictions = asr_model.forward(
        input_signal=torch.tensor([audio_16k]), 
        input_signal_length=torch.tensor([len(audio_16k)])
    )
    return logits


def decode_predictions(logits_list):
    logits_len = logits_list[0].shape[1]
    # cut overhead
    cutted_logits = []
    for idx in range(len(logits_list)):
        start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay
        end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk
        logits = logits_list[idx][:, start_cut:end_cut]
        cutted_logits.append(logits)

    # join
    logits = torch.cat(cutted_logits, axis=1)
    logits_len = torch.tensor([logits.shape[1]])
    current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
        logits, decoder_lengths=logits_len, return_hypotheses=False,
    )

    return current_hypotheses[0]


def transcribe_(audio, state):
    if state is None:
        state = [np.array([], dtype=np.float32), []]

    audio_16k = resample(audio)

    # join to audio sequence
    state[0] = np.concatenate([state[0], audio_16k])

    while (len(state[0]) > total_buffer):
        buffer = state[0][:total_buffer]
        state[0] = state[0][total_buffer - overhead_len:]
        # run model
        logits = model(buffer)
        # add logits
        state[1].append(logits)

    if len(state[1]) == 0:
        text = ""
    else:
        text = decode_predictions(state[1])
    return text, state

def transcribe(audio, state, reset_state):
    if reset_state:
        state = [np.array([], dtype=np.float32), []]

    if state is None:
        state = [np.array([], dtype=np.float32), []]

    audio_16k = resample(audio)

    # join to audio sequence
    state[0] = np.concatenate([state[0], audio_16k])

    while (len(state[0]) > total_buffer):
        buffer = state[0][:total_buffer]
        state[0] = state[0][total_buffer - overhead_len:]
        # run model
        logits = model(buffer)
        # add logits
        state[1].append(logits)

    if len(state[1]) == 0:
        text = ""
    else:
        text = decode_predictions(state[1])
    return text, state

gr.Interface(
    fn=transcribe, 
    inputs=[
        # gr.Audio(source="upload", type="filepath", streaming=True),
        gr.Audio(source="upload", type="filepath"),
        # "state"
        gr.State(None),
        gr.Button(text="Reset State", label="Reset State")
    ],
    outputs=[
        "textbox",
        "state"
    ],
    # live=True
).launch()