from typing import Any, Dict, List, Optional, Union import os import json import time import numpy as np from transformers import Trainer from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Wav2Vec2Processor from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from datasets import load_dataset, load_metric, Audio, concatenate_datasets, load_from_disk from aim import Run from aim.hugging_face import AimCallback import fire from aspram.collator import DataCollatorCTCWithPadding from aspram.utils import clean_characters, extract_all_chars, prepare_dataset def load_data(dataset_name: str, *, split: str): dataset_name = dataset_name.replace(' ', '') if '+' in dataset_name: return concatenate_datasets([ load_data(name, split=split) for name in dataset_name.split('+') ]) if '*' in dataset_name: a, _, b = dataset_name.partition('*') if a.isnumeric(): num_repeats = int(a) dataset_name = b else: num_repeats = int(b) dataset_name = a dataset = load_data(dataset_name, split=split) return concatenate_datasets([ dataset for _ in range(num_repeats) ]) if 'teacher' in dataset_name: dataset = load_from_disk( dataset_name, ).filter( lambda sample: len(sample['audio']['array']) < 250_000 ) elif 'common_voice' in dataset_name: dataset = load_dataset( dataset_name, "hy-AM", split="train+validation+other" if split == 'train' else split, use_auth_token=True, ) else: dataset = load_dataset( dataset_name, 'hy_am', split='train', ).map( lambda sample: dict(sentence=sample['transcription']) ).filter( lambda sample: sample['num_samples'] < 250_000 ) non_wanted_column_name = set(dataset.column_names) - set(['audio', 'path', 'sentence', 'client_id']) dataset = dataset.map(remove_columns=non_wanted_column_name).cast_column("audio", Audio(sampling_rate=16_000)) return dataset def exec( *, batch_size: int, lr: float, warmup_steps: int = 2000, grad_acc: int = 1, group_by_length: bool = True, fp16: bool = True, bf16: bool = False, pretrained_model: str = "facebook/wav2vec2-xls-r-2b", dataset: str = "mozilla-foundation/common_voice_8_0", num_train_epochs: int = 1200, blacklist_enabled: bool = True, seed: int = 42, # random augment apply_gaussian_noise_with_p: float = 0, apply_gain_with_p: float = 0, apply_pitch_shift_with_p: float = 0, apply_time_stretch_with_p: float = 0, # spec augment mask_time_prob: float = 0.05, # value that is used in the previous models mask_time_length: int = 10, mask_time_min_masks: int = 2, mask_feature_prob: float = 0, mask_feature_length: int = 10, mask_feature_min_masks: int = 0, layerdrop: float = 0, activation_dropout: float = 0.1, lower: bool = False, only_mesropatar: bool = False, gradient_checkpointing: bool = False, resume_from_hash: str = None, ): if bf16: fp16 = False fire_args = locals() run = Run(resume_from_hash, log_system_params=(not resume_from_hash)) if not resume_from_hash: timestr = time.strftime("%Y%m%d-%H%M%S") repo_name = os.path.join('models', timestr) for key, value in fire_args.items(): run['hparams', key] = value run['fire', key] = value else: repo_name = run['hparams', 'output_dir'] run_hash = run.hash run = None train_dataset = load_data(dataset, split="train") blacklist_client_ids = set() blacklist_sentences = set() if blacklist_enabled: blacklist_client_ids = { "93fa435db2b9e077af647c9f846d8b6031bcb1f6cd731e894a835e70a0ab4aec1faffce01c882bdcdcb854b98b601c83a1c412bae8e5ee411556f0e2f88c1c5c", "f0aba38a8ab8705a40d05d96829ded5738a7eec7a9a182394c2ed288fc1c64553abcb1e0c4c966ffab9e8b76c27616b9f0503f92c42fe11249af36c50d3de5ef", "a528aa436a34dce3b4ddc198c105ebb904967acdd04157bd1b0e0b2ffadd99b36a6cc5fe76f23c3dd2263d1507bec6038c41cb521ac8ee34126133e559df9e75", "b83375c41b8ef9ab1b64491b624302b1541b0ba8496ed4e5cb4a751766d7a2cf7430e49e7118eaac98f5ae478d8cdd2b59d18526632297185bbc2e10e2126b18", "330411ed21c5d9cda96180ac633b4dd10f5b6e50968e83a64f0016c9e15f22445fa8f396ef92b70ff03fc78e36b35b1693af60431b61b50b706aa58a00f80641", } # valid_dataset = load_data(dataset, split="test") valid_dataset = load_data("yerevann/common_voice_9_0", split="test") # train_client_ids = set(train_dataset['client_id']) - { None } valid_client_ids = set(valid_dataset['client_id']) - { None } blacklist_sentences = set(valid_dataset['sentence']) blacklist_client_ids |= valid_client_ids train_dataset = train_dataset.filter( lambda sample: ( sample.get("client_id") not in blacklist_client_ids and sample.get("sentence") not in blacklist_sentences ) ) # print('\n' * 10 + '================================' + '\n' * 10) # print(train_client_ids & valid_client_ids) # print('\n' * 10 + '================================' + '\n' * 10) # train_dataset = train_dataset.remove_columns( # [ # "accent", # "age", # "client_id", # "down_votes", # "gender", # "locale", # "segment", # "up_votes", # ] # ) # valid_dataset = valid_dataset.remove_columns( # [ # "accent", # "age", # "client_id", # "down_votes", # "gender", # "locale", # "segment", # "up_votes", # ] # ) train_dataset = train_dataset.map(clean_characters, fn_kwargs=dict(lower=lower, only_mesropatar=only_mesropatar)) valid_dataset = valid_dataset.map(clean_characters, fn_kwargs=dict(lower=lower, only_mesropatar=only_mesropatar)) if 'models/' in pretrained_model: tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model) elif not resume_from_hash: vocab_train = train_dataset.map( extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=train_dataset.column_names, ) vocab_valid = valid_dataset.map( extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=valid_dataset.column_names, ) vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_valid["vocab"][0])) vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))} vocab_dict["|"] = vocab_dict[" "] del vocab_dict[" "] vocab_dict["[UNK]"] = len(vocab_dict) vocab_dict["[PAD]"] = len(vocab_dict) with open("vocab.json", "w") as vocab_file: json.dump(vocab_dict, vocab_file) tokenizer = Wav2Vec2CTCTokenizer.from_pretrained( "./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|" ) tokenizer.push_to_hub(repo_name) # smth is wrong here else: tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(repo_name) feature_extractor = Wav2Vec2FeatureExtractor( feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True, ) processor = Wav2Vec2Processor( feature_extractor=feature_extractor, tokenizer=tokenizer, ) train_dataset = train_dataset.cast_column( "audio", Audio(sampling_rate=16_000) ) valid_dataset = valid_dataset.cast_column( "audio", Audio(sampling_rate=16_000) ) train_dataset = train_dataset.map( prepare_dataset, remove_columns=train_dataset.column_names, fn_kwargs=dict(processor=processor) ) valid_dataset = valid_dataset.map( prepare_dataset, remove_columns=valid_dataset.column_names, fn_kwargs=dict(processor=processor) ) data_collator = DataCollatorCTCWithPadding( processor=processor, padding=True, sample_rate=16_000, apply_gaussian_noise_with_p=apply_gaussian_noise_with_p, apply_gain_with_p=apply_gain_with_p, apply_pitch_shift_with_p=apply_pitch_shift_with_p, apply_time_stretch_with_p=apply_time_stretch_with_p, ) def compute_metrics(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis=-1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) # we do not want to group tokens when computing the metrics label_str = processor.batch_decode(pred.label_ids, group_tokens=False) wer = wer_metric.compute(predictions=pred_str, references=label_str) cer = cer_metric.compute(predictions=pred_str, references=label_str) return {"wer": wer, "cer": cer} wer_metric = load_metric("wer") cer_metric = load_metric("cer") def model_init(): from transformers import Wav2Vec2Config model = Wav2Vec2ForCTC.from_pretrained( pretrained_model, attention_dropout=0.0, hidden_dropout=0.0, feat_proj_dropout=0.0, mask_time_prob=mask_time_prob, mask_time_length=mask_time_length, mask_time_min_masks=mask_time_min_masks, mask_feature_prob=mask_feature_prob, mask_feature_length=mask_feature_length, mask_feature_min_masks=mask_feature_min_masks, layerdrop=layerdrop, activation_dropout=activation_dropout, ctc_loss_reduction="mean", pad_token_id=processor.tokenizer.pad_token_id, vocab_size=len(processor.tokenizer), ) model.freeze_feature_extractor() return model training_args = TrainingArguments( output_dir=repo_name, group_by_length=group_by_length, per_device_train_batch_size=batch_size, gradient_accumulation_steps=grad_acc, evaluation_strategy="steps", num_train_epochs=num_train_epochs, gradient_checkpointing=gradient_checkpointing if resume_from_hash is None else True, fp16=fp16, bf16=bf16, save_steps=4000, eval_steps=200, logging_steps=200, learning_rate=lr, # TODO warmup_steps=warmup_steps, save_total_limit=1, push_to_hub=True, metric_for_best_model="eval_wer", greater_is_better=False, seed=seed, ) aim_callback = AimCallback() aim_callback._run_hash = run_hash print(train_dataset) # run = aim_callback.experiment trainer = Trainer( model_init=model_init, data_collator=data_collator, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=valid_dataset, tokenizer=processor.feature_extractor, callbacks=[aim_callback], ) trainer.train(resume_from_checkpoint=bool(resume_from_hash)) trainer.push_to_hub() if __name__ == "__main__": fire.Fire(exec)