import gradio as gr import numpy as np import librosa import torch from math import ceil import nemo.collections.asr as nemo_asr asr_model = nemo_asr.models.EncDecCTCModelBPE. \ from_pretrained("theodotus/stt_ua_fastconformer_hybrid_large_pc",map_location="cpu") asr_model.preprocessor.featurizer.dither = 0.0 asr_model.preprocessor.featurizer.pad_to = 0 asr_model.eval() asr_model.encoder.freeze() asr_model.decoder.freeze() buffer_len = 3.2 chunk_len = 0.8 total_buffer = round(buffer_len * asr_model.cfg.sample_rate) overhead_len = round((buffer_len - chunk_len) * asr_model.cfg.sample_rate) model_stride = 4 model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride tokens_per_chunk = ceil(chunk_len / model_stride_in_secs) mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs) def resample(audio): audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"], mono=True, res_type='soxr_hq') return audio_16k def model(audio_16k): logits, logits_len, greedy_predictions = asr_model.forward( input_signal=torch.tensor([audio_16k]), input_signal_length=torch.tensor([len(audio_16k)]) ) return logits def decode_predictions(logits_list): logits_len = logits_list[0].shape[1] # cut overhead cutted_logits = [] for idx in range(len(logits_list)): start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk logits = logits_list[idx][:, start_cut:end_cut] cutted_logits.append(logits) # join logits = torch.cat(cutted_logits, axis=1) logits_len = torch.tensor([logits.shape[1]]) current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor( logits, decoder_lengths=logits_len, return_hypotheses=False, ) return current_hypotheses[0] def transcribe_(audio, state): if state is None: state = [np.array([], dtype=np.float32), []] audio_16k = resample(audio) # join to audio sequence state[0] = np.concatenate([state[0], audio_16k]) while (len(state[0]) > total_buffer): buffer = state[0][:total_buffer] state[0] = state[0][total_buffer - overhead_len:] # run model logits = model(buffer) # add logits state[1].append(logits) if len(state[1]) == 0: text = "" else: text = decode_predictions(state[1]) return text, state def transcribe(audio, state, reset_state): if reset_state: state = [np.array([], dtype=np.float32), []] if state is None: state = [np.array([], dtype=np.float32), []] audio_16k = resample(audio) # join to audio sequence state[0] = np.concatenate([state[0], audio_16k]) while (len(state[0]) > total_buffer): buffer = state[0][:total_buffer] state[0] = state[0][total_buffer - overhead_len:] # run model logits = model(buffer) # add logits state[1].append(logits) if len(state[1]) == 0: text = "" else: text = decode_predictions(state[1]) return text, state gr.Interface( fn=transcribe, inputs=[ # gr.Audio(source="upload", type="filepath", streaming=True), gr.Audio(source="upload", type="filepath"), # "state" gr.State(None), gr.Button(text="Reset State", label="Reset State") ], outputs=[ "textbox", "state" ], # live=True ).launch()