theodotus commited on
Commit
adffa4f
1 Parent(s): 2a5f9c9

Added decoding at the end

Browse files
Files changed (1) hide show
  1. app.py +21 -20
app.py CHANGED
@@ -29,24 +29,30 @@ def resample(sr, audio_data):
29
  return audio_16k
30
 
31
 
32
- def model(audio_16k, is_start):
33
  logits, logits_len, greedy_predictions = asr_model.forward(
34
  input_signal=torch.tensor([audio_16k]),
35
  input_signal_length=torch.tensor([len(audio_16k)])
36
  )
37
-
38
- # cut overhead
39
- buffer_len = len(audio_16k)
40
- logits_overhead = (logits.shape[1] - 1) * overhead_len // buffer_len
41
- logits_overhead //= 2
42
- delay = (logits.shape[1] - 1) - (2 * logits_overhead)
43
- start_cut = 0 if is_start else logits_overhead
44
- delay += 0 if not is_start else logits_overhead
45
- logits = logits[:, start_cut:start_cut+delay]
46
  return logits
47
 
48
 
49
- def decode_predictions(logits):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  logits_len = torch.tensor([logits.shape[1]])
51
  current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
52
  logits, decoder_lengths=logits_len, return_hypotheses=False,
@@ -57,8 +63,7 @@ def decode_predictions(logits):
57
 
58
  def transcribe(audio, state):
59
  if state is None:
60
- state = [np.array([], dtype=np.float32), None]
61
- is_start = state[1] is None
62
 
63
  sr, audio_data = audio
64
  audio_16k = resample(sr, audio_data)
@@ -70,15 +75,11 @@ def transcribe(audio, state):
70
  buffer = state[0][:total_buffer]
71
  state[0] = state[0][total_buffer - overhead_len:]
72
  # run model
73
- is_start = state[1] is None
74
- logits = model(buffer, is_start)
75
  # add logits
76
- if is_start:
77
- state[1] = logits
78
- else:
79
- state[1] = torch.cat([state[1],logits], axis=1)
80
 
81
- if is_start:
82
  text = ""
83
  else:
84
  text = decode_predictions(state[1])
 
29
  return audio_16k
30
 
31
 
32
+ def model(audio_16k):
33
  logits, logits_len, greedy_predictions = asr_model.forward(
34
  input_signal=torch.tensor([audio_16k]),
35
  input_signal_length=torch.tensor([len(audio_16k)])
36
  )
 
 
 
 
 
 
 
 
 
37
  return logits
38
 
39
 
40
+ def decode_predictions(logits_list):
41
+ # calc overhead
42
+ logits_overhead = logits_list[0].shape[1] * overhead_len // total_buffer
43
+ logits_overhead //= 2
44
+ #delay = (logits.shape[1] - 1) - (2 * logits_overhead)
45
+
46
+ # cut overhead
47
+ cutted_logits = []
48
+ for idx in range(len(logits_list)):
49
+ start_cut = 0 if (idx==0) else logits_overhead
50
+ end_cut = 1 if (idx==len(logits_list)-1) else logits_overhead
51
+ logits = logits_list[idx][:, start_cut:-end_cut]
52
+ cutted_logits.append(logits)
53
+
54
+ # join
55
+ logits = torch.cat(cutted_logits, axis=1)
56
  logits_len = torch.tensor([logits.shape[1]])
57
  current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
58
  logits, decoder_lengths=logits_len, return_hypotheses=False,
 
63
 
64
  def transcribe(audio, state):
65
  if state is None:
66
+ state = [np.array([], dtype=np.float32), []]
 
67
 
68
  sr, audio_data = audio
69
  audio_16k = resample(sr, audio_data)
 
75
  buffer = state[0][:total_buffer]
76
  state[0] = state[0][total_buffer - overhead_len:]
77
  # run model
78
+ logits = model(buffer)
 
79
  # add logits
80
+ state[1].append(logits)
 
 
 
81
 
82
+ if len(state[1]) == 0:
83
  text = ""
84
  else:
85
  text = decode_predictions(state[1])