File size: 3,619 Bytes
09ec1f1 02e5d0c 09ec1f1 9c6f6ce 09ec1f1 9c6f6ce 09ec1f1 d0dc758 32f9819 6775bac 9c6f6ce 09ec1f1 32f9819 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import numpy as np
import librosa
import torch
from math import ceil
import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.EncDecCTCModelBPE. \
from_pretrained("theodotus/stt_ua_fastconformer_hybrid_large_pc",map_location="cpu")
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()
buffer_len = 3.2
chunk_len = 0.8
total_buffer = round(buffer_len * asr_model.cfg.sample_rate)
overhead_len = round((buffer_len - chunk_len) * asr_model.cfg.sample_rate)
model_stride = 4
model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride
tokens_per_chunk = ceil(chunk_len / model_stride_in_secs)
mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs)
def resample(audio):
audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"],
mono=True, res_type='soxr_hq')
return audio_16k
def model(audio_16k):
logits, logits_len, greedy_predictions = asr_model.forward(
input_signal=torch.tensor([audio_16k]),
input_signal_length=torch.tensor([len(audio_16k)])
)
return logits
def decode_predictions(logits_list):
logits_len = logits_list[0].shape[1]
# cut overhead
cutted_logits = []
for idx in range(len(logits_list)):
start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay
end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk
logits = logits_list[idx][:, start_cut:end_cut]
cutted_logits.append(logits)
# join
logits = torch.cat(cutted_logits, axis=1)
logits_len = torch.tensor([logits.shape[1]])
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
logits, decoder_lengths=logits_len, return_hypotheses=False,
)
return current_hypotheses[0]
def transcribe_(audio, state):
if state is None:
state = [np.array([], dtype=np.float32), []]
audio_16k = resample(audio)
# join to audio sequence
state[0] = np.concatenate([state[0], audio_16k])
while (len(state[0]) > total_buffer):
buffer = state[0][:total_buffer]
state[0] = state[0][total_buffer - overhead_len:]
# run model
logits = model(buffer)
# add logits
state[1].append(logits)
if len(state[1]) == 0:
text = ""
else:
text = decode_predictions(state[1])
return text, state
def transcribe(audio, state, reset_state):
if reset_state:
state = [np.array([], dtype=np.float32), []]
if state is None:
state = [np.array([], dtype=np.float32), []]
audio_16k = resample(audio)
# join to audio sequence
state[0] = np.concatenate([state[0], audio_16k])
while (len(state[0]) > total_buffer):
buffer = state[0][:total_buffer]
state[0] = state[0][total_buffer - overhead_len:]
# run model
logits = model(buffer)
# add logits
state[1].append(logits)
if len(state[1]) == 0:
text = ""
else:
text = decode_predictions(state[1])
return text, state
gr.Interface(
fn=transcribe,
inputs=[
# gr.Audio(source="upload", type="filepath", streaming=True),
gr.Audio(source="upload", type="filepath"),
# "state"
gr.State(None),
gr.Button(text="Reset State", label="Reset State")
],
outputs=[
"textbox",
"state"
],
# live=True
).launch()
|