Automatic Speech Recognition
Transformers
4 languages
whisper
whisper-event
Generated from Trainer
Inference Endpoints
marinone94 commited on
Commit
bb38ae2
1 Parent(s): 0713f7f

reset to set prefix tokens

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -330,13 +330,13 @@ class DataCollatorSpeechSeq2SeqWithPadding:
330
  # labels[:, 1] = torch.full_like(labels[:, 1], -100)
331
 
332
  # remove start of sentence token from labels
333
- if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
334
- labels = labels[:, 1:]
335
 
336
- # add start of sentence token to labels + language + task
337
- labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim=-1)
338
- labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim=-1)
339
- labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim=-1)
340
 
341
  batch["labels"] = labels
342
 
@@ -640,14 +640,16 @@ def main():
640
 
641
  if model_args.freeze_encoder:
642
  model.freeze_encoder()
643
-
644
- if data_args.language_train is not None and len(data_args.language_train.split(",")) == 1:
645
- # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
646
- # If more than a langugae is specified, it will be specified in the data collator
647
- tokenizer.set_prefix_tokens(language=data_args.language_train, task=data_args.task)
648
- elif data_args.language_train is not None and len(data_args.language_train.split(",")) > 1:
649
- # make sure language and task are not stored in the model config
650
- model.config.forced_decoder_ids = None
 
 
651
 
652
  # 6. Resample speech dataset if necessary
653
  # logger.info("*** Resample dataset ***")
 
330
  # labels[:, 1] = torch.full_like(labels[:, 1], -100)
331
 
332
  # remove start of sentence token from labels
333
+ # if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
334
+ # labels = labels[:, 1:]
335
 
336
+ # # add start of sentence token to labels + language + task
337
+ # labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim=-1)
338
+ # labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim=-1)
339
+ # labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim=-1)
340
 
341
  batch["labels"] = labels
342
 
 
640
 
641
  if model_args.freeze_encoder:
642
  model.freeze_encoder()
643
+
644
+ tokenizer.set_prefix_tokens(language=data_args.language_train, task=data_args.task)
645
+
646
+ # if data_args.language_train is not None and len(data_args.language_train.split(",")) == 1:
647
+ # # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
648
+ # # If more than a langugae is specified, it will be specified in the data collator
649
+ # tokenizer.set_prefix_tokens(language=data_args.language_train, task=data_args.task)
650
+ # elif data_args.language_train is not None and len(data_args.language_train.split(",")) > 1:
651
+ # # make sure language and task are not stored in the model config
652
+ # model.config.forced_decoder_ids = None
653
 
654
  # 6. Resample speech dataset if necessary
655
  # logger.info("*** Resample dataset ***")