Spaces:
Sleeping
Sleeping
File size: 2,801 Bytes
056b612 4f90f68 0808d5f 4f90f68 056b612 c2163fe 056b612 f083e13 056b612 4f90f68 f083e13 c2163fe f083e13 c2163fe a20f918 4f90f68 0808d5f ac5c028 4f90f68 adffa4f 4f90f68 2f05f3a a20f918 2f05f3a adffa4f c2163fe adffa4f c2163fe adffa4f 2f05f3a 4f90f68 056b612 a20f918 adffa4f 4f90f68 0808d5f 4f90f68 a20f918 2a5f9c9 a20f918 adffa4f 2f05f3a adffa4f 2f05f3a adffa4f 2a5f9c9 2f05f3a 056b612 b31fd8e 0808d5f a20f918 b31fd8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
import numpy as np
import librosa
import torch
from math import ceil
import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.EncDecCTCModelBPE. \
from_pretrained("theodotus/stt_uk_squeezeformer_ctc_sm",map_location="cpu")
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()
buffer_len = 3.2
chunk_len = 0.8
total_buffer = round(buffer_len * asr_model.cfg.sample_rate)
overhead_len = round((buffer_len - chunk_len) * asr_model.cfg.sample_rate)
model_stride = 4
model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride
tokens_per_chunk = ceil(chunk_len / model_stride_in_secs)
mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs)
def resample(audio):
audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"],
mono=True, res_type='soxr_hq')
return audio_16k
def model(audio_16k):
logits, logits_len, greedy_predictions = asr_model.forward(
input_signal=torch.tensor([audio_16k]),
input_signal_length=torch.tensor([len(audio_16k)])
)
return logits
def decode_predictions(logits_list):
logits_len = logits_list[0].shape[1]
# cut overhead
cutted_logits = []
for idx in range(len(logits_list)):
start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay
end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk
logits = logits_list[idx][:, start_cut:end_cut]
cutted_logits.append(logits)
# join
logits = torch.cat(cutted_logits, axis=1)
logits_len = torch.tensor([logits.shape[1]])
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
logits, decoder_lengths=logits_len, return_hypotheses=False,
)
return current_hypotheses[0]
def transcribe(audio, state):
if state is None:
state = [np.array([], dtype=np.float32), []]
audio_16k = resample(audio)
# join to audio sequence
state[0] = np.concatenate([state[0], audio_16k])
while (len(state[0]) > total_buffer):
buffer = state[0][:total_buffer]
state[0] = state[0][total_buffer - overhead_len:]
# run model
logits = model(buffer)
# add logits
state[1].append(logits)
if len(state[1]) == 0:
text = ""
else:
text = decode_predictions(state[1])
return text, state
gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(source="microphone", type="filepath", streaming=True),
gr.State(None)
],
outputs=[
"textbox",
"state"
],
live=True).launch() |