|
import os |
|
import sys |
|
import time |
|
import torch |
|
import argparse |
|
import torchvision |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from logger_utils import CSVWriter |
|
from model_main import CRNN, STN_CRNN |
|
from utils import ctc_decode, compute_wer_and_cer_for_sample |
|
from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing |
|
|
|
|
|
def test( |
|
hw_model, |
|
test_loader, |
|
device, |
|
list_test_files, |
|
which_ctc_decoder="beam_search", |
|
save_prediction_stats=False, |
|
): |
|
""" |
|
--------- |
|
Arguments |
|
--------- |
|
hw_model : object |
|
handwriting recognition model object |
|
test_loader : object |
|
dataset loader object |
|
device : str |
|
device to be used for running the evaluation |
|
list_test_files : list |
|
list of all the test files |
|
which_ctc_decoder : str |
|
string indicating which ctc decoder to use |
|
save_prediction_stats : bool |
|
whether to save prediction stats |
|
""" |
|
hw_model.eval() |
|
num_test_samples = len(test_loader.dataset) |
|
num_test_batches = len(test_loader) |
|
|
|
count = 0 |
|
list_test_cers, list_test_wers = [], [] |
|
|
|
if save_prediction_stats: |
|
csv_writer = CSVWriter( |
|
file_name="pred_stats.csv", |
|
column_names=["file_name", "num_chars", "num_words", "cer", "wer"], |
|
) |
|
|
|
with torch.no_grad(): |
|
for images, labels, length_labels in test_loader: |
|
count += 1 |
|
images = images.to(device, dtype=torch.float) |
|
log_probs = hw_model(images) |
|
pred_labels = ctc_decode(log_probs, which_ctc_decoder=which_ctc_decoder) |
|
labels = labels.cpu().numpy().tolist() |
|
|
|
str_label = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in labels] |
|
str_label = "".join(str_label) |
|
str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[0]] |
|
str_pred = "".join(str_pred) |
|
|
|
cer_sample, wer_sample = compute_wer_and_cer_for_sample(str_pred, str_label) |
|
list_test_cers.append(cer_sample) |
|
list_test_wers.append(wer_sample) |
|
|
|
print( |
|
f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}" |
|
) |
|
print(f"{str_label} - label") |
|
print(f"{str_pred} - prediction") |
|
print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n") |
|
|
|
if save_prediction_stats: |
|
csv_writer.write_row( |
|
[ |
|
list_test_files[count - 1], |
|
len(str_label), |
|
len(str_label.split(" ")), |
|
cer_sample, |
|
wer_sample, |
|
] |
|
) |
|
list_test_cers = np.array(list_test_cers) |
|
list_test_wers = np.array(list_test_wers) |
|
mean_test_cer = np.mean(list_test_cers) |
|
mean_test_wer = np.mean(list_test_wers) |
|
print(f"test set - mean cer: {mean_test_cer:.3f}, mean wer: {mean_test_wer:.3f}\n") |
|
|
|
if save_prediction_stats: |
|
csv_writer.close() |
|
return |
|
|
|
|
|
def test_hw_recognizer(FLAGS): |
|
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt") |
|
dir_images = os.path.join(FLAGS.dir_dataset, "img") |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
|
|
test_x, test_y = split_dataset(file_txt_labels, for_train=False) |
|
num_test_samples = len(test_x) |
|
|
|
test_loader = get_dataloader_for_testing( |
|
test_x, |
|
test_y, |
|
dir_images=dir_images, |
|
image_height=FLAGS.image_height, |
|
image_width=FLAGS.image_width, |
|
) |
|
|
|
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1 |
|
print(f"task - handwriting recognition") |
|
print(f"model: {FLAGS.which_hw_model}, ctc decoder: {FLAGS.which_ctc_decoder}") |
|
print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}") |
|
print(f"num test samples: {num_test_samples}") |
|
|
|
|
|
if FLAGS.which_hw_model == "crnn": |
|
hw_model = CRNN(num_classes, FLAGS.image_height) |
|
elif FLAGS.which_hw_model == "stn_crnn": |
|
hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width) |
|
else: |
|
print(f"unidentified option : {FLAGS.which_hw_model}") |
|
sys.exit(0) |
|
hw_model.to(device) |
|
hw_model.load_state_dict(torch.load(FLAGS.file_model)) |
|
|
|
|
|
print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n") |
|
test( |
|
hw_model, |
|
test_loader, |
|
device, |
|
test_x, |
|
FLAGS.which_ctc_decoder, |
|
bool(FLAGS.save_prediction_stats), |
|
) |
|
print(f"testing handwriting recognition model completed!!!!") |
|
return |
|
|
|
|
|
def main(): |
|
image_height = 32 |
|
image_width = 768 |
|
which_hw_model = "crnn" |
|
dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/" |
|
file_model = "model_crnn/crnn_H_32_W_768_E_177.pth" |
|
which_ctc_decoder = "beam_search" |
|
save_prediction_stats = 0 |
|
|
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
parser.add_argument( |
|
"--image_height", |
|
default=image_height, |
|
type=int, |
|
help="image height to be used to predict with the model", |
|
) |
|
parser.add_argument( |
|
"--image_width", |
|
default=image_width, |
|
type=int, |
|
help="image width to be used to predict with the model", |
|
) |
|
parser.add_argument( |
|
"--dir_dataset", |
|
default=dir_dataset, |
|
type=str, |
|
help="full directory path to the dataset", |
|
) |
|
parser.add_argument( |
|
"--which_hw_model", |
|
default=which_hw_model, |
|
type=str, |
|
choices=["crnn", "stn_crnn"], |
|
help="which model to be used for prediction", |
|
) |
|
parser.add_argument( |
|
"--which_ctc_decoder", |
|
default=which_ctc_decoder, |
|
type=str, |
|
choices=["beam_search", "greedy"], |
|
help="which ctc decoder to use", |
|
) |
|
parser.add_argument( |
|
"--file_model", |
|
default=file_model, |
|
type=str, |
|
help="full path to trained model file (.pth)", |
|
) |
|
parser.add_argument( |
|
"--save_prediction_stats", |
|
default=save_prediction_stats, |
|
type=int, |
|
choices=[0, 1], |
|
help="save prediction stats (1 - yes, 0 - no)", |
|
) |
|
|
|
FLAGS, unparsed = parser.parse_known_args() |
|
test_hw_recognizer(FLAGS) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|