|
import os
|
|
import re
|
|
import json
|
|
import torch
|
|
import argparse
|
|
from functools import partial
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from datasets import set_caching_enabled
|
|
set_caching_enabled(False)
|
|
|
|
from datasets import (
|
|
load_dataset,
|
|
load_from_disk,
|
|
load_metric,)
|
|
|
|
from transformers import (
|
|
Wav2Vec2CTCTokenizer,
|
|
Wav2Vec2FeatureExtractor,
|
|
Wav2Vec2Processor,
|
|
Wav2Vec2ForCTC,
|
|
TrainingArguments,
|
|
Trainer,
|
|
)
|
|
|
|
import torchaudio
|
|
|
|
|
|
def preprocess_data(example, tok_func = word_tokenize):
|
|
example['sentence'] = ' '.join(tok_func(example['sentence']))
|
|
return example
|
|
|
|
|
|
def speech_file_to_array_fn(batch,
|
|
text_col="sentence",
|
|
fname_col="path",
|
|
resampling_to=16000):
|
|
speech_array, sampling_rate = torchaudio.load(batch[fname_col])
|
|
resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
|
|
batch["speech"] = resampler(speech_array)[0].numpy()
|
|
batch["sampling_rate"] = resampling_to
|
|
batch["target_text"] = batch[text_col]
|
|
return
|
|
|
|
@dataclass
|
|
class DataCollatorCTCWithPadding:
|
|
"""
|
|
Data collator that will dynamically pad the inputs received.
|
|
Args:
|
|
processor (:class:`~transformers.Wav2Vec2Processor`)
|
|
The processor used for proccessing the data.
|
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
|
among:
|
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
|
sequence if provided).
|
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
maximum acceptable input length for the model if that argument is not provided.
|
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
|
different lengths).
|
|
max_length (:obj:`int`, `optional`):
|
|
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
|
max_length_labels (:obj:`int`, `optional`):
|
|
Maximum length of the ``labels`` returned list and optionally padding length (see above).
|
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
|
If set will pad the sequence to a multiple of the provided value.
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
|
7.5 (Volta).
|
|
"""
|
|
|
|
processor: Wav2Vec2Processor
|
|
padding: Union[bool, str] = True
|
|
max_length: Optional[int] = None
|
|
max_length_labels: Optional[int] = None
|
|
pad_to_multiple_of: Optional[int] = None
|
|
pad_to_multiple_of_labels: Optional[int] = None
|
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
|
|
|
|
|
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
|
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
|
|
|
batch = self.processor.pad(
|
|
input_features,
|
|
padding=self.padding,
|
|
max_length=self.max_length,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors="pt",
|
|
)
|
|
with self.processor.as_target_processor():
|
|
labels_batch = self.processor.pad(
|
|
label_features,
|
|
padding=self.padding,
|
|
max_length=self.max_length_labels,
|
|
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
|
|
|
batch["labels"] = labels
|
|
|
|
return batch
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--pre_trained_model", default='', type=str, help='Local path to pre-trained wav2vec2 model')
|
|
parser.add_argument("--train_file_path", default='', type=str, help='Local path to train file')
|
|
parser.add_argument("--valid_file_path", default='', type=str, help='Local path to valid file')
|
|
|
|
parser.add_argument("--warmup_steps", default=20000, type=int, help='')
|
|
parser.add_argument("--learning_rate", default=3e-5, type=float, help='')
|
|
args = parser.parse_args()
|
|
|
|
def prepare_dataset(batch):
|
|
|
|
|
|
|
|
|
|
|
|
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
|
|
|
with processor.as_target_processor():
|
|
batch["labels"] = processor(batch["target_text"]).input_ids
|
|
return
|
|
|
|
def compute_metrics(pred, processor, metric):
|
|
pred_logits = pred.predictions
|
|
pred_ids = np.argmax(pred_logits, axis=-1)
|
|
|
|
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
|
|
|
|
pred_str = processor.batch_decode(pred_ids)
|
|
|
|
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
|
|
|
wer = cer_metric.compute(predictions=pred_str, references=label_str)
|
|
|
|
return {"cer": cer}
|
|
|
|
|
|
print('Loading dataset....')
|
|
datasets = load_dataset('csv', name='cn', data_files={'train': args.train_file_path, 'valid': args.valid_file_path},
|
|
cache_dir='/path/to/csv')
|
|
datasets = datasets.map(preprocess_data)
|
|
|
|
dataset_train = datasets['train']
|
|
dataset_valid = datasets['valid']
|
|
|
|
dataset_train = dataset_train.map(speech_file_to_array_fn,
|
|
remove_columns=dataset_train.column_names,
|
|
cache_file_name='/path/to/cache/of/train/speech/file')
|
|
|
|
dataset_valid = dataset_valid.map(speech_file_to_array_fn,
|
|
remove_columns=dataset_valid.column_names,
|
|
cache_file_name='/path/to/cache/of/valid/speech/file')
|
|
|
|
print('Tokenization')
|
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(args.pre_trained_model)
|
|
|
|
print('Feature extracting....')
|
|
feature_extractor = Wav2Vec2FeatureExtractor(args.pre_trained_model)
|
|
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
|
|
|
dataset_train = dataset_train.map(prepare_dataset,
|
|
remove_columns=dataset_train.column_names,
|
|
batched=True,
|
|
load_from_cache_file=True,
|
|
cache_file_name='/path/to/train')
|
|
|
|
dataset_valid = dataset_valid.map(prepare_dataset,
|
|
remove_columns=dataset_valid.column_names,
|
|
batched=True,
|
|
load_from_cache_file=True,
|
|
cache_file_name='/path/to/valid')
|
|
|
|
|
|
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
|
|
wer_metric = load_metric("cer")
|
|
|
|
|
|
model = Wav2Vec2ForCTC.from_pretrained(
|
|
args.pre_trained_model,
|
|
vocab_size=len(processor.tokenizer)
|
|
)
|
|
model.freeze_feature_extractor()
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir="/path/to/output",
|
|
group_by_length=True,
|
|
per_device_train_batch_size=3,
|
|
gradient_accumulation_steps=1,
|
|
per_device_eval_batch_size=1,
|
|
metric_for_best_model='cer',
|
|
evaluation_strategy="steps",
|
|
eval_steps=15000,
|
|
logging_strategy="steps",
|
|
logging_steps=15000,
|
|
save_strategy="steps",
|
|
save_steps=15000,
|
|
num_train_epochs=100,
|
|
fp16=True,
|
|
learning_rate=args.learning_rate,
|
|
warmup_steps=args.warmup_steps,
|
|
save_total_limit=3,
|
|
report_to="tensorboard"
|
|
)
|
|
|
|
print('Training model....')
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
data_collator=data_collator,
|
|
args=training_args,
|
|
compute_metrics=partial(compute_metrics, metric=cer_metric, processor=processor),
|
|
train_dataset=dataset_train,
|
|
eval_dataset=dataset_valid,
|
|
tokenizer=processor.feature_extractor,
|
|
)
|
|
|
|
trainer.train()
|
|
|