abhishekrs4's picture
code formatting
44066b7
import os
import sys
import time
import torch
import argparse
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from model_main import CRNN, STN_CRNN
from logger_utils import CSVWriter, write_json_file
from utils import compute_wer_and_cer_for_sample, ctc_decode
from dataset import HWRecogIAMDataset, split_dataset, get_dataloaders_for_training
def train(hw_model, optimizer, criterion, train_loader, device):
"""
---------
Arguments
---------
hw_model : object
handwriting recognition model object
optimizer : object
optimizer object to be used for optimization
criterion : object
criterion or loss object to be used as the objective function for optimization
train_loader : object
train set dataloader object
device : str
device to be used for running the evaluation
-------
Returns
-------
train_loss : float
mean training loss for an epoch
"""
hw_model.train()
train_running_loss = 0.0
num_train_samples = len(train_loader.dataset)
num_train_batches = len(train_loader)
for images, labels, lengths_labels in train_loader:
images = images.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.long)
lengths_labels = lengths_labels.to(device, torch.long)
batch_size = images.size(0)
optimizer.zero_grad()
log_probs = hw_model(images)
lengths_preds = torch.LongTensor([log_probs.size(0)] * batch_size)
lengths_labels = torch.flatten(lengths_labels)
loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
train_running_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(
hw_model.parameters(), 5
) # gradient clipping with 5
optimizer.step()
train_loss = train_running_loss / num_train_batches
return train_loss
def validate(hw_model, criterion, valid_loader, device):
"""
---------
Arguments
---------
hw_model : object
handwriting recognition model object
criterion : object
criterion or loss object to be used as the objective function for optimization
valid_loader : object
validation set dataloader object
device : str
device to be used for running the evaluation
-------
Returns
-------
a 3 tuple of
valid_loss : float
mean validation loss for an epoch
valid_cer : float
mean character error rate (CER) for validation set
valid_wer : float
mean word error rate (WER) for validation set
"""
hw_model.eval()
valid_running_loss = 0.0
valid_running_cer = 0.0
valid_running_wer = 0.0
num_valid_samples = len(valid_loader.dataset)
num_valid_batches = len(valid_loader)
count = 0
with torch.no_grad():
for images, labels, lengths_labels in valid_loader:
images = images.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.long)
lengths_labels = lengths_labels.to(device, torch.long)
batch_size = images.size(0)
log_probs = hw_model(images)
lengths_preds = torch.LongTensor([log_probs.size(0)] * batch_size)
loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
valid_running_loss += loss.item()
pred_labels = ctc_decode(log_probs)
labels_for_eval = labels.cpu().numpy().tolist()
lengths_labels_for_eval = lengths_labels.cpu().numpy().tolist()
final_labels_for_eval = []
length_label_counter = 0
for pred_label, length_label in zip(pred_labels, lengths_labels_for_eval):
label = labels_for_eval[
length_label_counter : length_label_counter + length_label
]
length_label_counter += length_label
final_labels_for_eval.append(label)
for i in range(len(final_labels_for_eval)):
if len(pred_labels[i]) != 0:
str_label = [
HWRecogIAMDataset.LABEL_2_CHAR[i]
for i in final_labels_for_eval[i]
]
str_label = "".join(str_label)
str_pred = [
HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[i]
]
str_pred = "".join(str_pred)
cer_sample, wer_sample = compute_wer_and_cer_for_sample(
str_pred, str_label
)
else:
cer_sample, wer_sample = 100, 100
valid_running_cer += cer_sample
valid_running_wer += wer_sample
valid_loss = valid_running_loss / num_valid_batches
valid_cer = valid_running_cer / num_valid_samples
valid_wer = valid_running_wer / num_valid_samples
return valid_loss, valid_cer, valid_wer
def train_hw_recognizer(FLAGS):
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
dir_images = os.path.join(FLAGS.dir_dataset, "img")
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# train only on a CUDA device (GPU)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
print("CUDA device not found, so exiting....")
sys.exit(0)
# split dataset into train and validation sets
train_x, valid_x, train_y, valid_y = split_dataset(file_txt_labels, for_train=True)
num_train_samples = len(train_x)
num_valid_samples = len(valid_x)
# get dataloaders for train and validation sets
train_loader, valid_loader = get_dataloaders_for_training(
train_x,
train_y,
valid_x,
valid_y,
dir_images=dir_images,
image_height=FLAGS.image_height,
image_width=FLAGS.image_width,
batch_size=FLAGS.batch_size,
)
# create a directory for saving the model
dir_model = f"model_{FLAGS.which_hw_model}"
if not os.path.isdir(dir_model):
print(f"creating directory: {dir_model}")
os.makedirs(dir_model)
# save train and validation metrics in a csv file
file_logger_train = os.path.join(dir_model, "train_metrics.csv")
csv_writer = CSVWriter(
file_name=file_logger_train,
column_names=["epoch", "loss_train", "loss_valid", "cer_valid", "wer_valid"],
)
file_params = os.path.join(dir_model, "params.json")
write_json_file(file_params, vars(FLAGS))
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
print(f"task - handwriting recognition")
print(f"model: {FLAGS.which_hw_model}")
print(
f"optimizer: {FLAGS.which_optimizer}, learning rate: {FLAGS.learning_rate:.6f}, weight decay: {FLAGS.weight_decay:.8f}"
)
print(
f"batch size: {FLAGS.batch_size}, image height: {FLAGS.image_height}, image width: {FLAGS.image_width}"
)
print(
f"num train samples: {num_train_samples}, num validation samples: {num_valid_samples}\n"
)
# load the right model
if FLAGS.which_hw_model == "crnn":
hw_model = CRNN(num_classes, FLAGS.image_height)
elif FLAGS.which_hw_model == "stn_crnn":
hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width)
else:
print(f"unidentified option: {FLAGS.which_hw_model}")
sys.exit(0)
hw_model.to(device)
# load the right optimizer based on user option
if FLAGS.which_optimizer == "adam":
optimizer = torch.optim.Adam(
hw_model.parameters(),
lr=FLAGS.learning_rate,
weight_decay=FLAGS.weight_decay,
)
elif FLAGS.which_optimizer == "adadelta":
optimizer = torch.optim.Adadelta(
hw_model.parameters(),
lr=FLAGS.learning_rate,
rho=0.95,
eps=1e-8,
weight_decay=FLAGS.weight_decay,
)
else:
print(f"unidentified option: {FLAGS.which_optimizer}")
sys.exit(0)
# use the CTC loss as the objective function for training
criterion = nn.CTCLoss(reduction="mean", zero_infinity=True)
# start training the model
print(f"training of handwriting recognition model {FLAGS.which_hw_model} started\n")
for epoch in range(1, FLAGS.num_epochs + 1):
time_start = time.time()
train_loss = train(hw_model, optimizer, criterion, train_loader, device)
valid_loss, valid_cer, valid_wer = validate(
hw_model, criterion, valid_loader, device
)
time_end = time.time()
print(
f"epoch: {epoch}/{FLAGS.num_epochs}, time: {time_end-time_start:.3f} sec."
)
print(
f"train loss: {train_loss:.6f}, validation loss: {valid_loss:.6f}, validation cer: {valid_cer:.4f}, validation wer: {valid_wer:.4f}\n"
)
csv_writer.write_row(
[
epoch,
round(train_loss, 6),
round(valid_loss, 6),
round(valid_cer, 4),
round(valid_wer, 4),
]
)
torch.save(
hw_model.state_dict(),
os.path.join(
dir_model,
f"{FLAGS.which_hw_model}_H_{FLAGS.image_height}_W_{FLAGS.image_width}_E_{epoch}.pth",
),
)
print(
f"Training of handwriting recognition model {FLAGS.which_hw_model} complete!!!!"
)
# close the csv file
csv_writer.close()
return
def main():
learning_rate = 1
# 3e-4 for Adam, 1 for Adadelta
weight_decay = 0
# 3e-5 with Adam for both CRNN and STN-CRNN
# 0 with Adadelta for CRNN and STN-CRNN
batch_size = 64
num_epochs = 100
image_height = 32
image_width = 768
which_hw_model = "crnn"
which_optimizer = "adadelta"
dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/"
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--learning_rate",
default=learning_rate,
type=float,
help="learning rate to use for training",
)
parser.add_argument(
"--weight_decay",
default=weight_decay,
type=float,
help="weight decay to use for training",
)
parser.add_argument(
"--batch_size",
default=batch_size,
type=int,
help="batch size to use for training",
)
parser.add_argument(
"--num_epochs",
default=num_epochs,
type=int,
help="num epochs to train the model",
)
parser.add_argument(
"--image_height",
default=image_height,
type=int,
help="image height to be used to train the model",
)
parser.add_argument(
"--image_width",
default=image_width,
type=int,
help="image width to be used to train the model",
)
parser.add_argument(
"--dir_dataset",
default=dir_dataset,
type=str,
help="full directory path to the dataset",
)
parser.add_argument(
"--which_optimizer",
default=which_optimizer,
type=str,
choices=["adadelta", "adam"],
help="which optimizer to use to train",
)
parser.add_argument(
"--which_hw_model",
default=which_hw_model,
type=str,
choices=["crnn", "stn_crnn", "stn_pp_crnn"],
help="which model to train",
)
FLAGS, unparsed = parser.parse_known_args()
train_hw_recognizer(FLAGS)
return
if __name__ == "__main__":
main()