Automatic Speech Recognition
Transformers
4 languages
whisper
whisper-event
Generated from Trainer
Inference Endpoints
marinone94 commited on
Commit
aac3fbb
1 Parent(s): 9e467b3

use decode to inspect ds

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -294,6 +294,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
294
  processor: Any
295
  decoder_start_token_id: int
296
  task_id: int
 
 
297
 
298
  def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
299
  # split inputs and labels since they have to be of different lengths and need
@@ -312,8 +314,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
312
 
313
  # if bos token is appended in previous tokenization step,
314
  # cut bos token here as it's append later anyways
315
- # if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
316
- # labels = labels[:, 1:]
317
  # lang_token_ids = self.processor.tokenizer(lang_features).input_ids
318
  # # Replace language and task if they are in the beginning, otherwise add them
319
  # if (labels[:, 1] == self.task_id).all().cpu().item():
@@ -328,6 +329,15 @@ class DataCollatorSpeechSeq2SeqWithPadding:
328
  # labels[:, 0] = torch.full_like(labels[:, 0], -100)
329
  # labels[:, 1] = torch.full_like(labels[:, 1], -100)
330
 
 
 
 
 
 
 
 
 
 
331
  batch["labels"] = labels
332
 
333
  return batch
@@ -461,7 +471,7 @@ def load_maybe_streaming_dataset(
461
  def print_data_samples(dataset, tokenizer, max_samples=5):
462
  shown_samples = 0
463
  for batch in dataset:
464
- print("Target: ", tokenizer.batch_decode(batch["labels"]))
465
  shown_samples += len(batch)
466
  if shown_samples >= max_samples:
467
  break
 
294
  processor: Any
295
  decoder_start_token_id: int
296
  task_id: int
297
+ # TODO: remove - infer language from dataset
298
+ language_id: int = -100
299
 
300
  def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
301
  # split inputs and labels since they have to be of different lengths and need
 
314
 
315
  # if bos token is appended in previous tokenization step,
316
  # cut bos token here as it's append later anyways
317
+
 
318
  # lang_token_ids = self.processor.tokenizer(lang_features).input_ids
319
  # # Replace language and task if they are in the beginning, otherwise add them
320
  # if (labels[:, 1] == self.task_id).all().cpu().item():
 
329
  # labels[:, 0] = torch.full_like(labels[:, 0], -100)
330
  # labels[:, 1] = torch.full_like(labels[:, 1], -100)
331
 
332
+ # remove start of sentence token from labels
333
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
334
+ labels = labels[:, 1:]
335
+
336
+ # add start of sentence token to labels + language + task
337
+ labels = torch.cat((torch.full_like(labels[:, 0], self.task_id), labels), dim=1)
338
+ labels = torch.cat((torch.full_like(labels[:, 0], self.language_id), labels), dim=1)
339
+ labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id), labels), dim=1)
340
+
341
  batch["labels"] = labels
342
 
343
  return batch
 
471
  def print_data_samples(dataset, tokenizer, max_samples=5):
472
  shown_samples = 0
473
  for batch in dataset:
474
+ print("Target: ", tokenizer.decode(batch["labels"]))
475
  shown_samples += len(batch)
476
  if shown_samples >= max_samples:
477
  break