Lakoc commited on
Commit
20efd0f
1 Parent(s): 4020d91

Update ctc_scorer.py

Browse files
Files changed (1) hide show
  1. ctc_scorer.py +66 -12
ctc_scorer.py CHANGED
@@ -1,14 +1,7 @@
1
  # pylint: skip-file
2
  # Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
3
  import torch
4
- from transformers import GenerationConfig, LogitsProcessor
5
-
6
-
7
- class GenerationConfigWithCTC(GenerationConfig):
8
- def __init__(self, ctc_weight=0.0, ctc_margin=0, **kwargs):
9
- super().__init__(**kwargs)
10
- self.ctc_weight = ctc_weight
11
- self.ctc_margin = ctc_margin
12
 
13
 
14
  class CTCPrefixScoreTH(object):
@@ -93,7 +86,7 @@ class CTCPrefixScoreTH(object):
93
  else:
94
  r_prev, s_prev, f_min_prev, f_max_prev = state
95
 
96
- # select input dimensions for scoring
97
  if self.scoring_num > 0:
98
  scoring_idmap = torch.full((n_bh, self.odim), -1, dtype=torch.long, device=self.device)
99
  snum = self.scoring_num
@@ -173,8 +166,8 @@ class CTCPrefixScoreTH(object):
173
  dim=0,
174
  )
175
 
176
- for si in range(n_bh):
177
- log_psi[si, self.eos] = max(log_psi[si, self.eos], r_sum[self.end_frames[si // n_hyps], si])
178
 
179
  # exclude blank probs
180
  log_psi[:, self.blank] = self.logzero
@@ -273,8 +266,14 @@ class CTCRescorerLogitsProcessor(LogitsProcessor):
273
  ctc_margin: int,
274
  ctc_weight: float,
275
  num_beams: int,
 
 
 
 
276
  ):
277
  super().__init__()
 
 
278
  self.pad_token_id = pad_token_id
279
  self.ctc_prefix_scorer = CTCPrefixScoreTH(
280
  torch.nn.functional.log_softmax(encoder_logits, dim=-1),
@@ -286,6 +285,41 @@ class CTCRescorerLogitsProcessor(LogitsProcessor):
286
  self.ctc_weight = ctc_weight
287
  self.ctc_states = None
288
  self.num_beams = num_beams
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
291
  scores[:, self.pad_token_id] = self.ctc_prefix_scorer.logzero
@@ -296,7 +330,27 @@ class CTCRescorerLogitsProcessor(LogitsProcessor):
296
  ctc_scores, ctc_states = self.ctc_prefix_scorer(input_ids, self.ctc_states)
297
  self.ctc_states = ctc_states
298
  next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
299
- # return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  return next_token_scores
301
 
302
 
 
1
  # pylint: skip-file
2
  # Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
3
  import torch
4
+ from transformers import LogitsProcessor
 
 
 
 
 
 
 
5
 
6
 
7
  class CTCPrefixScoreTH(object):
 
86
  else:
87
  r_prev, s_prev, f_min_prev, f_max_prev = state
88
 
89
+ # select input dimensions for decred_scoring
90
  if self.scoring_num > 0:
91
  scoring_idmap = torch.full((n_bh, self.odim), -1, dtype=torch.long, device=self.device)
92
  snum = self.scoring_num
 
166
  dim=0,
167
  )
168
 
169
+ # for si in range(n_bh):
170
+ # log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
171
 
172
  # exclude blank probs
173
  log_psi[:, self.blank] = self.logzero
 
266
  ctc_margin: int,
267
  ctc_weight: float,
268
  num_beams: int,
269
+ space_token_id: int,
270
+ apply_eos_space_trick: bool,
271
+ eos_space_trick_weight: float,
272
+ debug: bool = False,
273
  ):
274
  super().__init__()
275
+ # reduce_lens_by = (encoder_logits.argmax(dim=-1) == eos_token_id).sum(dim=-1)
276
+ # encoder_output_lens = encoder_output_lens - reduce_lens_by
277
  self.pad_token_id = pad_token_id
278
  self.ctc_prefix_scorer = CTCPrefixScoreTH(
279
  torch.nn.functional.log_softmax(encoder_logits, dim=-1),
 
285
  self.ctc_weight = ctc_weight
286
  self.ctc_states = None
287
  self.num_beams = num_beams
288
+ self.eos_token_id = eos_token_id
289
+ self.apply_eos_space_trick = apply_eos_space_trick
290
+ self.space_token_id = space_token_id
291
+ self.eos_space_trick_weight = eos_space_trick_weight
292
+ self.debug = debug
293
+
294
+ @staticmethod
295
+ def analyze_predictions(
296
+ scores, ctc_scores, next_token_scores, input_ids, k=10, tokenizer="Lakoc/english_corpus_uni5000_normalized"
297
+ ):
298
+ from transformers import AutoTokenizer
299
+
300
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
301
+ best_att_ids = scores.topk(k=k, dim=1)
302
+ best_ctc_ids = ctc_scores.topk(k=k, dim=1)
303
+ best_ids = next_token_scores.topk(k=k, dim=1)
304
+
305
+ def print_prediction(best_ids, name):
306
+ new_tensor = torch.zeros((best_ids.indices.shape[0], best_ids.indices.shape[1] * 2), dtype=torch.long)
307
+ new_tensor[:, 0::2] = best_ids.indices
308
+ new_tensor[:, 1::2] = 4976
309
+ print(f"{name}:")
310
+ for index, (next_ids, scores) in enumerate(zip(tokenizer.batch_decode(new_tensor), best_ids.values)):
311
+ print(f"HYP {index}:\n{next_ids} {scores}")
312
+
313
+ print(f"PREFIX:")
314
+ for index, prefix in enumerate(tokenizer.batch_decode(input_ids)):
315
+ print(f"HYP {index}:\n{prefix}")
316
+ print_prediction(best_att_ids, "ATT_SCORES")
317
+ print()
318
+ print_prediction(best_ctc_ids, "CTC_SCORES")
319
+ print()
320
+ print(f"CTC_EOS: {ctc_scores[:, 1]}")
321
+ print_prediction(best_ids, "NEXT_TOKEN_SCORES")
322
+ print()
323
 
324
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
325
  scores[:, self.pad_token_id] = self.ctc_prefix_scorer.logzero
 
330
  ctc_scores, ctc_states = self.ctc_prefix_scorer(input_ids, self.ctc_states)
331
  self.ctc_states = ctc_states
332
  next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
333
+ if self.apply_eos_space_trick:
334
+ space_eos_conflict = torch.logical_and(
335
+ scores.argmax(dim=1) == self.eos_token_id, ctc_scores.argmax(dim=1) == self.space_token_id
336
+ )
337
+ if space_eos_conflict.any():
338
+ apply_trick_on = torch.logical_and(
339
+ torch.logical_and(
340
+ space_eos_conflict,
341
+ next_token_scores[:, self.eos_token_id] < next_token_scores[:, self.space_token_id],
342
+ ),
343
+ self.eos_space_trick_weight * next_token_scores[:, self.eos_token_id]
344
+ > next_token_scores[:, self.space_token_id],
345
+ )
346
+ if apply_trick_on.any():
347
+ next_token_scores[apply_trick_on, self.eos_token_id] = (
348
+ next_token_scores[apply_trick_on, self.eos_token_id] * self.eos_space_trick_weight
349
+ )
350
+
351
+ if self.debug:
352
+ self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids)
353
+
354
  return next_token_scores
355
 
356