Upload main.py
Browse files
main.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
from functools import partial
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Any, Dict, List, Optional, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
from datasets import set_caching_enabled
|
14 |
+
set_caching_enabled(False)
|
15 |
+
|
16 |
+
from datasets import (
|
17 |
+
load_dataset,
|
18 |
+
load_from_disk,
|
19 |
+
load_metric,)
|
20 |
+
|
21 |
+
from transformers import (
|
22 |
+
Wav2Vec2CTCTokenizer,
|
23 |
+
Wav2Vec2FeatureExtractor,
|
24 |
+
Wav2Vec2Processor,
|
25 |
+
Wav2Vec2ForCTC,
|
26 |
+
TrainingArguments,
|
27 |
+
Trainer,
|
28 |
+
)
|
29 |
+
|
30 |
+
import torchaudio
|
31 |
+
|
32 |
+
|
33 |
+
def preprocess_data(example, tok_func = word_tokenize):
|
34 |
+
example['sentence'] = ' '.join(tok_func(example['sentence']))
|
35 |
+
return example
|
36 |
+
|
37 |
+
|
38 |
+
def speech_file_to_array_fn(batch,
|
39 |
+
text_col="sentence",
|
40 |
+
fname_col="path",
|
41 |
+
resampling_to=16000):
|
42 |
+
speech_array, sampling_rate = torchaudio.load(batch[fname_col])
|
43 |
+
resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
|
44 |
+
batch["speech"] = resampler(speech_array)[0].numpy()
|
45 |
+
batch["sampling_rate"] = resampling_to
|
46 |
+
batch["target_text"] = batch[text_col]
|
47 |
+
return
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class DataCollatorCTCWithPadding:
|
51 |
+
"""
|
52 |
+
Data collator that will dynamically pad the inputs received.
|
53 |
+
Args:
|
54 |
+
processor (:class:`~transformers.Wav2Vec2Processor`)
|
55 |
+
The processor used for proccessing the data.
|
56 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
57 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
58 |
+
among:
|
59 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
60 |
+
sequence if provided).
|
61 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
62 |
+
maximum acceptable input length for the model if that argument is not provided.
|
63 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
64 |
+
different lengths).
|
65 |
+
max_length (:obj:`int`, `optional`):
|
66 |
+
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
67 |
+
max_length_labels (:obj:`int`, `optional`):
|
68 |
+
Maximum length of the ``labels`` returned list and optionally padding length (see above).
|
69 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
70 |
+
If set will pad the sequence to a multiple of the provided value.
|
71 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
72 |
+
7.5 (Volta).
|
73 |
+
"""
|
74 |
+
|
75 |
+
processor: Wav2Vec2Processor
|
76 |
+
padding: Union[bool, str] = True
|
77 |
+
max_length: Optional[int] = None
|
78 |
+
max_length_labels: Optional[int] = None
|
79 |
+
pad_to_multiple_of: Optional[int] = None
|
80 |
+
pad_to_multiple_of_labels: Optional[int] = None
|
81 |
+
|
82 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
83 |
+
# split inputs and labels since they have to be of different lenghts and need
|
84 |
+
# different padding methods
|
85 |
+
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
86 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
87 |
+
|
88 |
+
batch = self.processor.pad(
|
89 |
+
input_features,
|
90 |
+
padding=self.padding,
|
91 |
+
max_length=self.max_length,
|
92 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
93 |
+
return_tensors="pt",
|
94 |
+
)
|
95 |
+
with self.processor.as_target_processor():
|
96 |
+
labels_batch = self.processor.pad(
|
97 |
+
label_features,
|
98 |
+
padding=self.padding,
|
99 |
+
max_length=self.max_length_labels,
|
100 |
+
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
101 |
+
return_tensors="pt",
|
102 |
+
)
|
103 |
+
|
104 |
+
# replace padding with -100 to ignore loss correctly
|
105 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
106 |
+
|
107 |
+
batch["labels"] = labels
|
108 |
+
|
109 |
+
return batch
|
110 |
+
|
111 |
+
|
112 |
+
def main():
|
113 |
+
parser = argparse.ArgumentParser()
|
114 |
+
|
115 |
+
parser.add_argument("--pre_trained_model", default='', type=str, help='Local path to pre-trained wav2vec2 model')
|
116 |
+
parser.add_argument("--train_file_path", default='', type=str, help='Local path to train file')
|
117 |
+
parser.add_argument("--valid_file_path", default='', type=str, help='Local path to valid file')
|
118 |
+
|
119 |
+
parser.add_argument("--warmup_steps", default=20000, type=int, help='')
|
120 |
+
parser.add_argument("--learning_rate", default=3e-5, type=float, help='')
|
121 |
+
args = parser.parse_args()
|
122 |
+
|
123 |
+
def prepare_dataset(batch):
|
124 |
+
# check that all files have the correct sampling rate
|
125 |
+
# assert (
|
126 |
+
# len(set(batch["sampling_rate"])) == 1
|
127 |
+
# ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
|
128 |
+
|
129 |
+
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
130 |
+
|
131 |
+
with processor.as_target_processor():
|
132 |
+
batch["labels"] = processor(batch["target_text"]).input_ids
|
133 |
+
return
|
134 |
+
|
135 |
+
def compute_metrics(pred, processor, metric):
|
136 |
+
pred_logits = pred.predictions
|
137 |
+
pred_ids = np.argmax(pred_logits, axis=-1)
|
138 |
+
|
139 |
+
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
|
140 |
+
|
141 |
+
pred_str = processor.batch_decode(pred_ids)
|
142 |
+
# we do not want to group tokens when computing the metrics
|
143 |
+
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
144 |
+
|
145 |
+
wer = cer_metric.compute(predictions=pred_str, references=label_str)
|
146 |
+
|
147 |
+
return {"cer": cer}
|
148 |
+
|
149 |
+
# load dataset
|
150 |
+
print('Loading dataset....')
|
151 |
+
datasets = load_dataset('csv', name='cn', data_files={'train': args.train_file_path, 'valid': args.valid_file_path},
|
152 |
+
cache_dir='/path/to/csv')
|
153 |
+
datasets = datasets.map(preprocess_data)
|
154 |
+
|
155 |
+
dataset_train = datasets['train']
|
156 |
+
dataset_valid = datasets['valid']
|
157 |
+
|
158 |
+
dataset_train = dataset_train.map(speech_file_to_array_fn,
|
159 |
+
remove_columns=dataset_train.column_names,
|
160 |
+
cache_file_name='/path/to/cache/of/train/speech/file')
|
161 |
+
|
162 |
+
dataset_valid = dataset_valid.map(speech_file_to_array_fn,
|
163 |
+
remove_columns=dataset_valid.column_names,
|
164 |
+
cache_file_name='/path/to/cache/of/valid/speech/file')
|
165 |
+
|
166 |
+
print('Tokenization')
|
167 |
+
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(args.pre_trained_model)
|
168 |
+
|
169 |
+
print('Feature extracting....')
|
170 |
+
feature_extractor = Wav2Vec2FeatureExtractor(args.pre_trained_model)
|
171 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
172 |
+
|
173 |
+
dataset_train = dataset_train.map(prepare_dataset,
|
174 |
+
remove_columns=dataset_train.column_names,
|
175 |
+
batched=True,
|
176 |
+
load_from_cache_file=True,
|
177 |
+
cache_file_name='/path/to/train')
|
178 |
+
|
179 |
+
dataset_valid = dataset_valid.map(prepare_dataset,
|
180 |
+
remove_columns=dataset_valid.column_names,
|
181 |
+
batched=True,
|
182 |
+
load_from_cache_file=True,
|
183 |
+
cache_file_name='/path/to/valid')
|
184 |
+
|
185 |
+
|
186 |
+
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
|
187 |
+
wer_metric = load_metric("cer")
|
188 |
+
|
189 |
+
# create model
|
190 |
+
model = Wav2Vec2ForCTC.from_pretrained(
|
191 |
+
args.pre_trained_model,
|
192 |
+
vocab_size=len(processor.tokenizer)
|
193 |
+
)
|
194 |
+
model.freeze_feature_extractor()
|
195 |
+
|
196 |
+
training_args = TrainingArguments(
|
197 |
+
output_dir="/path/to/output",
|
198 |
+
group_by_length=True,
|
199 |
+
per_device_train_batch_size=3,
|
200 |
+
gradient_accumulation_steps=1,
|
201 |
+
per_device_eval_batch_size=1,
|
202 |
+
metric_for_best_model='cer',
|
203 |
+
evaluation_strategy="steps",
|
204 |
+
eval_steps=15000,
|
205 |
+
logging_strategy="steps",
|
206 |
+
logging_steps=15000,
|
207 |
+
save_strategy="steps",
|
208 |
+
save_steps=15000,
|
209 |
+
num_train_epochs=100,
|
210 |
+
fp16=True,
|
211 |
+
learning_rate=args.learning_rate,
|
212 |
+
warmup_steps=args.warmup_steps,
|
213 |
+
save_total_limit=3,
|
214 |
+
report_to="tensorboard"
|
215 |
+
)
|
216 |
+
|
217 |
+
print('Training model....')
|
218 |
+
# Train
|
219 |
+
trainer = Trainer(
|
220 |
+
model=model,
|
221 |
+
data_collator=data_collator,
|
222 |
+
args=training_args,
|
223 |
+
compute_metrics=partial(compute_metrics, metric=cer_metric, processor=processor),
|
224 |
+
train_dataset=dataset_train,
|
225 |
+
eval_dataset=dataset_valid,
|
226 |
+
tokenizer=processor.feature_extractor,
|
227 |
+
)
|
228 |
+
|
229 |
+
trainer.train()
|