aspram / aspram /fine_tune.py
lilitket's picture
Move to package
cab7f7b
raw history blame
No virus
11.7 kB
from typing import Any, Dict, List, Optional, Union
import os
import json
import time
import numpy as np
from transformers import Trainer
from transformers import Wav2Vec2ForCTC
from transformers import TrainingArguments
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from datasets import load_dataset, load_metric, Audio, concatenate_datasets, load_from_disk
from aim import Run
from aim.hugging_face import AimCallback
import fire
from aspram.collator import DataCollatorCTCWithPadding
from aspram.utils import clean_characters, extract_all_chars, prepare_dataset
def load_data(dataset_name: str, *, split: str):
dataset_name = dataset_name.replace(' ', '')
if '+' in dataset_name:
return concatenate_datasets([
load_data(name, split=split)
for name in dataset_name.split('+')
])
if '*' in dataset_name:
a, _, b = dataset_name.partition('*')
if a.isnumeric():
num_repeats = int(a)
dataset_name = b
else:
num_repeats = int(b)
dataset_name = a
dataset = load_data(dataset_name, split=split)
return concatenate_datasets([
dataset
for _ in range(num_repeats)
])
if 'teacher' in dataset_name:
dataset = load_from_disk(
dataset_name,
).filter(
lambda sample: len(sample['audio']['array']) < 250_000
)
elif 'common_voice' in dataset_name:
dataset = load_dataset(
dataset_name,
"hy-AM",
split="train+validation+other" if split == 'train' else split,
use_auth_token=True,
)
else:
dataset = load_dataset(
dataset_name,
'hy_am',
split='train',
).map(
lambda sample: dict(sentence=sample['transcription'])
).filter(
lambda sample: sample['num_samples'] < 250_000
)
non_wanted_column_name = set(dataset.column_names) - set(['audio', 'path', 'sentence', 'client_id'])
dataset = dataset.map(remove_columns=non_wanted_column_name).cast_column("audio", Audio(sampling_rate=16_000))
return dataset
def exec(
*,
batch_size: int,
lr: float,
warmup_steps: int = 2000,
grad_acc: int = 1,
group_by_length: bool = True,
fp16: bool = True,
bf16: bool = False,
pretrained_model: str = "facebook/wav2vec2-xls-r-2b",
dataset: str = "mozilla-foundation/common_voice_8_0",
num_train_epochs: int = 1200,
blacklist_enabled: bool = True,
seed: int = 42,
# random augment
apply_gaussian_noise_with_p: float = 0,
apply_gain_with_p: float = 0,
apply_pitch_shift_with_p: float = 0,
apply_time_stretch_with_p: float = 0,
# spec augment
mask_time_prob: float = 0.05, # value that is used in the previous models
mask_time_length: int = 10,
mask_time_min_masks: int = 2,
mask_feature_prob: float = 0,
mask_feature_length: int = 10,
mask_feature_min_masks: int = 0,
layerdrop: float = 0,
activation_dropout: float = 0.1,
lower: bool = False,
only_mesropatar: bool = False,
gradient_checkpointing: bool = False,
resume_from_hash: str = None,
):
if bf16:
fp16 = False
fire_args = locals()
run = Run(resume_from_hash, log_system_params=(not resume_from_hash))
if not resume_from_hash:
timestr = time.strftime("%Y%m%d-%H%M%S")
repo_name = os.path.join('models', timestr)
for key, value in fire_args.items():
run['hparams', key] = value
run['fire', key] = value
else:
repo_name = run['hparams', 'output_dir']
run_hash = run.hash
run = None
train_dataset = load_data(dataset, split="train")
blacklist_client_ids = set()
blacklist_sentences = set()
if blacklist_enabled:
blacklist_client_ids = {
"93fa435db2b9e077af647c9f846d8b6031bcb1f6cd731e894a835e70a0ab4aec1faffce01c882bdcdcb854b98b601c83a1c412bae8e5ee411556f0e2f88c1c5c",
"f0aba38a8ab8705a40d05d96829ded5738a7eec7a9a182394c2ed288fc1c64553abcb1e0c4c966ffab9e8b76c27616b9f0503f92c42fe11249af36c50d3de5ef",
"a528aa436a34dce3b4ddc198c105ebb904967acdd04157bd1b0e0b2ffadd99b36a6cc5fe76f23c3dd2263d1507bec6038c41cb521ac8ee34126133e559df9e75",
"b83375c41b8ef9ab1b64491b624302b1541b0ba8496ed4e5cb4a751766d7a2cf7430e49e7118eaac98f5ae478d8cdd2b59d18526632297185bbc2e10e2126b18",
"330411ed21c5d9cda96180ac633b4dd10f5b6e50968e83a64f0016c9e15f22445fa8f396ef92b70ff03fc78e36b35b1693af60431b61b50b706aa58a00f80641",
}
# valid_dataset = load_data(dataset, split="test")
valid_dataset = load_data("yerevann/common_voice_9_0", split="test")
# train_client_ids = set(train_dataset['client_id']) - { None }
valid_client_ids = set(valid_dataset['client_id']) - { None }
blacklist_sentences = set(valid_dataset['sentence'])
blacklist_client_ids |= valid_client_ids
train_dataset = train_dataset.filter(
lambda sample: (
sample.get("client_id") not in blacklist_client_ids
and
sample.get("sentence") not in blacklist_sentences
)
)
# print('\n' * 10 + '================================' + '\n' * 10)
# print(train_client_ids & valid_client_ids)
# print('\n' * 10 + '================================' + '\n' * 10)
# train_dataset = train_dataset.remove_columns(
# [
# "accent",
# "age",
# "client_id",
# "down_votes",
# "gender",
# "locale",
# "segment",
# "up_votes",
# ]
# )
# valid_dataset = valid_dataset.remove_columns(
# [
# "accent",
# "age",
# "client_id",
# "down_votes",
# "gender",
# "locale",
# "segment",
# "up_votes",
# ]
# )
train_dataset = train_dataset.map(clean_characters, fn_kwargs=dict(lower=lower, only_mesropatar=only_mesropatar))
valid_dataset = valid_dataset.map(clean_characters, fn_kwargs=dict(lower=lower, only_mesropatar=only_mesropatar))
if 'models/' in pretrained_model:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model)
elif not resume_from_hash:
vocab_train = train_dataset.map(
extract_all_chars,
batched=True,
batch_size=-1,
keep_in_memory=True,
remove_columns=train_dataset.column_names,
)
vocab_valid = valid_dataset.map(
extract_all_chars,
batched=True,
batch_size=-1,
keep_in_memory=True,
remove_columns=valid_dataset.column_names,
)
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_valid["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
with open("vocab.json", "w") as vocab_file:
json.dump(vocab_dict, vocab_file)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
"./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
)
tokenizer.push_to_hub(repo_name) # smth is wrong here
else:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(repo_name)
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0.0,
do_normalize=True,
return_attention_mask=True,
)
processor = Wav2Vec2Processor(
feature_extractor=feature_extractor,
tokenizer=tokenizer,
)
train_dataset = train_dataset.cast_column(
"audio", Audio(sampling_rate=16_000)
)
valid_dataset = valid_dataset.cast_column(
"audio", Audio(sampling_rate=16_000)
)
train_dataset = train_dataset.map(
prepare_dataset, remove_columns=train_dataset.column_names,
fn_kwargs=dict(processor=processor)
)
valid_dataset = valid_dataset.map(
prepare_dataset, remove_columns=valid_dataset.column_names,
fn_kwargs=dict(processor=processor)
)
data_collator = DataCollatorCTCWithPadding(
processor=processor,
padding=True,
sample_rate=16_000,
apply_gaussian_noise_with_p=apply_gaussian_noise_with_p,
apply_gain_with_p=apply_gain_with_p,
apply_pitch_shift_with_p=apply_pitch_shift_with_p,
apply_time_stretch_with_p=apply_time_stretch_with_p,
)
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer, "cer": cer}
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")
def model_init():
from transformers import Wav2Vec2Config
model = Wav2Vec2ForCTC.from_pretrained(
pretrained_model,
attention_dropout=0.0,
hidden_dropout=0.0,
feat_proj_dropout=0.0,
mask_time_prob=mask_time_prob,
mask_time_length=mask_time_length,
mask_time_min_masks=mask_time_min_masks,
mask_feature_prob=mask_feature_prob,
mask_feature_length=mask_feature_length,
mask_feature_min_masks=mask_feature_min_masks,
layerdrop=layerdrop,
activation_dropout=activation_dropout,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
model.freeze_feature_extractor()
return model
training_args = TrainingArguments(
output_dir=repo_name,
group_by_length=group_by_length,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=grad_acc,
evaluation_strategy="steps",
num_train_epochs=num_train_epochs,
gradient_checkpointing=gradient_checkpointing if resume_from_hash is None else True,
fp16=fp16,
bf16=bf16,
save_steps=4000,
eval_steps=200,
logging_steps=200,
learning_rate=lr, # TODO
warmup_steps=warmup_steps,
save_total_limit=1,
push_to_hub=True,
metric_for_best_model="eval_wer",
greater_is_better=False,
seed=seed,
)
aim_callback = AimCallback()
aim_callback._run_hash = run_hash
print(train_dataset)
# run = aim_callback.experiment
trainer = Trainer(
model_init=model_init,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
tokenizer=processor.feature_extractor,
callbacks=[aim_callback],
)
trainer.train(resume_from_checkpoint=bool(resume_from_hash))
trainer.push_to_hub()
if __name__ == "__main__":
fire.Fire(exec)