|
import gradio as gr |
|
import numpy as np |
|
import librosa |
|
import torch |
|
|
|
from math import ceil |
|
import nemo.collections.asr as nemo_asr |
|
|
|
|
|
asr_model = nemo_asr.models.EncDecCTCModelBPE. \ |
|
from_pretrained("theodotus/stt_ua_fastconformer_hybrid_large_pc",map_location="cpu") |
|
|
|
asr_model.preprocessor.featurizer.dither = 0.0 |
|
asr_model.preprocessor.featurizer.pad_to = 0 |
|
asr_model.eval() |
|
asr_model.encoder.freeze() |
|
asr_model.decoder.freeze() |
|
|
|
|
|
buffer_len = 3.2 |
|
chunk_len = 0.8 |
|
total_buffer = round(buffer_len * asr_model.cfg.sample_rate) |
|
overhead_len = round((buffer_len - chunk_len) * asr_model.cfg.sample_rate) |
|
model_stride = 4 |
|
|
|
|
|
|
|
model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride |
|
tokens_per_chunk = ceil(chunk_len / model_stride_in_secs) |
|
mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs) |
|
|
|
|
|
|
|
def resample(audio): |
|
audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"], |
|
mono=True, res_type='soxr_hq') |
|
return audio_16k |
|
|
|
|
|
def model(audio_16k): |
|
logits, logits_len, greedy_predictions = asr_model.forward( |
|
input_signal=torch.tensor([audio_16k]), |
|
input_signal_length=torch.tensor([len(audio_16k)]) |
|
) |
|
return logits |
|
|
|
|
|
def decode_predictions(logits_list): |
|
logits_len = logits_list[0].shape[1] |
|
|
|
cutted_logits = [] |
|
for idx in range(len(logits_list)): |
|
start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay |
|
end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk |
|
logits = logits_list[idx][:, start_cut:end_cut] |
|
cutted_logits.append(logits) |
|
|
|
|
|
logits = torch.cat(cutted_logits, axis=1) |
|
logits_len = torch.tensor([logits.shape[1]]) |
|
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor( |
|
logits, decoder_lengths=logits_len, return_hypotheses=False, |
|
) |
|
|
|
return current_hypotheses[0] |
|
|
|
|
|
def transcribe_(audio, state): |
|
if state is None: |
|
state = [np.array([], dtype=np.float32), []] |
|
|
|
audio_16k = resample(audio) |
|
|
|
|
|
state[0] = np.concatenate([state[0], audio_16k]) |
|
|
|
while (len(state[0]) > total_buffer): |
|
buffer = state[0][:total_buffer] |
|
state[0] = state[0][total_buffer - overhead_len:] |
|
|
|
logits = model(buffer) |
|
|
|
state[1].append(logits) |
|
|
|
if len(state[1]) == 0: |
|
text = "" |
|
else: |
|
text = decode_predictions(state[1]) |
|
return text, state |
|
|
|
def transcribe(audio, state, reset_state): |
|
if reset_state: |
|
state = [np.array([], dtype=np.float32), []] |
|
|
|
if state is None: |
|
state = [np.array([], dtype=np.float32), []] |
|
|
|
audio_16k = resample(audio) |
|
|
|
|
|
state[0] = np.concatenate([state[0], audio_16k]) |
|
|
|
while (len(state[0]) > total_buffer): |
|
buffer = state[0][:total_buffer] |
|
state[0] = state[0][total_buffer - overhead_len:] |
|
|
|
logits = model(buffer) |
|
|
|
state[1].append(logits) |
|
|
|
if len(state[1]) == 0: |
|
text = "" |
|
else: |
|
text = decode_predictions(state[1]) |
|
return text, state |
|
|
|
gr.Interface( |
|
fn=transcribe, |
|
inputs=[ |
|
|
|
gr.Audio(source="upload", type="filepath"), |
|
|
|
gr.State(None), |
|
gr.Button(text="Reset State", label="Reset State") |
|
], |
|
outputs=[ |
|
"textbox", |
|
"state" |
|
], |
|
|
|
).launch() |
|
|