|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import os |
|
import json |
|
import time |
|
|
|
import numpy as np |
|
|
|
from transformers import Trainer |
|
from transformers import Wav2Vec2ForCTC |
|
from transformers import TrainingArguments |
|
from transformers import Wav2Vec2Processor |
|
from transformers import Wav2Vec2CTCTokenizer |
|
from transformers import Wav2Vec2FeatureExtractor |
|
|
|
from datasets import load_dataset, load_metric, Audio, concatenate_datasets, load_from_disk |
|
|
|
from aim import Run |
|
from aim.hugging_face import AimCallback |
|
|
|
import fire |
|
|
|
from aspram.collator import DataCollatorCTCWithPadding |
|
from aspram.utils import clean_characters, extract_all_chars, prepare_dataset |
|
|
|
|
|
def load_data(dataset_name: str, *, split: str): |
|
dataset_name = dataset_name.replace(' ', '') |
|
|
|
if '+' in dataset_name: |
|
return concatenate_datasets([ |
|
load_data(name, split=split) |
|
for name in dataset_name.split('+') |
|
]) |
|
|
|
if '*' in dataset_name: |
|
a, _, b = dataset_name.partition('*') |
|
if a.isnumeric(): |
|
num_repeats = int(a) |
|
dataset_name = b |
|
else: |
|
num_repeats = int(b) |
|
dataset_name = a |
|
|
|
dataset = load_data(dataset_name, split=split) |
|
|
|
return concatenate_datasets([ |
|
dataset |
|
for _ in range(num_repeats) |
|
]) |
|
|
|
if 'teacher' in dataset_name: |
|
dataset = load_from_disk( |
|
dataset_name, |
|
).filter( |
|
lambda sample: len(sample['audio']['array']) < 250_000 |
|
) |
|
elif 'common_voice' in dataset_name: |
|
dataset = load_dataset( |
|
dataset_name, |
|
"hy-AM", |
|
split="train+validation+other" if split == 'train' else split, |
|
use_auth_token=True, |
|
) |
|
else: |
|
dataset = load_dataset( |
|
dataset_name, |
|
'hy_am', |
|
split='train', |
|
).map( |
|
lambda sample: dict(sentence=sample['transcription']) |
|
).filter( |
|
lambda sample: sample['num_samples'] < 250_000 |
|
) |
|
|
|
non_wanted_column_name = set(dataset.column_names) - set(['audio', 'path', 'sentence', 'client_id']) |
|
|
|
dataset = dataset.map(remove_columns=non_wanted_column_name).cast_column("audio", Audio(sampling_rate=16_000)) |
|
|
|
return dataset |
|
|
|
|
|
def exec( |
|
*, |
|
batch_size: int, |
|
lr: float, |
|
warmup_steps: int = 2000, |
|
grad_acc: int = 1, |
|
group_by_length: bool = True, |
|
fp16: bool = True, |
|
bf16: bool = False, |
|
pretrained_model: str = "facebook/wav2vec2-xls-r-2b", |
|
dataset: str = "mozilla-foundation/common_voice_8_0", |
|
num_train_epochs: int = 1200, |
|
blacklist_enabled: bool = True, |
|
seed: int = 42, |
|
|
|
apply_gaussian_noise_with_p: float = 0, |
|
apply_gain_with_p: float = 0, |
|
apply_pitch_shift_with_p: float = 0, |
|
apply_time_stretch_with_p: float = 0, |
|
|
|
mask_time_prob: float = 0.05, |
|
mask_time_length: int = 10, |
|
mask_time_min_masks: int = 2, |
|
mask_feature_prob: float = 0, |
|
mask_feature_length: int = 10, |
|
mask_feature_min_masks: int = 0, |
|
|
|
layerdrop: float = 0, |
|
activation_dropout: float = 0.1, |
|
|
|
lower: bool = False, |
|
only_mesropatar: bool = False, |
|
gradient_checkpointing: bool = False, |
|
resume_from_hash: str = None, |
|
): |
|
if bf16: |
|
fp16 = False |
|
fire_args = locals() |
|
|
|
run = Run(resume_from_hash, log_system_params=(not resume_from_hash)) |
|
if not resume_from_hash: |
|
timestr = time.strftime("%Y%m%d-%H%M%S") |
|
repo_name = os.path.join('models', timestr) |
|
for key, value in fire_args.items(): |
|
run['hparams', key] = value |
|
run['fire', key] = value |
|
else: |
|
repo_name = run['hparams', 'output_dir'] |
|
run_hash = run.hash |
|
run = None |
|
|
|
|
|
train_dataset = load_data(dataset, split="train") |
|
|
|
blacklist_client_ids = set() |
|
blacklist_sentences = set() |
|
|
|
if blacklist_enabled: |
|
blacklist_client_ids = { |
|
"93fa435db2b9e077af647c9f846d8b6031bcb1f6cd731e894a835e70a0ab4aec1faffce01c882bdcdcb854b98b601c83a1c412bae8e5ee411556f0e2f88c1c5c", |
|
"f0aba38a8ab8705a40d05d96829ded5738a7eec7a9a182394c2ed288fc1c64553abcb1e0c4c966ffab9e8b76c27616b9f0503f92c42fe11249af36c50d3de5ef", |
|
"a528aa436a34dce3b4ddc198c105ebb904967acdd04157bd1b0e0b2ffadd99b36a6cc5fe76f23c3dd2263d1507bec6038c41cb521ac8ee34126133e559df9e75", |
|
"b83375c41b8ef9ab1b64491b624302b1541b0ba8496ed4e5cb4a751766d7a2cf7430e49e7118eaac98f5ae478d8cdd2b59d18526632297185bbc2e10e2126b18", |
|
"330411ed21c5d9cda96180ac633b4dd10f5b6e50968e83a64f0016c9e15f22445fa8f396ef92b70ff03fc78e36b35b1693af60431b61b50b706aa58a00f80641", |
|
} |
|
|
|
|
|
valid_dataset = load_data("yerevann/common_voice_9_0", split="test") |
|
|
|
|
|
valid_client_ids = set(valid_dataset['client_id']) - { None } |
|
blacklist_sentences = set(valid_dataset['sentence']) |
|
blacklist_client_ids |= valid_client_ids |
|
|
|
train_dataset = train_dataset.filter( |
|
lambda sample: ( |
|
sample.get("client_id") not in blacklist_client_ids |
|
and |
|
sample.get("sentence") not in blacklist_sentences |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = train_dataset.map(clean_characters, fn_kwargs=dict(lower=lower, only_mesropatar=only_mesropatar)) |
|
valid_dataset = valid_dataset.map(clean_characters, fn_kwargs=dict(lower=lower, only_mesropatar=only_mesropatar)) |
|
|
|
if 'models/' in pretrained_model: |
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model) |
|
elif not resume_from_hash: |
|
vocab_train = train_dataset.map( |
|
extract_all_chars, |
|
batched=True, |
|
batch_size=-1, |
|
keep_in_memory=True, |
|
remove_columns=train_dataset.column_names, |
|
) |
|
vocab_valid = valid_dataset.map( |
|
extract_all_chars, |
|
batched=True, |
|
batch_size=-1, |
|
keep_in_memory=True, |
|
remove_columns=valid_dataset.column_names, |
|
) |
|
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_valid["vocab"][0])) |
|
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))} |
|
vocab_dict["|"] = vocab_dict[" "] |
|
del vocab_dict[" "] |
|
|
|
vocab_dict["[UNK]"] = len(vocab_dict) |
|
vocab_dict["[PAD]"] = len(vocab_dict) |
|
|
|
with open("vocab.json", "w") as vocab_file: |
|
json.dump(vocab_dict, vocab_file) |
|
|
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained( |
|
"./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|" |
|
) |
|
tokenizer.push_to_hub(repo_name) |
|
else: |
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(repo_name) |
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor( |
|
feature_size=1, |
|
sampling_rate=16000, |
|
padding_value=0.0, |
|
do_normalize=True, |
|
return_attention_mask=True, |
|
) |
|
processor = Wav2Vec2Processor( |
|
feature_extractor=feature_extractor, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
|
|
train_dataset = train_dataset.cast_column( |
|
"audio", Audio(sampling_rate=16_000) |
|
) |
|
valid_dataset = valid_dataset.cast_column( |
|
"audio", Audio(sampling_rate=16_000) |
|
) |
|
|
|
train_dataset = train_dataset.map( |
|
prepare_dataset, remove_columns=train_dataset.column_names, |
|
fn_kwargs=dict(processor=processor) |
|
) |
|
valid_dataset = valid_dataset.map( |
|
prepare_dataset, remove_columns=valid_dataset.column_names, |
|
fn_kwargs=dict(processor=processor) |
|
) |
|
|
|
data_collator = DataCollatorCTCWithPadding( |
|
processor=processor, |
|
padding=True, |
|
sample_rate=16_000, |
|
apply_gaussian_noise_with_p=apply_gaussian_noise_with_p, |
|
apply_gain_with_p=apply_gain_with_p, |
|
apply_pitch_shift_with_p=apply_pitch_shift_with_p, |
|
apply_time_stretch_with_p=apply_time_stretch_with_p, |
|
) |
|
|
|
def compute_metrics(pred): |
|
pred_logits = pred.predictions |
|
pred_ids = np.argmax(pred_logits, axis=-1) |
|
|
|
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
|
pred_str = processor.batch_decode(pred_ids) |
|
|
|
label_str = processor.batch_decode(pred.label_ids, group_tokens=False) |
|
|
|
wer = wer_metric.compute(predictions=pred_str, references=label_str) |
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
return {"wer": wer, "cer": cer} |
|
|
|
wer_metric = load_metric("wer") |
|
cer_metric = load_metric("cer") |
|
|
|
def model_init(): |
|
from transformers import Wav2Vec2Config |
|
model = Wav2Vec2ForCTC.from_pretrained( |
|
pretrained_model, |
|
attention_dropout=0.0, |
|
hidden_dropout=0.0, |
|
feat_proj_dropout=0.0, |
|
mask_time_prob=mask_time_prob, |
|
mask_time_length=mask_time_length, |
|
mask_time_min_masks=mask_time_min_masks, |
|
mask_feature_prob=mask_feature_prob, |
|
mask_feature_length=mask_feature_length, |
|
mask_feature_min_masks=mask_feature_min_masks, |
|
layerdrop=layerdrop, |
|
activation_dropout=activation_dropout, |
|
ctc_loss_reduction="mean", |
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
vocab_size=len(processor.tokenizer), |
|
) |
|
model.freeze_feature_extractor() |
|
return model |
|
|
|
training_args = TrainingArguments( |
|
output_dir=repo_name, |
|
group_by_length=group_by_length, |
|
per_device_train_batch_size=batch_size, |
|
gradient_accumulation_steps=grad_acc, |
|
evaluation_strategy="steps", |
|
num_train_epochs=num_train_epochs, |
|
gradient_checkpointing=gradient_checkpointing if resume_from_hash is None else True, |
|
fp16=fp16, |
|
bf16=bf16, |
|
save_steps=4000, |
|
eval_steps=200, |
|
logging_steps=200, |
|
learning_rate=lr, |
|
warmup_steps=warmup_steps, |
|
save_total_limit=1, |
|
push_to_hub=True, |
|
metric_for_best_model="eval_wer", |
|
greater_is_better=False, |
|
seed=seed, |
|
) |
|
|
|
aim_callback = AimCallback() |
|
aim_callback._run_hash = run_hash |
|
|
|
|
|
print(train_dataset) |
|
|
|
|
|
trainer = Trainer( |
|
model_init=model_init, |
|
data_collator=data_collator, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=train_dataset, |
|
eval_dataset=valid_dataset, |
|
tokenizer=processor.feature_extractor, |
|
callbacks=[aim_callback], |
|
) |
|
|
|
trainer.train(resume_from_checkpoint=bool(resume_from_hash)) |
|
|
|
trainer.push_to_hub() |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(exec) |
|
|