|
import os |
|
import sys |
|
import time |
|
import torch |
|
import argparse |
|
import torchvision |
|
import numpy as np |
|
import torch.nn as nn |
|
from PIL import Image |
|
from skimage.io import imread |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
import torchvision.transforms as transforms |
|
|
|
from dataset import HWRecogIAMDataset |
|
from model_main import CRNN, STN_CRNN |
|
from utils import ctc_decode, compute_wer_and_cer_for_sample |
|
|
|
|
|
class DatasetFinalEval(HWRecogIAMDataset): |
|
""" |
|
Dataset class for final evaluation - inherits main dataset class |
|
""" |
|
|
|
def __init__(self, dir_images, image_height=32, image_width=768): |
|
""" |
|
--------- |
|
Arguments |
|
--------- |
|
dir_images : str |
|
full path to directory containing images |
|
image_height : int |
|
image height (default: 32) |
|
image_width : int |
|
image width (default: 768) |
|
""" |
|
self.dir_images = dir_images |
|
self.image_files = [ |
|
f for f in os.listdir(self.dir_images) if f.endswith(".png") |
|
] |
|
self.image_width = image_width |
|
self.image_height = image_height |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.ToPILImage(), |
|
transforms.Resize( |
|
(self.image_height, self.image_width), Image.BILINEAR |
|
), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225], |
|
), |
|
] |
|
) |
|
|
|
def __len__(self): |
|
return len(self.image_files) |
|
|
|
def __getitem__(self, idx): |
|
image_file_name = self.image_files[idx] |
|
image_gray = imread(os.path.join(self.dir_images, image_file_name)) |
|
image_3_channel = np.repeat(np.expand_dims(image_gray, -1), 3, -1) |
|
image_3_channel = self.transform(image_3_channel) |
|
return image_3_channel |
|
|
|
|
|
def get_dataloader_for_evaluation( |
|
dir_images, image_height=32, image_width=768, batch_size=1 |
|
): |
|
""" |
|
--------- |
|
Arguments |
|
--------- |
|
dir_images : str |
|
full path to directory containing images |
|
image_height : int |
|
image height (default: 32) |
|
image_width : int |
|
image width (default: 768) |
|
batch_size : int |
|
batch size to use for final evaluation (default: 1) |
|
|
|
------- |
|
Returns |
|
------- |
|
test_loader : object |
|
dataset loader object for final evaluation |
|
""" |
|
test_dataset = DatasetFinalEval( |
|
dir_images=dir_images, image_height=image_height, image_width=image_width |
|
) |
|
test_loader = DataLoader( |
|
test_dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=4, |
|
) |
|
return test_loader |
|
|
|
|
|
def final_eval(hw_model, device, test_loader, dir_images, dir_results): |
|
""" |
|
--------- |
|
Arguments |
|
--------- |
|
hw_model : object |
|
handwriting recognition model object |
|
device : str |
|
device to be used for running the evaluation |
|
test_loader : object |
|
dataset loader object |
|
dir_images : str |
|
full path to directory containing test images |
|
dir_results : str |
|
relative path to directory to save the predictions as txt files |
|
""" |
|
hw_model.eval() |
|
count = 0 |
|
num_test_samples = len(test_loader.dataset) |
|
list_test_files = os.listdir(dir_images) |
|
|
|
if not os.path.isdir(dir_results): |
|
print(f"creating directory: {dir_results}") |
|
os.makedirs(dir_results) |
|
|
|
with torch.no_grad(): |
|
for image_test in test_loader: |
|
file_test = list_test_files[count] |
|
count += 1 |
|
""" |
|
if count == 11: |
|
break |
|
""" |
|
image_test = image_test.to(device, dtype=torch.float) |
|
|
|
log_probs = hw_model(image_test) |
|
pred_labels = ctc_decode(log_probs) |
|
str_pred = [DatasetFinalEval.LABEL_2_CHAR[i] for i in pred_labels[0]] |
|
str_pred = "".join(str_pred) |
|
|
|
with open( |
|
os.path.join(dir_results, file_test + ".txt"), |
|
"w", |
|
encoding="utf-8", |
|
newline="\n", |
|
) as fh_pred: |
|
fh_pred.write(str_pred) |
|
|
|
print( |
|
f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}" |
|
) |
|
print(f"{str_pred}\n") |
|
print(f"predictions saved in directory: ./{dir_results}\n") |
|
return |
|
|
|
|
|
def test_hw_recognizer(FLAGS): |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" |
|
|
|
num_classes = len(DatasetFinalEval.LABEL_2_CHAR) + 1 |
|
print(f"task - handwriting recognition") |
|
print(f"model: {FLAGS.which_hw_model}") |
|
print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}") |
|
|
|
|
|
if FLAGS.which_hw_model == "crnn": |
|
hw_model = CRNN(num_classes, FLAGS.image_height) |
|
elif FLAGS.which_hw_model == "stn_crnn": |
|
hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width) |
|
else: |
|
print(f"unidentified option : {FLAGS.which_hw_model}") |
|
sys.exit(0) |
|
dir_results = f"results_{FLAGS.which_hw_model}" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
hw_model.to(device) |
|
hw_model.load_state_dict(torch.load(FLAGS.file_model)) |
|
|
|
|
|
test_loader = get_dataloader_for_evaluation( |
|
dir_images=FLAGS.dir_images, |
|
image_height=FLAGS.image_height, |
|
image_width=FLAGS.image_width, |
|
) |
|
|
|
|
|
print( |
|
f"final evaluation of handwriting recognition model {FLAGS.which_hw_model} started\n" |
|
) |
|
final_eval(hw_model, device, test_loader, FLAGS.dir_images, dir_results) |
|
print(f"final evaluation of handwriting recognition model completed!!!!") |
|
return |
|
|
|
|
|
def main(): |
|
image_height = 32 |
|
image_width = 768 |
|
which_hw_model = "crnn" |
|
dir_images = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/img/" |
|
file_model = "model_crnn/crnn_H_32_W_768_E_177.pth" |
|
save_predictions = 1 |
|
|
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
parser.add_argument( |
|
"--image_height", |
|
default=image_height, |
|
type=int, |
|
help="image height to be used to predict with the model", |
|
) |
|
parser.add_argument( |
|
"--image_width", |
|
default=image_width, |
|
type=int, |
|
help="image width to be used to predict with the model", |
|
) |
|
parser.add_argument( |
|
"--dir_images", |
|
default=dir_images, |
|
type=str, |
|
help="full directory path to directory containing images", |
|
) |
|
parser.add_argument( |
|
"--which_hw_model", |
|
default=which_hw_model, |
|
type=str, |
|
choices=["crnn", "stn_crnn"], |
|
help="which model to be used for prediction", |
|
) |
|
parser.add_argument( |
|
"--file_model", |
|
default=file_model, |
|
type=str, |
|
help="full path to trained model file (.pth)", |
|
) |
|
parser.add_argument( |
|
"--save_predictions", |
|
default=save_predictions, |
|
type=int, |
|
choices=[0, 1], |
|
help="save or do not save the predictions (1 - save, 0 - do not save)", |
|
) |
|
|
|
FLAGS, unparsed = parser.parse_known_args() |
|
test_hw_recognizer(FLAGS) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|