rosyvs commited on
Commit
7c2d6fa
1 Parent(s): 3c0cc82

main.py created - contains code for transcription

Browse files
Files changed (1) hide show
  1. main.py +270 -0
main.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from peft import PeftModel, PeftConfig
3
+ import torch
4
+ from torch.cuda.amp import autocast
5
+ from torch.utils.data import DataLoader
6
+ from tqdm import tqdm
7
+ import transformers
8
+ from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig
9
+ from transformers import pipeline, AutomaticSpeechRecognitionPipeline
10
+ import argparse
11
+ import time
12
+ from pathlib import Path
13
+ import json
14
+ import pandas as pd
15
+ import csv
16
+
17
+ def prepare_pipeline(model_type='large-v2',
18
+ model_dir="../models/whisat-1.2/",
19
+ use_stock_model=False,
20
+ generate_opts={'max_new_tokens':112,
21
+ 'num_beams':1,
22
+ 'repetition_penalty':1,
23
+ 'do_sample':False}
24
+ ):
25
+ #%% options (TODO make these CLI options)
26
+ lang='english'
27
+ USE_INT8 = False
28
+
29
+
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+ transformers.utils.logging.set_verbosity_error()
33
+
34
+ init_from_hub_path = f"openai/whisper-{model_type}" # TODO infer automatically from PEFT checkpoint
35
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
36
+ print(device)
37
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(init_from_hub_path)
38
+ # TODO: no need to specify lanf/task?
39
+ tokenizer = WhisperTokenizer.from_pretrained(init_from_hub_path, language=lang, task="transcribe")
40
+ processor = WhisperProcessor.from_pretrained(init_from_hub_path, language=lang, task="transcribe")
41
+
42
+ if use_stock_model:
43
+ model =WhisperForConditionalGeneration.from_pretrained(init_from_hub_path)
44
+ else:
45
+ checkpoint_dir = os.path.expanduser(model_dir)
46
+ # check if PEFT
47
+ if os.path.isdir(os.path.join(checkpoint_dir , "adapter_model")):
48
+ print('...it looks like this model was tuned using PEFT, because adapter_model/ is present in ckpt dir')
49
+
50
+ # checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json
51
+ peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir , "adapter_model"))
52
+ # except ValueError as e: # if final checkpoint these are in the parent checkpoint direcory
53
+ # peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir ), subfolder=None)
54
+ model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,
55
+ load_in_8bit=USE_INT8,
56
+ device_map='auto',
57
+ use_cache=False,
58
+ )
59
+ model = PeftModel.from_pretrained(model, os.path.join(checkpoint_dir,"adapter_model"))
60
+ else:
61
+ model = WhisperForConditionalGeneration.from_pretrained(checkpoint_dir,
62
+ load_in_8bit=USE_INT8,
63
+ device_map='auto',
64
+ use_cache=False,
65
+ )
66
+ model.eval() # needed?
67
+
68
+ pipe = AutomaticSpeechRecognitionPipeline(
69
+ # task="automatic-speech-recognition",
70
+ model=model,
71
+ tokenizer=tokenizer,
72
+ feature_extractor=feature_extractor,
73
+ chunk_length_s=30,
74
+ device=device,
75
+ return_timestamps=False,
76
+ generate_kwargs=generate_opts,
77
+ )
78
+
79
+ return(pipe)
80
+
81
+ def load_model(model_type='large-v2',
82
+ model_dir="../models/whisat-1.2/"):
83
+
84
+ lang='english'
85
+ USE_INT8 = False
86
+
87
+ import warnings
88
+ warnings.filterwarnings("ignore")
89
+ transformers.utils.logging.set_verbosity_error()
90
+
91
+ init_from_hub_path = f"openai/whisper-{model_type}" # TODO infer automatically from PEFT checkpoint
92
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
93
+ print(device)
94
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(init_from_hub_path)
95
+ # TODO: no need to specify lanf/task?
96
+ tokenizer = WhisperTokenizer.from_pretrained(init_from_hub_path, language=lang, task="transcribe")
97
+ processor = WhisperProcessor.from_pretrained(init_from_hub_path, language=lang, task="transcribe")
98
+
99
+ checkpoint_dir = os.path.expanduser(model_dir)
100
+ # checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json
101
+ peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir , "adapter_model"))
102
+ # except ValueError as e: # if final checkpoint these are in the parent checkpoint direcory
103
+ # peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir ), subfolder=None)
104
+ model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,
105
+ load_in_8bit=USE_INT8, # TODO: seemed slightly better without?
106
+ device_map='auto',
107
+ use_cache=False,
108
+ )
109
+ model = PeftModel.from_pretrained(model, os.path.join(checkpoint_dir,"adapter_model"))
110
+ model.eval() # needed?
111
+ return(model, tokenizer, processor)
112
+
113
+ def ASRdirWhisat(
114
+ audio_dir,
115
+ files_to_include=None,
116
+ out_dir = '../whisat_results/',
117
+ model_type='large-v2',
118
+ model_name='whisat-1.2',
119
+ model_dir="../models/whisat-1.2",
120
+ use_stock_model=False,
121
+ max_new_tokens=112,
122
+ num_beams=1,
123
+ do_sample=False,
124
+ repetition_penalty=1,
125
+ ):
126
+
127
+ ## ASR using fine-tuned Transformers Whisper
128
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
129
+ # Simply trancsribe each file in the specified folder separately
130
+ # Whisper takes 30-second input. Anything shorter than this will be 0 padded. Longer will be concatenated.
131
+ # Save output in same directory structure as input in specified top-level folder
132
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133
+
134
+ #TODO optional arg listing files to transcribe in a list or a text file
135
+
136
+ asr_model=prepare_pipeline(
137
+ model_type=model_type,
138
+ model_dir=model_dir,
139
+ use_stock_model=use_stock_model,
140
+ generate_opts={'max_new_tokens':max_new_tokens,
141
+ 'num_beams':num_beams,
142
+ 'repetition_penalty':repetition_penalty,
143
+ 'do_sample':do_sample
144
+ }
145
+ )
146
+
147
+ if use_stock_model: # set some alternative defaults if using stock model
148
+ model_name='whisper_' + model_type + '_stock'
149
+
150
+ if files_to_include:
151
+ assert isinstance(files_to_include,list) ,'files_to_include should be a list of paths relative to audio_dir to transcribe'
152
+ audio_files=files_to_include
153
+ # audio_files=[]
154
+ # for f in [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )]:
155
+ # print(f)
156
+ # if os.path.join(audio_dir,f) in files_to_include:
157
+ # audio_files.append(f)
158
+ # print(f'Including {len(audio_files)} hypotheses matching files_to_include...')
159
+ else:
160
+ audio_files = [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )]
161
+
162
+ # audio_identifier = os.path.basename(audio_dir)
163
+ asrDir = os.path.join(out_dir,f'ASR_{model_name}') # Dir where full session asr result will be stored
164
+ jsonDir = os.path.join(out_dir,f'JSON_{model_name}')
165
+ os.makedirs(asrDir, exist_ok=True)
166
+ os.makedirs(jsonDir, exist_ok=True)
167
+
168
+ message = "This may take a while on CPU. Go make a cuppa" if asr_model.device.type=="cpu" else "Running on GPU"
169
+ print(f'Running ASR for {len(audio_files)} files. {message} ...')
170
+ compute_time=0
171
+ total_audio_dur=0
172
+ # get the start time
173
+ st = time.time()
174
+
175
+ for audiofile in tqdm(audio_files):
176
+ sessname=Path(audiofile).stem
177
+ sesspath=os.path.relpath(os.path.dirname(Path(audiofile).resolve()),Path(audio_dir).resolve())
178
+ asrFullFile = os.path.join(asrDir,sesspath,f"{sessname}.asr.txt") # full session ASR results file
179
+ jsonFile = os.path.join(jsonDir,sesspath, f"{sessname}.json")
180
+ os.makedirs(os.path.join(asrDir,sesspath),exist_ok=True)
181
+ os.makedirs(os.path.join(jsonDir,sesspath),exist_ok=True)
182
+
183
+ with torch.no_grad():
184
+ with autocast():
185
+ try:
186
+ result = asr_model(audiofile)
187
+ except ValueError as e:
188
+ print(f'{e}: {audiofile}')
189
+ continue
190
+
191
+ # save full result JSON
192
+ with open(jsonFile, "w") as jf:
193
+ json.dump(result, jf, indent=4)
194
+ # save full result transcript
195
+ # if asr_model.return_timestamps:
196
+ # asrtext = '\n'.join([r['text'].strip() for r in result['chunks']])
197
+ # else:
198
+ asrtext = result['text']
199
+
200
+ with open(asrFullFile,'w') as outfile:
201
+ outfile.write(asrtext)
202
+ # print(asrtext)
203
+ et = time.time()
204
+ compute_time = (et-st)
205
+ print(f'...transcription complete in {compute_time:.1f} sec')
206
+
207
+
208
+ def ASRmanifestWhisat(
209
+ manifest_csv,
210
+ out_csv,
211
+ corpora_root,
212
+ model_type='large-v2',
213
+ model_dir="../models/whisat-1.2",
214
+ use_stock_model=False,
215
+ max_new_tokens=112,
216
+ num_beams=1,
217
+ do_sample=False,
218
+ repetition_penalty=1,
219
+ ):
220
+
221
+ ## ASR using fine-tuned Transformers Whisper
222
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
223
+ # Simply trancsribe each file in the specified folder separately
224
+ # Whisper takes 30-second input. Anything shorter than this will be 0 padded. Longer will be concatenated.
225
+ # Save output in same directory structure as input in specified top-level folder
226
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
227
+ df = pd.read_csv(manifest_csv,keep_default_na=False)
228
+ fieldnames = list(df.columns) + ['asr']
229
+
230
+ asr_model=prepare_pipeline(
231
+ model_type=model_type,
232
+ model_dir=model_dir,
233
+ use_stock_model=use_stock_model,
234
+ generate_opts={'max_new_tokens':max_new_tokens,
235
+ 'num_beams':num_beams,
236
+ 'repetition_penalty':repetition_penalty,
237
+ 'do_sample':do_sample
238
+ }
239
+ )
240
+
241
+ message = "This may take a while on CPU. Go make a cuppa " if asr_model.device.type=="cpu" else "Running on GPU"
242
+ print(f'Running ASR for {len(df)} files. {message} ...')
243
+ compute_time=0
244
+ total_audio_dur=0
245
+ # get the start time
246
+ st = time.time()
247
+
248
+ with open(out_csv, 'w', newline='') as csvfile:
249
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames,delimiter=',')
250
+ writer.writeheader()
251
+
252
+ for i,row in tqdm(df.iterrows(), total=df.shape[0]):
253
+
254
+ audiofile=row['wav'].replace('$DATAROOT',corpora_root)
255
+ with torch.no_grad():
256
+ with autocast():
257
+ try:
258
+ result = asr_model(audiofile)
259
+ asrtext = result['text']
260
+ except ValueError as e:
261
+ print(f'{e}: {audiofile}')
262
+ asrtext=''
263
+
264
+ row['asr']=asrtext
265
+ writer.writerow( row.to_dict())
266
+
267
+ et = time.time()
268
+ compute_time = (et-st)
269
+ print(f'...transcription complete in {compute_time:.1f} sec')
270
+