sanchit-gandhi
commited on
Commit
·
569a4c9
1
Parent(s):
2c11eb6
Add scripts and weights
Browse files
run_speech_recognition_whisper.py
CHANGED
@@ -23,6 +23,7 @@ import os
|
|
23 |
import whisper
|
24 |
import sys
|
25 |
from dataclasses import dataclass, field
|
|
|
26 |
|
27 |
from typing import Optional, Dict, Union, List
|
28 |
|
@@ -275,7 +276,6 @@ class WhisperDataCollatorWithPadding:
|
|
275 |
"""
|
276 |
|
277 |
eos_token_id: int
|
278 |
-
time_stamp_token_id: int
|
279 |
|
280 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
281 |
"""
|
@@ -626,9 +626,7 @@ def main():
|
|
626 |
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
627 |
|
628 |
# Define data collator
|
629 |
-
|
630 |
-
t_stamp = tokenizer("<|notimestamps|>").input_ids[0]
|
631 |
-
whisper_data_collator = WhisperDataCollatorWithPadding(eos_token_id=eos, time_stamp_token_id=t_stamp)
|
632 |
|
633 |
# make sure model uses 50257 as BOS
|
634 |
bos = tokenizer("<|startoftranscript|>").input_ids[0]
|
|
|
23 |
import whisper
|
24 |
import sys
|
25 |
from dataclasses import dataclass, field
|
26 |
+
import tempfile
|
27 |
|
28 |
from typing import Optional, Dict, Union, List
|
29 |
|
|
|
276 |
"""
|
277 |
|
278 |
eos_token_id: int
|
|
|
279 |
|
280 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
281 |
"""
|
|
|
626 |
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
627 |
|
628 |
# Define data collator
|
629 |
+
whisper_data_collator = WhisperDataCollatorWithPadding(eos_token_id=tokenizer.eos_token_id)
|
|
|
|
|
630 |
|
631 |
# make sure model uses 50257 as BOS
|
632 |
bos = tokenizer("<|startoftranscript|>").input_ids[0]
|