|
import os |
|
import sys |
|
import time |
|
import torch |
|
import argparse |
|
import torchvision |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from logger_utils import CSVWriter |
|
from model_main import CRNN, STN_CRNN |
|
from utils import ctc_decode, compute_wer_and_cer_for_sample |
|
from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing |
|
|
|
|
|
def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam_search", save_prediction_stats=False): |
|
""" |
|
--------- |
|
Arguments |
|
--------- |
|
hw_model : object |
|
handwriting recognition model object |
|
test_loader : object |
|
dataset loader object |
|
device : str |
|
device to be used for running the evaluation |
|
list_test_files : list |
|
list of all the test files |
|
which_ctc_decoder : str |
|
string indicating which ctc decoder to use |
|
save_prediction_stats : bool |
|
whether to save prediction stats |
|
""" |
|
hw_model.eval() |
|
num_test_samples = len(test_loader.dataset) |
|
num_test_batches = len(test_loader) |
|
|
|
count = 0 |
|
list_test_cers, list_test_wers = [], [] |
|
|
|
if save_prediction_stats: |
|
csv_writer = CSVWriter( |
|
file_name="pred_stats.csv", |
|
column_names=["file_name", "num_chars", "num_words", "cer", "wer"] |
|
) |
|
|
|
with torch.no_grad(): |
|
for images, labels, length_labels in test_loader: |
|
count += 1 |
|
images = images.to(device, dtype=torch.float) |
|
log_probs = hw_model(images) |
|
pred_labels = ctc_decode(log_probs, which_ctc_decoder=which_ctc_decoder) |
|
labels = labels.cpu().numpy().tolist() |
|
|
|
str_label = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in labels] |
|
str_label = "".join(str_label) |
|
str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[0]] |
|
str_pred = "".join(str_pred) |
|
|
|
cer_sample, wer_sample = compute_wer_and_cer_for_sample(str_pred, str_label) |
|
list_test_cers.append(cer_sample) |
|
list_test_wers.append(wer_sample) |
|
|
|
print(f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}") |
|
print(f"{str_label} - label") |
|
print(f"{str_pred} - prediction") |
|
print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n") |
|
|
|
if save_prediction_stats: |
|
csv_writer.write_row([ |
|
list_test_files[count-1], |
|
len(str_label), |
|
len(str_label.split(" ")), |
|
cer_sample, |
|
wer_sample, |
|
]) |
|
list_test_cers = np.array(list_test_cers) |
|
list_test_wers = np.array(list_test_wers) |
|
mean_test_cer = np.mean(list_test_cers) |
|
mean_test_wer = np.mean(list_test_wers) |
|
print(f"test set - mean cer: {mean_test_cer:.3f}, mean wer: {mean_test_wer:.3f}\n") |
|
|
|
if save_prediction_stats: |
|
csv_writer.close() |
|
return |
|
|
|
def test_hw_recognizer(FLAGS): |
|
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt") |
|
dir_images = os.path.join(FLAGS.dir_dataset, "img") |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
|
|
test_x, test_y = split_dataset(file_txt_labels, for_train=False) |
|
num_test_samples = len(test_x) |
|
|
|
test_loader = get_dataloader_for_testing( |
|
test_x, test_y, |
|
dir_images=dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width, |
|
) |
|
|
|
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1 |
|
print(f"task - handwriting recognition") |
|
print(f"model: {FLAGS.which_hw_model}, ctc decoder: {FLAGS.which_ctc_decoder}") |
|
print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}") |
|
print(f"num test samples: {num_test_samples}") |
|
|
|
|
|
if FLAGS.which_hw_model == "crnn": |
|
hw_model = CRNN(num_classes, FLAGS.image_height) |
|
elif FLAGS.which_hw_model == "stn_crnn": |
|
hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width) |
|
else: |
|
print(f"unidentified option : {FLAGS.which_hw_model}") |
|
sys.exit(0) |
|
hw_model.to(device) |
|
hw_model.load_state_dict(torch.load(FLAGS.file_model)) |
|
|
|
|
|
print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n") |
|
test(hw_model, test_loader, device, test_x, FLAGS.which_ctc_decoder, bool(FLAGS.save_prediction_stats)) |
|
print(f"testing handwriting recognition model completed!!!!") |
|
return |
|
|
|
def main(): |
|
image_height = 32 |
|
image_width = 768 |
|
which_hw_model = "crnn" |
|
dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/" |
|
file_model = "model_crnn/crnn_H_32_W_768_E_177.pth" |
|
which_ctc_decoder = "beam_search" |
|
save_prediction_stats = 0 |
|
|
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
parser.add_argument("--image_height", default=image_height, |
|
type=int, help="image height to be used to predict with the model") |
|
parser.add_argument("--image_width", default=image_width, |
|
type=int, help="image width to be used to predict with the model") |
|
parser.add_argument("--dir_dataset", default=dir_dataset, |
|
type=str, help="full directory path to the dataset") |
|
parser.add_argument("--which_hw_model", default=which_hw_model, |
|
type=str, choices=["crnn", "stn_crnn"], help="which model to be used for prediction") |
|
parser.add_argument("--which_ctc_decoder", default=which_ctc_decoder, |
|
type=str, choices=["beam_search", "greedy"], help="which ctc decoder to use") |
|
parser.add_argument("--file_model", default=file_model, |
|
type=str, help="full path to trained model file (.pth)") |
|
parser.add_argument("--save_prediction_stats", default=save_prediction_stats, |
|
type=int, choices=[0, 1], help="save prediction stats (1 - yes, 0 - no)") |
|
|
|
FLAGS, unparsed = parser.parse_known_args() |
|
test_hw_recognizer(FLAGS) |
|
return |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|