Spaces:
Running
Running
File size: 2,711 Bytes
0992503 61cb41b 0992503 1b56b73 0992503 ff8ca2b 0992503 1b56b73 0821d74 0992503 1b56b73 0992503 61cb41b 68b908e 0992503 1b56b73 0992503 1b56b73 0992503 0eb651e 0992503 61cb41b 0992503 2f1c221 0992503 0eb651e 0992503 61cb41b 0992503 03267a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import numpy as np
import librosa
import torch
from math import ceil
import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.EncDecCTCModelBPE. \
from_pretrained("theodotus/stt_uk_squeezeformer_ctc_ml",map_location="cpu")
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()
buffer_len = 8.0
chunk_len = 4.8
total_buffer = round(buffer_len * asr_model.cfg.sample_rate)
overhead_len = round((buffer_len - chunk_len) * asr_model.cfg.sample_rate)
model_stride = 4
model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride
tokens_per_chunk = ceil(chunk_len / model_stride_in_secs)
mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs)
def resample(audio):
audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"],
mono=True, res_type='soxr_hq')
return audio_16k
def model(audio_16k):
logits, logits_len, greedy_predictions = asr_model.forward(
input_signal=torch.tensor([audio_16k]),
input_signal_length=torch.tensor([len(audio_16k)])
)
return logits
def decode_predictions(logits_list):
logits_len = logits_list[0].shape[1]
# cut overhead
cutted_logits = []
for idx in range(len(logits_list)):
start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay
end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk
logits = logits_list[idx][:, start_cut:end_cut]
cutted_logits.append(logits)
# join
logits = torch.cat(cutted_logits, axis=1)
logits_len = torch.tensor([logits.shape[1]])
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
logits, decoder_lengths=logits_len, return_hypotheses=False,
)
return current_hypotheses[0]
def transcribe(audio):
state = [np.array([], dtype=np.float32), []]
audio_16k = resample(audio)
# join to audio sequence
state[0] = np.concatenate([state[0], audio_16k])
while (len(state[0]) > overhead_len) or (len(state[1]) == 0):
buffer = state[0][:total_buffer]
state[0] = state[0][total_buffer - overhead_len:]
# run model
logits = model(buffer)
# add logits
state[1].append(logits)
if len(state[1]) == 0:
text = ""
else:
text = decode_predictions(state[1])
return text
gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(source="upload", type="filepath"),
],
outputs=[
"textbox",
],
).launch() |