import os import sys import time import torch import argparse import torchvision import numpy as np import torch.nn as nn from PIL import Image from skimage.io import imread import torch.nn.functional as F from torch.utils.data import DataLoader import torchvision.transforms as transforms from dataset import HWRecogIAMDataset from model_main import CRNN, STN_CRNN from utils import ctc_decode, compute_wer_and_cer_for_sample class DatasetFinalEval(HWRecogIAMDataset): """ Dataset class for final evaluation - inherits main dataset class """ def __init__(self, dir_images, image_height=32, image_width=768): """ --------- Arguments --------- dir_images : str full path to directory containing images image_height : int image height (default: 32) image_width : int image width (default: 768) """ self.dir_images = dir_images self.image_files = [ f for f in os.listdir(self.dir_images) if f.endswith(".png") ] self.image_width = image_width self.image_height = image_height self.transform = transforms.Compose( [ transforms.ToPILImage(), transforms.Resize( (self.image_height, self.image_width), Image.BILINEAR ), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ] ) def __len__(self): return len(self.image_files) def __getitem__(self, idx): image_file_name = self.image_files[idx] image_gray = imread(os.path.join(self.dir_images, image_file_name)) image_3_channel = np.repeat(np.expand_dims(image_gray, -1), 3, -1) image_3_channel = self.transform(image_3_channel) return image_3_channel def get_dataloader_for_evaluation( dir_images, image_height=32, image_width=768, batch_size=1 ): """ --------- Arguments --------- dir_images : str full path to directory containing images image_height : int image height (default: 32) image_width : int image width (default: 768) batch_size : int batch size to use for final evaluation (default: 1) ------- Returns ------- test_loader : object dataset loader object for final evaluation """ test_dataset = DatasetFinalEval( dir_images=dir_images, image_height=image_height, image_width=image_width ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, ) return test_loader def final_eval(hw_model, device, test_loader, dir_images, dir_results): """ --------- Arguments --------- hw_model : object handwriting recognition model object device : str device to be used for running the evaluation test_loader : object dataset loader object dir_images : str full path to directory containing test images dir_results : str relative path to directory to save the predictions as txt files """ hw_model.eval() count = 0 num_test_samples = len(test_loader.dataset) list_test_files = os.listdir(dir_images) if not os.path.isdir(dir_results): print(f"creating directory: {dir_results}") os.makedirs(dir_results) with torch.no_grad(): for image_test in test_loader: file_test = list_test_files[count] count += 1 """ if count == 11: break """ image_test = image_test.to(device, dtype=torch.float) log_probs = hw_model(image_test) pred_labels = ctc_decode(log_probs) str_pred = [DatasetFinalEval.LABEL_2_CHAR[i] for i in pred_labels[0]] str_pred = "".join(str_pred) with open( os.path.join(dir_results, file_test + ".txt"), "w", encoding="utf-8", newline="\n", ) as fh_pred: fh_pred.write(str_pred) print( f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}" ) print(f"{str_pred}\n") print(f"predictions saved in directory: ./{dir_results}\n") return def test_hw_recognizer(FLAGS): os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" num_classes = len(DatasetFinalEval.LABEL_2_CHAR) + 1 print(f"task - handwriting recognition") print(f"model: {FLAGS.which_hw_model}") print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}") # load the right model if FLAGS.which_hw_model == "crnn": hw_model = CRNN(num_classes, FLAGS.image_height) elif FLAGS.which_hw_model == "stn_crnn": hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width) else: print(f"unidentified option : {FLAGS.which_hw_model}") sys.exit(0) dir_results = f"results_{FLAGS.which_hw_model}" # choose a device for evaluation if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") hw_model.to(device) hw_model.load_state_dict(torch.load(FLAGS.file_model)) # get test set dataloader test_loader = get_dataloader_for_evaluation( dir_images=FLAGS.dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width, ) # start the evaluation on the final test set print( f"final evaluation of handwriting recognition model {FLAGS.which_hw_model} started\n" ) final_eval(hw_model, device, test_loader, FLAGS.dir_images, dir_results) print(f"final evaluation of handwriting recognition model completed!!!!") return def main(): image_height = 32 image_width = 768 which_hw_model = "crnn" dir_images = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/img/" file_model = "model_crnn/crnn_H_32_W_768_E_177.pth" save_predictions = 1 parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--image_height", default=image_height, type=int, help="image height to be used to predict with the model", ) parser.add_argument( "--image_width", default=image_width, type=int, help="image width to be used to predict with the model", ) parser.add_argument( "--dir_images", default=dir_images, type=str, help="full directory path to directory containing images", ) parser.add_argument( "--which_hw_model", default=which_hw_model, type=str, choices=["crnn", "stn_crnn"], help="which model to be used for prediction", ) parser.add_argument( "--file_model", default=file_model, type=str, help="full path to trained model file (.pth)", ) parser.add_argument( "--save_predictions", default=save_predictions, type=int, choices=[0, 1], help="save or do not save the predictions (1 - save, 0 - do not save)", ) FLAGS, unparsed = parser.parse_known_args() test_hw_recognizer(FLAGS) return if __name__ == "__main__": main()