Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import torch | |
import argparse | |
import torchvision | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from model_main import CRNN, STN_CRNN | |
from logger_utils import CSVWriter, write_json_file | |
from utils import compute_wer_and_cer_for_sample, ctc_decode | |
from dataset import HWRecogIAMDataset, split_dataset, get_dataloaders_for_training | |
def train(hw_model, optimizer, criterion, train_loader, device): | |
""" | |
--------- | |
Arguments | |
--------- | |
hw_model : object | |
handwriting recognition model object | |
optimizer : object | |
optimizer object to be used for optimization | |
criterion : object | |
criterion or loss object to be used as the objective function for optimization | |
train_loader : object | |
train set dataloader object | |
device : str | |
device to be used for running the evaluation | |
------- | |
Returns | |
------- | |
train_loss : float | |
mean training loss for an epoch | |
""" | |
hw_model.train() | |
train_running_loss = 0.0 | |
num_train_samples = len(train_loader.dataset) | |
num_train_batches = len(train_loader) | |
for images, labels, lengths_labels in train_loader: | |
images = images.to(device, dtype=torch.float) | |
labels = labels.to(device, dtype=torch.long) | |
lengths_labels = lengths_labels.to(device, torch.long) | |
batch_size = images.size(0) | |
optimizer.zero_grad() | |
log_probs = hw_model(images) | |
lengths_preds = torch.LongTensor([log_probs.size(0)] * batch_size) | |
lengths_labels = torch.flatten(lengths_labels) | |
loss = criterion(log_probs, labels, lengths_preds, lengths_labels) | |
train_running_loss += loss.item() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_( | |
hw_model.parameters(), 5 | |
) # gradient clipping with 5 | |
optimizer.step() | |
train_loss = train_running_loss / num_train_batches | |
return train_loss | |
def validate(hw_model, criterion, valid_loader, device): | |
""" | |
--------- | |
Arguments | |
--------- | |
hw_model : object | |
handwriting recognition model object | |
criterion : object | |
criterion or loss object to be used as the objective function for optimization | |
valid_loader : object | |
validation set dataloader object | |
device : str | |
device to be used for running the evaluation | |
------- | |
Returns | |
------- | |
a 3 tuple of | |
valid_loss : float | |
mean validation loss for an epoch | |
valid_cer : float | |
mean character error rate (CER) for validation set | |
valid_wer : float | |
mean word error rate (WER) for validation set | |
""" | |
hw_model.eval() | |
valid_running_loss = 0.0 | |
valid_running_cer = 0.0 | |
valid_running_wer = 0.0 | |
num_valid_samples = len(valid_loader.dataset) | |
num_valid_batches = len(valid_loader) | |
count = 0 | |
with torch.no_grad(): | |
for images, labels, lengths_labels in valid_loader: | |
images = images.to(device, dtype=torch.float) | |
labels = labels.to(device, dtype=torch.long) | |
lengths_labels = lengths_labels.to(device, torch.long) | |
batch_size = images.size(0) | |
log_probs = hw_model(images) | |
lengths_preds = torch.LongTensor([log_probs.size(0)] * batch_size) | |
loss = criterion(log_probs, labels, lengths_preds, lengths_labels) | |
valid_running_loss += loss.item() | |
pred_labels = ctc_decode(log_probs) | |
labels_for_eval = labels.cpu().numpy().tolist() | |
lengths_labels_for_eval = lengths_labels.cpu().numpy().tolist() | |
final_labels_for_eval = [] | |
length_label_counter = 0 | |
for pred_label, length_label in zip(pred_labels, lengths_labels_for_eval): | |
label = labels_for_eval[ | |
length_label_counter : length_label_counter + length_label | |
] | |
length_label_counter += length_label | |
final_labels_for_eval.append(label) | |
for i in range(len(final_labels_for_eval)): | |
if len(pred_labels[i]) != 0: | |
str_label = [ | |
HWRecogIAMDataset.LABEL_2_CHAR[i] | |
for i in final_labels_for_eval[i] | |
] | |
str_label = "".join(str_label) | |
str_pred = [ | |
HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[i] | |
] | |
str_pred = "".join(str_pred) | |
cer_sample, wer_sample = compute_wer_and_cer_for_sample( | |
str_pred, str_label | |
) | |
else: | |
cer_sample, wer_sample = 100, 100 | |
valid_running_cer += cer_sample | |
valid_running_wer += wer_sample | |
valid_loss = valid_running_loss / num_valid_batches | |
valid_cer = valid_running_cer / num_valid_samples | |
valid_wer = valid_running_wer / num_valid_samples | |
return valid_loss, valid_cer, valid_wer | |
def train_hw_recognizer(FLAGS): | |
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt") | |
dir_images = os.path.join(FLAGS.dir_dataset, "img") | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" | |
# train only on a CUDA device (GPU) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
print("CUDA device not found, so exiting....") | |
sys.exit(0) | |
# split dataset into train and validation sets | |
train_x, valid_x, train_y, valid_y = split_dataset(file_txt_labels, for_train=True) | |
num_train_samples = len(train_x) | |
num_valid_samples = len(valid_x) | |
# get dataloaders for train and validation sets | |
train_loader, valid_loader = get_dataloaders_for_training( | |
train_x, | |
train_y, | |
valid_x, | |
valid_y, | |
dir_images=dir_images, | |
image_height=FLAGS.image_height, | |
image_width=FLAGS.image_width, | |
batch_size=FLAGS.batch_size, | |
) | |
# create a directory for saving the model | |
dir_model = f"model_{FLAGS.which_hw_model}" | |
if not os.path.isdir(dir_model): | |
print(f"creating directory: {dir_model}") | |
os.makedirs(dir_model) | |
# save train and validation metrics in a csv file | |
file_logger_train = os.path.join(dir_model, "train_metrics.csv") | |
csv_writer = CSVWriter( | |
file_name=file_logger_train, | |
column_names=["epoch", "loss_train", "loss_valid", "cer_valid", "wer_valid"], | |
) | |
file_params = os.path.join(dir_model, "params.json") | |
write_json_file(file_params, vars(FLAGS)) | |
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1 | |
print(f"task - handwriting recognition") | |
print(f"model: {FLAGS.which_hw_model}") | |
print( | |
f"optimizer: {FLAGS.which_optimizer}, learning rate: {FLAGS.learning_rate:.6f}, weight decay: {FLAGS.weight_decay:.8f}" | |
) | |
print( | |
f"batch size: {FLAGS.batch_size}, image height: {FLAGS.image_height}, image width: {FLAGS.image_width}" | |
) | |
print( | |
f"num train samples: {num_train_samples}, num validation samples: {num_valid_samples}\n" | |
) | |
# load the right model | |
if FLAGS.which_hw_model == "crnn": | |
hw_model = CRNN(num_classes, FLAGS.image_height) | |
elif FLAGS.which_hw_model == "stn_crnn": | |
hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width) | |
else: | |
print(f"unidentified option: {FLAGS.which_hw_model}") | |
sys.exit(0) | |
hw_model.to(device) | |
# load the right optimizer based on user option | |
if FLAGS.which_optimizer == "adam": | |
optimizer = torch.optim.Adam( | |
hw_model.parameters(), | |
lr=FLAGS.learning_rate, | |
weight_decay=FLAGS.weight_decay, | |
) | |
elif FLAGS.which_optimizer == "adadelta": | |
optimizer = torch.optim.Adadelta( | |
hw_model.parameters(), | |
lr=FLAGS.learning_rate, | |
rho=0.95, | |
eps=1e-8, | |
weight_decay=FLAGS.weight_decay, | |
) | |
else: | |
print(f"unidentified option: {FLAGS.which_optimizer}") | |
sys.exit(0) | |
# use the CTC loss as the objective function for training | |
criterion = nn.CTCLoss(reduction="mean", zero_infinity=True) | |
# start training the model | |
print(f"training of handwriting recognition model {FLAGS.which_hw_model} started\n") | |
for epoch in range(1, FLAGS.num_epochs + 1): | |
time_start = time.time() | |
train_loss = train(hw_model, optimizer, criterion, train_loader, device) | |
valid_loss, valid_cer, valid_wer = validate( | |
hw_model, criterion, valid_loader, device | |
) | |
time_end = time.time() | |
print( | |
f"epoch: {epoch}/{FLAGS.num_epochs}, time: {time_end-time_start:.3f} sec." | |
) | |
print( | |
f"train loss: {train_loss:.6f}, validation loss: {valid_loss:.6f}, validation cer: {valid_cer:.4f}, validation wer: {valid_wer:.4f}\n" | |
) | |
csv_writer.write_row( | |
[ | |
epoch, | |
round(train_loss, 6), | |
round(valid_loss, 6), | |
round(valid_cer, 4), | |
round(valid_wer, 4), | |
] | |
) | |
torch.save( | |
hw_model.state_dict(), | |
os.path.join( | |
dir_model, | |
f"{FLAGS.which_hw_model}_H_{FLAGS.image_height}_W_{FLAGS.image_width}_E_{epoch}.pth", | |
), | |
) | |
print( | |
f"Training of handwriting recognition model {FLAGS.which_hw_model} complete!!!!" | |
) | |
# close the csv file | |
csv_writer.close() | |
return | |
def main(): | |
learning_rate = 1 | |
# 3e-4 for Adam, 1 for Adadelta | |
weight_decay = 0 | |
# 3e-5 with Adam for both CRNN and STN-CRNN | |
# 0 with Adadelta for CRNN and STN-CRNN | |
batch_size = 64 | |
num_epochs = 100 | |
image_height = 32 | |
image_width = 768 | |
which_hw_model = "crnn" | |
which_optimizer = "adadelta" | |
dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/" | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
) | |
parser.add_argument( | |
"--learning_rate", | |
default=learning_rate, | |
type=float, | |
help="learning rate to use for training", | |
) | |
parser.add_argument( | |
"--weight_decay", | |
default=weight_decay, | |
type=float, | |
help="weight decay to use for training", | |
) | |
parser.add_argument( | |
"--batch_size", | |
default=batch_size, | |
type=int, | |
help="batch size to use for training", | |
) | |
parser.add_argument( | |
"--num_epochs", | |
default=num_epochs, | |
type=int, | |
help="num epochs to train the model", | |
) | |
parser.add_argument( | |
"--image_height", | |
default=image_height, | |
type=int, | |
help="image height to be used to train the model", | |
) | |
parser.add_argument( | |
"--image_width", | |
default=image_width, | |
type=int, | |
help="image width to be used to train the model", | |
) | |
parser.add_argument( | |
"--dir_dataset", | |
default=dir_dataset, | |
type=str, | |
help="full directory path to the dataset", | |
) | |
parser.add_argument( | |
"--which_optimizer", | |
default=which_optimizer, | |
type=str, | |
choices=["adadelta", "adam"], | |
help="which optimizer to use to train", | |
) | |
parser.add_argument( | |
"--which_hw_model", | |
default=which_hw_model, | |
type=str, | |
choices=["crnn", "stn_crnn", "stn_pp_crnn"], | |
help="which model to train", | |
) | |
FLAGS, unparsed = parser.parse_known_args() | |
train_hw_recognizer(FLAGS) | |
return | |
if __name__ == "__main__": | |
main() | |