rosyvs commited on
Commit
e404b97
1 Parent(s): f60eaa1

new READMe, tidy up main and add hparams

Browse files
Files changed (3) hide show
  1. README.md +35 -7
  2. hparams.yaml +50 -0
  3. main.py +25 -194
README.md CHANGED
@@ -8,12 +8,13 @@ Model trained in int8 with LoRA
8
 
9
  Usage:
10
 
11
- prepare pipeline, setting to default generate_opts will give you (deterministic) greedy decoding with up to 112 tokens generated, no repetition penalty:
12
 
13
  ```
14
  asr_model=prepare_pipeline(
15
  model_dir='.', # wherever you save the model
16
- generate_opts={'max_new_tokens':112,
 
17
  'num_beams':1,
18
  'repetition_penalty':1,
19
  'do_sample':False
@@ -25,8 +26,35 @@ run ASR:
25
  asr_model(audio_path)
26
  ```
27
 
28
- See also:
29
- https://github.com/rosyvs/isatasr
30
- Model is on Github at https://github.com/rosyvs/isatasr/tree/main/models/whisat-1.2
31
- Training script: https://github.com/rosyvs/isatasr/blob/main/train/whisat/tune_hf_whisper.py
32
- Training hyperparameters: https://github.com/rosyvs/isatasr/blob/main/train/whisat/hparams/redo_for_ICASSP/publicKS_ig_hf_LoRA_int8_largev2.yaml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  Usage:
10
 
11
+ prepare pipeline, providing any custom generate_kwargs supprted by https://huggingface.co/docs/transformers/v4.40.0/en/main_classes/text_generation#transformers.GenerationConfig
12
 
13
  ```
14
  asr_model=prepare_pipeline(
15
  model_dir='.', # wherever you save the model
16
+ generate_kwargs={
17
+ 'max_new_tokens':112,
18
  'num_beams':1,
19
  'repetition_penalty':1,
20
  'do_sample':False
 
26
  asr_model(audio_path)
27
  ```
28
 
29
+ run ASR on full directory in `audio_dir`:
30
+ If generate_kwargs not specified, will give you (deterministic) greedy decoding with up to 112 tokens generated, no repetition penalty
31
+
32
+ ```
33
+ ASRdirWhisat(
34
+ audio_dir,
35
+ out_dir = '../whisat_results/',
36
+ model_dir=".",
37
+ )
38
+ ```
39
+
40
+
41
+ Training information:
42
+ Training script: tune_hf_whisper.py
43
+ Training hyperparameters: hparams.yaml
44
+ Training data manifest: PUBLIC_KIDS_TRAIN_v4_deduped.csv
45
+
46
+ Note: to recreate this training you will need to acquire the following public datasets:
47
+ MyST (myst-v0.4.2)
48
+ CuKids
49
+ CSLU
50
+
51
+ and ensure they are stored at paths consistend with those in the data manifest above.
52
+
53
+ Reference:
54
+ @inproceedings{southwell2024,
55
+ title={Automatic speech recognition tuned for child speech in the classroom},
56
+ author={ Southwell, Rosy and Ward , Wayne and Trinh , Viet Anh and Clevenger, Charis and Clevenger, Clay and Watts, Emily and Reitman, Jason and D’Mello, Sidney and Whitehill, Jacob},
57
+ booktitle={{IEEE} International Conference on Acoustics, Speech and Signal Processing
58
+ {ICASSP} 2024, Seoul, South Korea, April 14-19, 2024},
59
+ year={2024},
60
+ }
hparams.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # parameters to set
2
+
3
+ model_cfg:
4
+ init_from_hub_path: openai/whisper-large-v2
5
+ # lang: None
6
+ # apply_spec_augment: True
7
+ # mask_time_prob: 0.05
8
+ # mask_feature_prob: 0.05
9
+ # mask_time_length: 40
10
+ # mask_feature_length: 30
11
+ # mask_time_min_masks: 2
12
+ # mask_feature_min_masks: 2
13
+
14
+ data_cfg:
15
+ data_root: ~/corpora/
16
+ train_manif: ~/corpora/data_manifests/ASR/PUBLIC_KIDS_TRAIN_v4_deduped.csv
17
+ val_manif: # small private dataset of classroom speech, only affects training if load_best_model_at_end: True
18
+ test_manif: # small private dataset of classroom speech, doesn't affect training
19
+
20
+ experiment_cfg:
21
+ OUT_DIR: train/whisat/save/publicKS_LoRA_int8
22
+ use_lora: True
23
+ use_int8: True
24
+
25
+ train_cfg:
26
+ training_args:
27
+ output_dir: !ref <experiment_cfg[OUT_DIR]>
28
+ per_device_train_batch_size: 32 # 64
29
+ learning_rate: 0.0001 # 1e-5 orig, 1e-3 lora
30
+ warmup_steps: 50 # 500 orig 50 lora
31
+ num_train_epochs: 1
32
+ fp16: True # True
33
+ evaluation_strategy: steps # or epochs
34
+ per_device_eval_batch_size: 4
35
+ predict_with_generate: True
36
+ generation_max_length: 112
37
+ save_steps: 500
38
+ eval_steps: 500
39
+ eval_accumulation_steps: 2
40
+ logging_steps: 25
41
+ report_to:
42
+ - tensorboard
43
+ load_best_model_at_end: False
44
+ metric_for_best_model: wer
45
+ greater_is_better: False
46
+ push_to_hub: False
47
+ remove_unused_columns: False # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
48
+ label_names:
49
+ - labels
50
+
main.py CHANGED
@@ -14,110 +14,30 @@ import json
14
  import pandas as pd
15
  import csv
16
 
17
- def prepare_pipeline(model_type='large-v2',
18
- model_dir="../models/whisat-1.2/",
19
- use_stock_model=False,
20
- generate_opts={'max_new_tokens':112,
21
- 'num_beams':1,
22
- 'repetition_penalty':1,
23
- 'do_sample':False}
24
- ):
25
- #%% options (TODO make these CLI options)
26
- lang='english'
27
- USE_INT8 = False
28
-
29
-
30
- import warnings
31
- warnings.filterwarnings("ignore")
32
- transformers.utils.logging.set_verbosity_error()
33
-
34
- init_from_hub_path = f"openai/whisper-{model_type}" # TODO infer automatically from PEFT checkpoint
35
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
36
- print(device)
37
- feature_extractor = WhisperFeatureExtractor.from_pretrained(init_from_hub_path)
38
- # TODO: no need to specify lanf/task?
39
- tokenizer = WhisperTokenizer.from_pretrained(init_from_hub_path, language=lang, task="transcribe")
40
- processor = WhisperProcessor.from_pretrained(init_from_hub_path, language=lang, task="transcribe")
41
-
42
- if use_stock_model:
43
- model =WhisperForConditionalGeneration.from_pretrained(init_from_hub_path)
44
- else:
45
- checkpoint_dir = os.path.expanduser(model_dir)
46
- # check if PEFT
47
- if os.path.isdir(os.path.join(checkpoint_dir , "adapter_model")):
48
- print('...it looks like this model was tuned using PEFT, because adapter_model/ is present in ckpt dir')
49
-
50
- # checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json
51
- peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir , "adapter_model"))
52
- # except ValueError as e: # if final checkpoint these are in the parent checkpoint direcory
53
- # peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir ), subfolder=None)
54
- model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,
55
- load_in_8bit=USE_INT8,
56
- device_map='auto',
57
- use_cache=False,
58
- )
59
- model = PeftModel.from_pretrained(model, os.path.join(checkpoint_dir,"adapter_model"))
60
- else:
61
- model = WhisperForConditionalGeneration.from_pretrained(checkpoint_dir,
62
- load_in_8bit=USE_INT8,
63
- device_map='auto',
64
- use_cache=False,
65
- )
66
- model.eval() # needed?
67
-
68
- pipe = AutomaticSpeechRecognitionPipeline(
69
- # task="automatic-speech-recognition",
70
- model=model,
71
- tokenizer=tokenizer,
72
- feature_extractor=feature_extractor,
73
- chunk_length_s=30,
74
- device=device,
75
- return_timestamps=False,
76
- generate_kwargs=generate_opts,
77
- )
78
-
79
- return(pipe)
80
-
81
- def load_model(model_type='large-v2',
82
- model_dir="../models/whisat-1.2/"):
83
-
84
- lang='english'
85
- USE_INT8 = False
86
-
87
- import warnings
88
- warnings.filterwarnings("ignore")
89
- transformers.utils.logging.set_verbosity_error()
90
-
91
- init_from_hub_path = f"openai/whisper-{model_type}" # TODO infer automatically from PEFT checkpoint
92
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
93
- print(device)
94
- feature_extractor = WhisperFeatureExtractor.from_pretrained(init_from_hub_path)
95
- # TODO: no need to specify lanf/task?
96
- tokenizer = WhisperTokenizer.from_pretrained(init_from_hub_path, language=lang, task="transcribe")
97
- processor = WhisperProcessor.from_pretrained(init_from_hub_path, language=lang, task="transcribe")
98
-
99
- checkpoint_dir = os.path.expanduser(model_dir)
100
- # checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json
101
- peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir , "adapter_model"))
102
- # except ValueError as e: # if final checkpoint these are in the parent checkpoint direcory
103
- # peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir ), subfolder=None)
104
- model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,
105
- load_in_8bit=USE_INT8, # TODO: seemed slightly better without?
106
- device_map='auto',
107
- use_cache=False,
108
- )
109
- model = PeftModel.from_pretrained(model, os.path.join(checkpoint_dir,"adapter_model"))
110
- model.eval() # needed?
111
- return(model, tokenizer, processor)
112
 
113
  def ASRdirWhisat(
114
  audio_dir,
115
- files_to_include=None,
116
  out_dir = '../whisat_results/',
117
- model_type='large-v2',
118
- model_name='whisat-1.2',
119
- model_dir="../models/whisat-1.2",
120
- use_stock_model=False,
121
  max_new_tokens=112,
122
  num_beams=1,
123
  do_sample=False,
@@ -131,54 +51,36 @@ def ASRdirWhisat(
131
  # Save output in same directory structure as input in specified top-level folder
132
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133
 
134
- #TODO optional arg listing files to transcribe in a list or a text file
135
 
136
  asr_model=prepare_pipeline(
137
  model_type=model_type,
138
  model_dir=model_dir,
139
  use_stock_model=use_stock_model,
140
- generate_opts={'max_new_tokens':max_new_tokens,
141
  'num_beams':num_beams,
142
  'repetition_penalty':repetition_penalty,
143
  'do_sample':do_sample
144
  }
145
  )
146
 
147
- if use_stock_model: # set some alternative defaults if using stock model
148
- model_name='whisper_' + model_type + '_stock'
149
 
150
- if files_to_include:
151
- assert isinstance(files_to_include,list) ,'files_to_include should be a list of paths relative to audio_dir to transcribe'
152
- audio_files=files_to_include
153
- # audio_files=[]
154
- # for f in [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )]:
155
- # print(f)
156
- # if os.path.join(audio_dir,f) in files_to_include:
157
- # audio_files.append(f)
158
- # print(f'Including {len(audio_files)} hypotheses matching files_to_include...')
159
- else:
160
- audio_files = [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )]
161
 
162
  # audio_identifier = os.path.basename(audio_dir)
163
- asrDir = os.path.join(out_dir,f'ASR_{model_name}') # Dir where full session asr result will be stored
164
- jsonDir = os.path.join(out_dir,f'JSON_{model_name}')
165
- os.makedirs(asrDir, exist_ok=True)
166
- os.makedirs(jsonDir, exist_ok=True)
167
 
168
- message = "This may take a while on CPU. Go make a cuppa" if asr_model.device.type=="cpu" else "Running on GPU"
169
  print(f'Running ASR for {len(audio_files)} files. {message} ...')
170
  compute_time=0
171
  total_audio_dur=0
172
  # get the start time
173
  st = time.time()
174
-
175
  for audiofile in tqdm(audio_files):
176
  sessname=Path(audiofile).stem
177
  sesspath=os.path.relpath(os.path.dirname(Path(audiofile).resolve()),Path(audio_dir).resolve())
178
  asrFullFile = os.path.join(asrDir,sesspath,f"{sessname}.asr.txt") # full session ASR results file
179
- jsonFile = os.path.join(jsonDir,sesspath, f"{sessname}.json")
180
  os.makedirs(os.path.join(asrDir,sesspath),exist_ok=True)
181
- os.makedirs(os.path.join(jsonDir,sesspath),exist_ok=True)
182
 
183
  with torch.no_grad():
184
  with autocast():
@@ -188,13 +90,6 @@ def ASRdirWhisat(
188
  print(f'{e}: {audiofile}')
189
  continue
190
 
191
- # save full result JSON
192
- with open(jsonFile, "w") as jf:
193
- json.dump(result, jf, indent=4)
194
- # save full result transcript
195
- # if asr_model.return_timestamps:
196
- # asrtext = '\n'.join([r['text'].strip() for r in result['chunks']])
197
- # else:
198
  asrtext = result['text']
199
 
200
  with open(asrFullFile,'w') as outfile:
@@ -204,67 +99,3 @@ def ASRdirWhisat(
204
  compute_time = (et-st)
205
  print(f'...transcription complete in {compute_time:.1f} sec')
206
 
207
-
208
- def ASRmanifestWhisat(
209
- manifest_csv,
210
- out_csv,
211
- corpora_root,
212
- model_type='large-v2',
213
- model_dir="../models/whisat-1.2",
214
- use_stock_model=False,
215
- max_new_tokens=112,
216
- num_beams=1,
217
- do_sample=False,
218
- repetition_penalty=1,
219
- ):
220
-
221
- ## ASR using fine-tuned Transformers Whisper
222
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
223
- # Simply trancsribe each file in the specified folder separately
224
- # Whisper takes 30-second input. Anything shorter than this will be 0 padded. Longer will be concatenated.
225
- # Save output in same directory structure as input in specified top-level folder
226
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
227
- df = pd.read_csv(manifest_csv,keep_default_na=False)
228
- fieldnames = list(df.columns) + ['asr']
229
-
230
- asr_model=prepare_pipeline(
231
- model_type=model_type,
232
- model_dir=model_dir,
233
- use_stock_model=use_stock_model,
234
- generate_opts={'max_new_tokens':max_new_tokens,
235
- 'num_beams':num_beams,
236
- 'repetition_penalty':repetition_penalty,
237
- 'do_sample':do_sample
238
- }
239
- )
240
-
241
- message = "This may take a while on CPU. Go make a cuppa " if asr_model.device.type=="cpu" else "Running on GPU"
242
- print(f'Running ASR for {len(df)} files. {message} ...')
243
- compute_time=0
244
- total_audio_dur=0
245
- # get the start time
246
- st = time.time()
247
-
248
- with open(out_csv, 'w', newline='') as csvfile:
249
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames,delimiter=',')
250
- writer.writeheader()
251
-
252
- for i,row in tqdm(df.iterrows(), total=df.shape[0]):
253
-
254
- audiofile=row['wav'].replace('$DATAROOT',corpora_root)
255
- with torch.no_grad():
256
- with autocast():
257
- try:
258
- result = asr_model(audiofile)
259
- asrtext = result['text']
260
- except ValueError as e:
261
- print(f'{e}: {audiofile}')
262
- asrtext=''
263
-
264
- row['asr']=asrtext
265
- writer.writerow( row.to_dict())
266
-
267
- et = time.time()
268
- compute_time = (et-st)
269
- print(f'...transcription complete in {compute_time:.1f} sec')
270
-
 
14
  import pandas as pd
15
  import csv
16
 
17
+ def prepare_pipeline(model_path, generate_kwargs):
18
+ """Prepare a pipeline for ASR inference
19
+ Args:
20
+ model_path (str): path to model directory / huggingface model name
21
+ generate_kwargs (dict): options to pass to pipeline
22
+ Returns:
23
+ pipeline: ASR pipeline
24
+ """
25
+ processor = WhisperProcessor.from_pretrained(model_path)
26
+
27
+ asr_pipeline = pipeline(
28
+ "automatic-speech-recognition",
29
+ model=model_path,
30
+ tokenizer=processor.tokenizer,
31
+ feature_extractor=processor.feature_extractor,
32
+ generate_kwargs=generate_kwargs,
33
+ model_kwargs={"load_in_8bit": False},
34
+ device_map='auto')
35
+ return asr_pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def ASRdirWhisat(
38
  audio_dir,
 
39
  out_dir = '../whisat_results/',
40
+ model_dir=".",
 
 
 
41
  max_new_tokens=112,
42
  num_beams=1,
43
  do_sample=False,
 
51
  # Save output in same directory structure as input in specified top-level folder
52
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
53
 
 
54
 
55
  asr_model=prepare_pipeline(
56
  model_type=model_type,
57
  model_dir=model_dir,
58
  use_stock_model=use_stock_model,
59
+ generate_kwargs={'max_new_tokens':max_new_tokens,
60
  'num_beams':num_beams,
61
  'repetition_penalty':repetition_penalty,
62
  'do_sample':do_sample
63
  }
64
  )
65
 
 
 
66
 
67
+ audio_files = [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )]
 
 
 
 
 
 
 
 
 
 
68
 
69
  # audio_identifier = os.path.basename(audio_dir)
70
+ os.makedirs(out_dir, exist_ok=True)
 
 
 
71
 
72
+ message = "This may take a while on CPU." if asr_model.device.type=="cpu" else "Running on GPU"
73
  print(f'Running ASR for {len(audio_files)} files. {message} ...')
74
  compute_time=0
75
  total_audio_dur=0
76
  # get the start time
77
  st = time.time()
78
+ asrDir = out_dir
79
  for audiofile in tqdm(audio_files):
80
  sessname=Path(audiofile).stem
81
  sesspath=os.path.relpath(os.path.dirname(Path(audiofile).resolve()),Path(audio_dir).resolve())
82
  asrFullFile = os.path.join(asrDir,sesspath,f"{sessname}.asr.txt") # full session ASR results file
 
83
  os.makedirs(os.path.join(asrDir,sesspath),exist_ok=True)
 
84
 
85
  with torch.no_grad():
86
  with autocast():
 
90
  print(f'{e}: {audiofile}')
91
  continue
92
 
 
 
 
 
 
 
 
93
  asrtext = result['text']
94
 
95
  with open(asrFullFile,'w') as outfile:
 
99
  compute_time = (et-st)
100
  print(f'...transcription complete in {compute_time:.1f} sec')
101