sanchit-gandhi HF staff commited on
Commit
569a4c9
1 Parent(s): 2c11eb6

Add scripts and weights

Browse files
Files changed (1) hide show
  1. run_speech_recognition_whisper.py +2 -4
run_speech_recognition_whisper.py CHANGED
@@ -23,6 +23,7 @@ import os
23
  import whisper
24
  import sys
25
  from dataclasses import dataclass, field
 
26
 
27
  from typing import Optional, Dict, Union, List
28
 
@@ -275,7 +276,6 @@ class WhisperDataCollatorWithPadding:
275
  """
276
 
277
  eos_token_id: int
278
- time_stamp_token_id: int
279
 
280
  def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
281
  """
@@ -626,9 +626,7 @@ def main():
626
  torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
627
 
628
  # Define data collator
629
- eos = tokenizer.eos_token_id
630
- t_stamp = tokenizer("<|notimestamps|>").input_ids[0]
631
- whisper_data_collator = WhisperDataCollatorWithPadding(eos_token_id=eos, time_stamp_token_id=t_stamp)
632
 
633
  # make sure model uses 50257 as BOS
634
  bos = tokenizer("<|startoftranscript|>").input_ids[0]
 
23
  import whisper
24
  import sys
25
  from dataclasses import dataclass, field
26
+ import tempfile
27
 
28
  from typing import Optional, Dict, Union, List
29
 
 
276
  """
277
 
278
  eos_token_id: int
 
279
 
280
  def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
281
  """
 
626
  torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
627
 
628
  # Define data collator
629
+ whisper_data_collator = WhisperDataCollatorWithPadding(eos_token_id=tokenizer.eos_token_id)
 
 
630
 
631
  # make sure model uses 50257 as BOS
632
  bos = tokenizer("<|startoftranscript|>").input_ids[0]