theodotus commited on
Commit
c2163fe
1 Parent(s): f5f9215

Used Nemo streaming logic

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import resampy
4
  import torch
5
 
6
- from math import floor,ceil
7
  import nemo.collections.asr as nemo_asr
8
 
9
 
@@ -17,9 +17,17 @@ asr_model.encoder.freeze()
17
  asr_model.decoder.freeze()
18
 
19
 
20
- total_buffer = asr_model.cfg["sample_rate"] * 19 // 10
21
- overhead_len = total_buffer * 5 // 8
22
- model_stride = 4
 
 
 
 
 
 
 
 
23
 
24
 
25
 
@@ -39,19 +47,13 @@ def model(audio_16k):
39
 
40
 
41
  def decode_predictions(logits_list):
42
- # calc overhead
43
- logits_overhead = logits_list[0].shape[1] * overhead_len / total_buffer / 2
44
- if (logits_overhead * 2 != int(logits_overhead * 2)):
45
- raise ValueError("Wrong total_buffer")
46
-
47
  # cut overhead
48
  cutted_logits = []
49
  for idx in range(len(logits_list)):
50
- start_cut = 0 if (idx==0) else floor(logits_overhead)
51
- end_cut = 1 if (idx==len(logits_list)-1) else ceil(logits_overhead)
52
- if (logits_overhead == int(logits_overhead)) and (end_cut != 1):
53
- end_cut +=1
54
- logits = logits_list[idx][:, start_cut:-end_cut]
55
  cutted_logits.append(logits)
56
 
57
  # join
 
3
  import resampy
4
  import torch
5
 
6
+ from math import ceil
7
  import nemo.collections.asr as nemo_asr
8
 
9
 
 
17
  asr_model.decoder.freeze()
18
 
19
 
20
+ buffer_len = 1.6
21
+ chunk_len = 0.8
22
+ total_buffer = round(buffer_len * asr_model.cfg.sample_rate)
23
+ overhead_len = round((buffer_len - chunk_len) * asr_model.cfg.sample_rate)
24
+ model_stride = 8
25
+
26
+
27
+
28
+ model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride
29
+ tokens_per_chunk = ceil(chunk_len / model_stride_in_secs)
30
+ mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs)
31
 
32
 
33
 
 
47
 
48
 
49
  def decode_predictions(logits_list):
50
+ logits_len = logits_list[0].shape[1]
 
 
 
 
51
  # cut overhead
52
  cutted_logits = []
53
  for idx in range(len(logits_list)):
54
+ start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay
55
+ end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk
56
+ logits = logits_list[idx][:, start_cut:end_cut]
 
 
57
  cutted_logits.append(logits)
58
 
59
  # join