import os import sys import time import torch import argparse import torchvision import numpy as np import torch.nn as nn import torch.nn.functional as F from logger_utils import CSVWriter from model_main import CRNN, STN_CRNN from utils import ctc_decode, compute_wer_and_cer_for_sample from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing def test( hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam_search", save_prediction_stats=False, ): """ --------- Arguments --------- hw_model : object handwriting recognition model object test_loader : object dataset loader object device : str device to be used for running the evaluation list_test_files : list list of all the test files which_ctc_decoder : str string indicating which ctc decoder to use save_prediction_stats : bool whether to save prediction stats """ hw_model.eval() num_test_samples = len(test_loader.dataset) num_test_batches = len(test_loader) count = 0 list_test_cers, list_test_wers = [], [] if save_prediction_stats: csv_writer = CSVWriter( file_name="pred_stats.csv", column_names=["file_name", "num_chars", "num_words", "cer", "wer"], ) with torch.no_grad(): for images, labels, length_labels in test_loader: count += 1 images = images.to(device, dtype=torch.float) log_probs = hw_model(images) pred_labels = ctc_decode(log_probs, which_ctc_decoder=which_ctc_decoder) labels = labels.cpu().numpy().tolist() str_label = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in labels] str_label = "".join(str_label) str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[0]] str_pred = "".join(str_pred) cer_sample, wer_sample = compute_wer_and_cer_for_sample(str_pred, str_label) list_test_cers.append(cer_sample) list_test_wers.append(wer_sample) print( f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}" ) print(f"{str_label} - label") print(f"{str_pred} - prediction") print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n") if save_prediction_stats: csv_writer.write_row( [ list_test_files[count - 1], len(str_label), len(str_label.split(" ")), cer_sample, wer_sample, ] ) list_test_cers = np.array(list_test_cers) list_test_wers = np.array(list_test_wers) mean_test_cer = np.mean(list_test_cers) mean_test_wer = np.mean(list_test_wers) print(f"test set - mean cer: {mean_test_cer:.3f}, mean wer: {mean_test_wer:.3f}\n") if save_prediction_stats: csv_writer.close() return def test_hw_recognizer(FLAGS): file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt") dir_images = os.path.join(FLAGS.dir_dataset, "img") os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" # choose a device for testing if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") # get the internal test set files test_x, test_y = split_dataset(file_txt_labels, for_train=False) num_test_samples = len(test_x) # get the internal test set dataloader test_loader = get_dataloader_for_testing( test_x, test_y, dir_images=dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width, ) num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1 print(f"task - handwriting recognition") print(f"model: {FLAGS.which_hw_model}, ctc decoder: {FLAGS.which_ctc_decoder}") print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}") print(f"num test samples: {num_test_samples}") # load the right model if FLAGS.which_hw_model == "crnn": hw_model = CRNN(num_classes, FLAGS.image_height) elif FLAGS.which_hw_model == "stn_crnn": hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width) else: print(f"unidentified option : {FLAGS.which_hw_model}") sys.exit(0) hw_model.to(device) hw_model.load_state_dict(torch.load(FLAGS.file_model)) # start testing of the model on the internal set print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n") test( hw_model, test_loader, device, test_x, FLAGS.which_ctc_decoder, bool(FLAGS.save_prediction_stats), ) print(f"testing handwriting recognition model completed!!!!") return def main(): image_height = 32 image_width = 768 which_hw_model = "crnn" dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/" file_model = "model_crnn/crnn_H_32_W_768_E_177.pth" which_ctc_decoder = "beam_search" save_prediction_stats = 0 parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--image_height", default=image_height, type=int, help="image height to be used to predict with the model", ) parser.add_argument( "--image_width", default=image_width, type=int, help="image width to be used to predict with the model", ) parser.add_argument( "--dir_dataset", default=dir_dataset, type=str, help="full directory path to the dataset", ) parser.add_argument( "--which_hw_model", default=which_hw_model, type=str, choices=["crnn", "stn_crnn"], help="which model to be used for prediction", ) parser.add_argument( "--which_ctc_decoder", default=which_ctc_decoder, type=str, choices=["beam_search", "greedy"], help="which ctc decoder to use", ) parser.add_argument( "--file_model", default=file_model, type=str, help="full path to trained model file (.pth)", ) parser.add_argument( "--save_prediction_stats", default=save_prediction_stats, type=int, choices=[0, 1], help="save prediction stats (1 - yes, 0 - no)", ) FLAGS, unparsed = parser.parse_known_args() test_hw_recognizer(FLAGS) return if __name__ == "__main__": main()