zzuczy commited on
Commit
12f5621
1 Parent(s): a020911

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +229 -0
main.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import torch
5
+ import argparse
6
+ from functools import partial
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from datasets import set_caching_enabled
14
+ set_caching_enabled(False)
15
+
16
+ from datasets import (
17
+ load_dataset,
18
+ load_from_disk,
19
+ load_metric,)
20
+
21
+ from transformers import (
22
+ Wav2Vec2CTCTokenizer,
23
+ Wav2Vec2FeatureExtractor,
24
+ Wav2Vec2Processor,
25
+ Wav2Vec2ForCTC,
26
+ TrainingArguments,
27
+ Trainer,
28
+ )
29
+
30
+ import torchaudio
31
+
32
+
33
+ def preprocess_data(example, tok_func = word_tokenize):
34
+ example['sentence'] = ' '.join(tok_func(example['sentence']))
35
+ return example
36
+
37
+
38
+ def speech_file_to_array_fn(batch,
39
+ text_col="sentence",
40
+ fname_col="path",
41
+ resampling_to=16000):
42
+ speech_array, sampling_rate = torchaudio.load(batch[fname_col])
43
+ resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
44
+ batch["speech"] = resampler(speech_array)[0].numpy()
45
+ batch["sampling_rate"] = resampling_to
46
+ batch["target_text"] = batch[text_col]
47
+ return
48
+
49
+ @dataclass
50
+ class DataCollatorCTCWithPadding:
51
+ """
52
+ Data collator that will dynamically pad the inputs received.
53
+ Args:
54
+ processor (:class:`~transformers.Wav2Vec2Processor`)
55
+ The processor used for proccessing the data.
56
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
57
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
58
+ among:
59
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
60
+ sequence if provided).
61
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
62
+ maximum acceptable input length for the model if that argument is not provided.
63
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
64
+ different lengths).
65
+ max_length (:obj:`int`, `optional`):
66
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
67
+ max_length_labels (:obj:`int`, `optional`):
68
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
69
+ pad_to_multiple_of (:obj:`int`, `optional`):
70
+ If set will pad the sequence to a multiple of the provided value.
71
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
72
+ 7.5 (Volta).
73
+ """
74
+
75
+ processor: Wav2Vec2Processor
76
+ padding: Union[bool, str] = True
77
+ max_length: Optional[int] = None
78
+ max_length_labels: Optional[int] = None
79
+ pad_to_multiple_of: Optional[int] = None
80
+ pad_to_multiple_of_labels: Optional[int] = None
81
+
82
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
83
+ # split inputs and labels since they have to be of different lenghts and need
84
+ # different padding methods
85
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
86
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
87
+
88
+ batch = self.processor.pad(
89
+ input_features,
90
+ padding=self.padding,
91
+ max_length=self.max_length,
92
+ pad_to_multiple_of=self.pad_to_multiple_of,
93
+ return_tensors="pt",
94
+ )
95
+ with self.processor.as_target_processor():
96
+ labels_batch = self.processor.pad(
97
+ label_features,
98
+ padding=self.padding,
99
+ max_length=self.max_length_labels,
100
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
101
+ return_tensors="pt",
102
+ )
103
+
104
+ # replace padding with -100 to ignore loss correctly
105
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
106
+
107
+ batch["labels"] = labels
108
+
109
+ return batch
110
+
111
+
112
+ def main():
113
+ parser = argparse.ArgumentParser()
114
+
115
+ parser.add_argument("--pre_trained_model", default='', type=str, help='Local path to pre-trained wav2vec2 model')
116
+ parser.add_argument("--train_file_path", default='', type=str, help='Local path to train file')
117
+ parser.add_argument("--valid_file_path", default='', type=str, help='Local path to valid file')
118
+
119
+ parser.add_argument("--warmup_steps", default=20000, type=int, help='')
120
+ parser.add_argument("--learning_rate", default=3e-5, type=float, help='')
121
+ args = parser.parse_args()
122
+
123
+ def prepare_dataset(batch):
124
+ # check that all files have the correct sampling rate
125
+ # assert (
126
+ # len(set(batch["sampling_rate"])) == 1
127
+ # ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
128
+
129
+ batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
130
+
131
+ with processor.as_target_processor():
132
+ batch["labels"] = processor(batch["target_text"]).input_ids
133
+ return
134
+
135
+ def compute_metrics(pred, processor, metric):
136
+ pred_logits = pred.predictions
137
+ pred_ids = np.argmax(pred_logits, axis=-1)
138
+
139
+ pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
140
+
141
+ pred_str = processor.batch_decode(pred_ids)
142
+ # we do not want to group tokens when computing the metrics
143
+ label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
144
+
145
+ wer = cer_metric.compute(predictions=pred_str, references=label_str)
146
+
147
+ return {"cer": cer}
148
+
149
+ # load dataset
150
+ print('Loading dataset....')
151
+ datasets = load_dataset('csv', name='cn', data_files={'train': args.train_file_path, 'valid': args.valid_file_path},
152
+ cache_dir='/path/to/csv')
153
+ datasets = datasets.map(preprocess_data)
154
+
155
+ dataset_train = datasets['train']
156
+ dataset_valid = datasets['valid']
157
+
158
+ dataset_train = dataset_train.map(speech_file_to_array_fn,
159
+ remove_columns=dataset_train.column_names,
160
+ cache_file_name='/path/to/cache/of/train/speech/file')
161
+
162
+ dataset_valid = dataset_valid.map(speech_file_to_array_fn,
163
+ remove_columns=dataset_valid.column_names,
164
+ cache_file_name='/path/to/cache/of/valid/speech/file')
165
+
166
+ print('Tokenization')
167
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(args.pre_trained_model)
168
+
169
+ print('Feature extracting....')
170
+ feature_extractor = Wav2Vec2FeatureExtractor(args.pre_trained_model)
171
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
172
+
173
+ dataset_train = dataset_train.map(prepare_dataset,
174
+ remove_columns=dataset_train.column_names,
175
+ batched=True,
176
+ load_from_cache_file=True,
177
+ cache_file_name='/path/to/train')
178
+
179
+ dataset_valid = dataset_valid.map(prepare_dataset,
180
+ remove_columns=dataset_valid.column_names,
181
+ batched=True,
182
+ load_from_cache_file=True,
183
+ cache_file_name='/path/to/valid')
184
+
185
+
186
+ data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
187
+ wer_metric = load_metric("cer")
188
+
189
+ # create model
190
+ model = Wav2Vec2ForCTC.from_pretrained(
191
+ args.pre_trained_model,
192
+ vocab_size=len(processor.tokenizer)
193
+ )
194
+ model.freeze_feature_extractor()
195
+
196
+ training_args = TrainingArguments(
197
+ output_dir="/path/to/output",
198
+ group_by_length=True,
199
+ per_device_train_batch_size=3,
200
+ gradient_accumulation_steps=1,
201
+ per_device_eval_batch_size=1,
202
+ metric_for_best_model='cer',
203
+ evaluation_strategy="steps",
204
+ eval_steps=15000,
205
+ logging_strategy="steps",
206
+ logging_steps=15000,
207
+ save_strategy="steps",
208
+ save_steps=15000,
209
+ num_train_epochs=100,
210
+ fp16=True,
211
+ learning_rate=args.learning_rate,
212
+ warmup_steps=args.warmup_steps,
213
+ save_total_limit=3,
214
+ report_to="tensorboard"
215
+ )
216
+
217
+ print('Training model....')
218
+ # Train
219
+ trainer = Trainer(
220
+ model=model,
221
+ data_collator=data_collator,
222
+ args=training_args,
223
+ compute_metrics=partial(compute_metrics, metric=cer_metric, processor=processor),
224
+ train_dataset=dataset_train,
225
+ eval_dataset=dataset_valid,
226
+ tokenizer=processor.feature_extractor,
227
+ )
228
+
229
+ trainer.train()