model_training / train.py
Ranjit's picture
Training in progress, step 1000
2f98415 verified
## 1. Setting Up Environment Variables & Devices
import os
import torch
abs_path = os.path.abspath('.')
# base_dir = os.path.dirname(os.path.dirname(abs_path))
base_dir = os.path.dirname(abs_path)
os.environ['TRANSFORMERS_CACHE'] = os.path.join(base_dir, 'models_cache')
os.environ['TRANSFORMERS_OFFLINE'] = '0'
os.environ['HF_DATASETS_CACHE'] = os.path.join(base_dir, 'datasets_cache')
os.environ['HF_DATASETS_OFFLINE'] = '0'
# device = "GPU" if torch.cuda.is_available() else "CPU"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\n\n Device to be used: {device} \n\n")
## 2. Setting Up Variables
model_name = "openai/whisper-tiny"
# model_name = "openai/whisper-small"
# model_name = "openai/whisper-large-v2"
language = "Odia"
task = "transcribe" # transcribe or translate
print(f"\n\n Loading {model_name} for {language} to {task}...this might take a while.. \n\n")
## 3. Setting Up Training Args
output_dir = "./"
overwrite_output_dir = True
max_steps = 16000
# max_steps = 5
per_device_train_batch_size = 8
# per_device_train_batch_size = 1
per_device_eval_batch_size = 2
# per_device_eval_batch_size = 1
gradient_accumulation_steps = 1
# gradient_accumulation_steps = 1
dataloader_num_workers = 0 #Default: 0 and 0 for Windows
gradient_checkpointing = False
evaluation_strategy ="steps"
# eval_steps = 5
eval_steps = 1000
save_strategy = "steps"
save_steps = 1000
# save_steps = 5
save_total_limit = 5
learning_rate = 1e-5
lr_scheduler_type = "cosine" # "constant", "constant_with_warmup", "cosine", "cosine_with_restarts", "linear"(default), "polynomial", "inverse_sqrt"
warmup_steps = 8000 # (1 epoch)
# warmup_steps = 1
logging_steps = 25
# logging_steps = 1
# weight_decay = 0.01
weight_decay = 0
dropout = 0.1 # any value > 0.1 hurts performance. So, use values between 0.0 and 0.1
load_best_model_at_end = True
metric_for_best_model = "wer"
greater_is_better = False
bf16 = True
# bf16 = False
tf32 = True
# tf32 = False
generation_max_length = 448 # ensure that the generation_max_length is equal to model max_length. model max_length = 448 for whisper-small (see config.json).
report_to = ["tensorboard"]
predict_with_generate = True
push_to_hub = True
# push_to_hub = False
freeze_feature_encoder = False
early_stopping_patience = 10
apply_spec_augment = True
torch_compile = False #Windows not yet supported
optim="adamw_hf" # adamw_hf (default), adamw_torch, adamw_torch_fused (improved), adamw_apex_fused, adamw_anyprecision or adafactor
## 4. Load Datasets
print("\n\n Loading Datasets...this might take a while..\n\n")
from datasets import load_dataset, DatasetDict, Features, Value, Audio
# common_voice = DatasetDict()
# google_fleurs = DatasetDict()
openslr = DatasetDict()
## commonvoice_11.0 + google_fleurs + openslr53
my_dataset = DatasetDict()
# common_voice["train"] = load_dataset("mozilla-foundation/common_voice_13_0", "or", split="train+validation+other", cache_dir=os.path.join(base_dir, 'datasets_cache'), trust_remote_code=True)
#####################
# google_fleurs["train"] = load_dataset("google/fleurs", "or_in", split="train+validation", cache_dir=os.path.join(base_dir, 'datasets_cache'), trust_remote_code=True)
openslr["train"] = load_dataset("Ranjit/or_in_dataset", split="train+validation", cache_dir=os.path.join(base_dir, 'datasets_cache'), trust_remote_code=True)
# common_voice["test"] = load_dataset("mozilla-foundation/common_voice_13_0", "or", split="test", cache_dir=os.path.join(base_dir, 'datasets_cache'))
#####################
# google_fleurs["test"] = load_dataset("google/fleurs", "or_in", split="test", cache_dir=os.path.join(base_dir, 'datasets_cache'))
openslr["test"] = load_dataset("Ranjit/or_in_dataset", split="test", cache_dir=os.path.join(base_dir, 'datasets_cache'), trust_remote_code=True)
# see count of samples in each dataset
print("\n\n Datasets Loaded \n\n")
# print(common_voice)
#####################
# print(google_fleurs)
print(openslr)
## Removing bad samples from common_voice based on upvotes and downvotes
# print("\n BEFORE Filtering by Upvotes (Common Voice): \n")
# print(common_voice["train"])
# # FILTERING!!! Will get 37k Data if >0 and will get 201k Data if >=0 out of 207k
# common_voice["train"] = common_voice["train"].filter(lambda x: (x["up_votes"] - x["down_votes"]) >= 0, num_proc=None)
# print("\n AFTER Filtering by Upvotes (Common Voice): \n")
# print(common_voice["train"])
# print("\n\n So, the datasets to be trained are: \n\n")
# print("\n Common Voice 11.0 - Bangla\n")
# print(common_voice)
#####################
# print("\n Google Fleurs - Bangla \n")
# print(google_fleurs)
print("\n OpenSLR-53 - Odia \n")
print(openslr)
# print("\n")
## 6. Merge Datasets
from datasets import concatenate_datasets, Audio
sampling_rate = 16000
## resample to specified sampling rate
# common_voice = common_voice.cast_column("audio", Audio(sampling_rate))
#####################
# google_fleurs = google_fleurs.cast_column("audio", Audio(sampling_rate))
openslr = openslr.cast_column("audio", Audio(sampling_rate))
## normalise columns to ["audio", "sentence"]
# common_voice = common_voice.remove_columns(
# set(common_voice['test'].features.keys()) - {"audio", "sentence"}
# )
#####################
openslr = openslr.rename_column("transcription", "sentence")
# google_fleurs = google_fleurs.remove_columns(
# set(google_fleurs['test'].features.keys()) - {"audio", "sentence"}
# )
openslr = openslr.remove_columns(
set(openslr['train'].features.keys()) - {"audio", "sentence"}
)
## check if all audio are in float32 dtype or not.
## a fix is: https://github.com/huggingface/datasets/issues/5345
# print("\n Checking all audio dtype is float32 or not... \n")
# print(f'Common Voice Train: {common_voice["train"][0]["audio"]["array"].dtype}')
# print(f'Common Voice Test: {common_voice["test"][0]["audio"]["array"].dtype}')
#####################
# print(f'Google Fleurs Train: {google_fleurs["train"][0]["audio"]["array"].dtype}')
# print(f'Google Fleurs Test: {google_fleurs["test"][0]["audio"]["array"].dtype}')
print(f'OpenSlR: {openslr["train"][0]["audio"]["array"].dtype}')
print("\n")
## merge the three datasets
# my_dataset['train'] = concatenate_datasets([common_voice['train'], google_fleurs['train'], openslr['train']]) #for linux
# my_dataset['train'] = concatenate_datasets([common_voice['train'], google_fleurs['train'], openslr['train']]) #for linux
#####################
my_dataset['train'] = concatenate_datasets([openslr['train']])
my_dataset['test'] = concatenate_datasets([openslr['test']])
# my_dataset['test'] = concatenate_datasets([common_voice['test'], google_fleurs['test'], openslr['test']]) #for linux
# my_dataset['train'] = concatenate_datasets([common_voice['train'], openslr['train']])
# my_dataset['train'] = concatenate_datasets([google_fleurs['train'], openslr['train']]) #for windows no commonvoice as it requires ffmpeg-4
# my_dataset['train'] = google_fleurs['train']
# my_dataset['test'] = common_voice['test']
# my_dataset['test'] = concatenate_datasets([google_fleurs['test']]) #for windows no commonvoice as it requires ffmpeg-4
#shuffle train set with seed=42
my_dataset['train'] = my_dataset['train'].shuffle(seed=10)
print("\n\n AFTER MERGING, train and validation sets are: ")
print(my_dataset)
print("\n")
# print("\n\n AFTER AUGMENTATION, FINAL train and validation sets are: ")
print("\n FINAL DATASET: \n")
print(my_dataset)
## 8. Preprocessing Data
print("\n\n Preprocessing Datasets...this might take a while..\n\n")
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", task="transcribe")
def prepare_dataset(batch):
# load and (possibly) resample audio data to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
inputs = processor.feature_extractor(
audio["array"],
sampling_rate=audio["sampling_rate"],
)
batch["input_features"] = inputs.input_features[0]
# optional pre-processing steps
transcription = batch["sentence"]
# encode target text to label ids
batch["labels"] = tokenizer(transcription).input_ids
return batch
## This,
my_dataset = my_dataset.map(prepare_dataset,
num_proc=1, # if num_proc>1, then mapping might get stuck. use num_proc=1 in that case.
load_from_cache_file=True,
cache_file_names={
"train" : os.path.join(base_dir, 'datasets_cache', 'preprocessed_train_cache_8.arrow'),
"test" : os.path.join(base_dir, 'datasets_cache', 'preprocessed_test_cache_8.arrow'),
}
)
print("\n\n AFTER PREPROCESSING, final train and validation sets are: ")
print(my_dataset)
print("\n")
## 9. Filter too Short or too Long Audio Files
# MAX_DURATION_IN_SECONDS = 30.0
# max_input_length = MAX_DURATION_IN_SECONDS * 16000
# def filter_inputs(input_length):
# """Filter inputs with zero input length or longer than 30s"""
# return 0 < input_length < max_input_length
# my_dataset = my_dataset.filter(filter_inputs, input_columns=["input_length"])
# print("\n\n AFTER FILTERING INPUTS, final train and validation sets are: ")
# print(my_dataset)
# print("\n")
# max_label_length = generation_max_length #(max_label_length should be equal to max_length of model which is equal to generation_max_length)
# def filter_labels(labels_length):
# """Filter label sequences longer than max length (448)"""
# return labels_length < max_label_length
# my_dataset = my_dataset.filter(filter_labels, input_columns=["labels_length"])
# print("\n\n AFTER FILTERING LABELS, final train and validation sets are: ")
# print(my_dataset)
# print("\n")
# import re
# def filter_transcripts(transcript):
# """Filter transcripts with empty strings and samples containing English characters & numbers"""
# pattern = r'^.*[a-zA-Z0-9]+.*$'
# match = re.match(pattern, transcript)
# return len(transcript.split(" ")) > 1 and not bool(match)
# my_dataset = my_dataset.filter(filter_transcripts, input_columns=["sentence"])
# print("\n\n AFTER FILTERING TRANSCRIPTS, final train and validation sets are: ")
# print("\n My FINAL DATASET \n")
# print(my_dataset)
# print("\n")
## Removes unused cached files & returns the number of removed cache files
print("\n Removing UNUSED Cache Files: \n")
try:
# print(f"{common_voice.cleanup_cache_files()} for common_voice")
# print(f"{google_fleurs.cleanup_cache_files()} for google_fleurs")
print(f"{openslr.cleanup_cache_files()} for openslr")
# print(f"{crblp.cleanup_cache_files()} for crblp")
print(f"{my_dataset.cleanup_cache_files()} for my_dataset")
except Exception as e:
print(f"\n\n UNABLE to REMOVE some Cache files. \n Error: {e} \n\n")
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained(model_name)
model = model.to(device)
## 12. Define Data Collator
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)
## 13. Define Evaluation Metrics
import evaluate
wer_metric = evaluate.load("wer", cache_dir=os.path.join(base_dir, "metrics_cache"))
cer_metric = evaluate.load("cer", cache_dir=os.path.join(base_dir, "metrics_cache"))
do_normalize_eval = False
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
if do_normalize_eval:
pred_str = [normalizer(pred) for pred in pred_str]
label_str = [normalizer(label) for label in label_str]
wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer, "wer": wer}
## 14. Load a Pre-Trained Checkpoint
print("\n\n Loading Model to Device..\n\n")
## 15. Override generation arguments
model.config.apply_spec_augment = apply_spec_augment
model.config.max_length = generation_max_length
model.config.dropout = dropout
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
if gradient_checkpointing:
model.config.use_cache = False
if freeze_feature_encoder:
model.freeze_feature_encoder()
model.generation_config.max_length = generation_max_length
## 16. Define the Training Configuration
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
overwrite_output_dir=overwrite_output_dir,
max_steps=max_steps,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=gradient_checkpointing,
dataloader_num_workers=dataloader_num_workers,
evaluation_strategy=evaluation_strategy,
eval_steps=eval_steps,
save_strategy=save_strategy,
save_steps=save_steps,
save_total_limit=save_total_limit,
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
warmup_steps=warmup_steps,
logging_steps=logging_steps,
weight_decay=weight_decay,
load_best_model_at_end=load_best_model_at_end,
metric_for_best_model=metric_for_best_model,
greater_is_better=greater_is_better,
bf16=bf16,
tf32=tf32,
torch_compile=torch_compile,
optim=optim,
generation_max_length=generation_max_length,
report_to=report_to,
predict_with_generate=predict_with_generate,
push_to_hub=push_to_hub,
)
from transformers import Seq2SeqTrainer
import transformers as tf
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=my_dataset["train"],
eval_dataset=my_dataset["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
callbacks=[tf.EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)],
)
processor.save_pretrained("best_model")
## 17. Training
print("\n\n Training STARTED..\n\n")
train_result = trainer.train()
print("\n\n Training COMPLETED...\n\n")
## 18. Evaluating & Saving Metrics & Model
print("\n\n Evaluating Model & Saving Metrics...\n\n")
processor.save_pretrained(save_directory=output_dir)
# trainer.save_model()
metrics = train_result.metrics
trainer.save_metrics("train", metrics)
trainer.save_state()
metrics = trainer.evaluate(
metric_key_prefix="eval",
max_length=training_args.generation_max_length,
num_beams=training_args.generation_num_beams,
)
trainer.save_metrics("eval", metrics)
## 19. Push to Hub
if push_to_hub:
print("\n\n Pushing to Hub...\n\n")
trainer.create_model_card()
# trainer.push_to_hub(**kwargs)
trainer.push_to_hub()
print("\n\n DONEEEEEE \n\n")