Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import torch | |
import argparse | |
import torchvision | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from logger_utils import CSVWriter | |
from model_main import CRNN, STN_CRNN | |
from utils import ctc_decode, compute_wer_and_cer_for_sample | |
from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing | |
def test( | |
hw_model, | |
test_loader, | |
device, | |
list_test_files, | |
which_ctc_decoder="beam_search", | |
save_prediction_stats=False, | |
): | |
""" | |
--------- | |
Arguments | |
--------- | |
hw_model : object | |
handwriting recognition model object | |
test_loader : object | |
dataset loader object | |
device : str | |
device to be used for running the evaluation | |
list_test_files : list | |
list of all the test files | |
which_ctc_decoder : str | |
string indicating which ctc decoder to use | |
save_prediction_stats : bool | |
whether to save prediction stats | |
""" | |
hw_model.eval() | |
num_test_samples = len(test_loader.dataset) | |
num_test_batches = len(test_loader) | |
count = 0 | |
list_test_cers, list_test_wers = [], [] | |
if save_prediction_stats: | |
csv_writer = CSVWriter( | |
file_name="pred_stats.csv", | |
column_names=["file_name", "num_chars", "num_words", "cer", "wer"], | |
) | |
with torch.no_grad(): | |
for images, labels, length_labels in test_loader: | |
count += 1 | |
images = images.to(device, dtype=torch.float) | |
log_probs = hw_model(images) | |
pred_labels = ctc_decode(log_probs, which_ctc_decoder=which_ctc_decoder) | |
labels = labels.cpu().numpy().tolist() | |
str_label = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in labels] | |
str_label = "".join(str_label) | |
str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[0]] | |
str_pred = "".join(str_pred) | |
cer_sample, wer_sample = compute_wer_and_cer_for_sample(str_pred, str_label) | |
list_test_cers.append(cer_sample) | |
list_test_wers.append(wer_sample) | |
print( | |
f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}" | |
) | |
print(f"{str_label} - label") | |
print(f"{str_pred} - prediction") | |
print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n") | |
if save_prediction_stats: | |
csv_writer.write_row( | |
[ | |
list_test_files[count - 1], | |
len(str_label), | |
len(str_label.split(" ")), | |
cer_sample, | |
wer_sample, | |
] | |
) | |
list_test_cers = np.array(list_test_cers) | |
list_test_wers = np.array(list_test_wers) | |
mean_test_cer = np.mean(list_test_cers) | |
mean_test_wer = np.mean(list_test_wers) | |
print(f"test set - mean cer: {mean_test_cer:.3f}, mean wer: {mean_test_wer:.3f}\n") | |
if save_prediction_stats: | |
csv_writer.close() | |
return | |
def test_hw_recognizer(FLAGS): | |
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt") | |
dir_images = os.path.join(FLAGS.dir_dataset, "img") | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" | |
# choose a device for testing | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
# get the internal test set files | |
test_x, test_y = split_dataset(file_txt_labels, for_train=False) | |
num_test_samples = len(test_x) | |
# get the internal test set dataloader | |
test_loader = get_dataloader_for_testing( | |
test_x, | |
test_y, | |
dir_images=dir_images, | |
image_height=FLAGS.image_height, | |
image_width=FLAGS.image_width, | |
) | |
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1 | |
print(f"task - handwriting recognition") | |
print(f"model: {FLAGS.which_hw_model}, ctc decoder: {FLAGS.which_ctc_decoder}") | |
print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}") | |
print(f"num test samples: {num_test_samples}") | |
# load the right model | |
if FLAGS.which_hw_model == "crnn": | |
hw_model = CRNN(num_classes, FLAGS.image_height) | |
elif FLAGS.which_hw_model == "stn_crnn": | |
hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width) | |
else: | |
print(f"unidentified option : {FLAGS.which_hw_model}") | |
sys.exit(0) | |
hw_model.to(device) | |
hw_model.load_state_dict(torch.load(FLAGS.file_model)) | |
# start testing of the model on the internal set | |
print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n") | |
test( | |
hw_model, | |
test_loader, | |
device, | |
test_x, | |
FLAGS.which_ctc_decoder, | |
bool(FLAGS.save_prediction_stats), | |
) | |
print(f"testing handwriting recognition model completed!!!!") | |
return | |
def main(): | |
image_height = 32 | |
image_width = 768 | |
which_hw_model = "crnn" | |
dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/" | |
file_model = "model_crnn/crnn_H_32_W_768_E_177.pth" | |
which_ctc_decoder = "beam_search" | |
save_prediction_stats = 0 | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
) | |
parser.add_argument( | |
"--image_height", | |
default=image_height, | |
type=int, | |
help="image height to be used to predict with the model", | |
) | |
parser.add_argument( | |
"--image_width", | |
default=image_width, | |
type=int, | |
help="image width to be used to predict with the model", | |
) | |
parser.add_argument( | |
"--dir_dataset", | |
default=dir_dataset, | |
type=str, | |
help="full directory path to the dataset", | |
) | |
parser.add_argument( | |
"--which_hw_model", | |
default=which_hw_model, | |
type=str, | |
choices=["crnn", "stn_crnn"], | |
help="which model to be used for prediction", | |
) | |
parser.add_argument( | |
"--which_ctc_decoder", | |
default=which_ctc_decoder, | |
type=str, | |
choices=["beam_search", "greedy"], | |
help="which ctc decoder to use", | |
) | |
parser.add_argument( | |
"--file_model", | |
default=file_model, | |
type=str, | |
help="full path to trained model file (.pth)", | |
) | |
parser.add_argument( | |
"--save_prediction_stats", | |
default=save_prediction_stats, | |
type=int, | |
choices=[0, 1], | |
help="save prediction stats (1 - yes, 0 - no)", | |
) | |
FLAGS, unparsed = parser.parse_known_args() | |
test_hw_recognizer(FLAGS) | |
return | |
if __name__ == "__main__": | |
main() | |