|
import os
|
|
import torch
|
|
import torch.nn
|
|
import numpy as np
|
|
from PIL import Image
|
|
from skimage.io import imread
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
def read_IAM_label_txt_file(file_txt_labels):
|
|
"""
|
|
---------
|
|
Arguments
|
|
---------
|
|
file_txt_labels : str
|
|
full path to the text file containing labels
|
|
|
|
-------
|
|
Returns
|
|
-------
|
|
a tuple of
|
|
all_image_files : list
|
|
a list of all image file names
|
|
all_labels : list
|
|
a list of all labels
|
|
"""
|
|
label_file_handler = open(file_txt_labels, mode="r")
|
|
all_lines = label_file_handler.readlines()
|
|
num_lines = len(all_lines)
|
|
|
|
all_image_files = []
|
|
all_labels = []
|
|
|
|
for cur_line_num in range(num_lines):
|
|
if cur_line_num % 3 == 0:
|
|
all_image_files.append(all_lines[cur_line_num].strip())
|
|
elif cur_line_num % 3 == 1:
|
|
all_labels.append(all_lines[cur_line_num].strip())
|
|
else:
|
|
continue
|
|
|
|
return all_image_files, all_labels
|
|
|
|
|
|
class HWRecogIAMDataset(Dataset):
|
|
"""
|
|
Main dataset class to be used only for training, validation and internal testing
|
|
"""
|
|
|
|
CHAR_SET = " !\"#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
|
CHAR_2_LABEL = {char: i + 1 for i, char in enumerate(CHAR_SET)}
|
|
LABEL_2_CHAR = {label: char for char, label in CHAR_2_LABEL.items()}
|
|
|
|
def __init__(
|
|
self,
|
|
list_image_files,
|
|
list_labels,
|
|
dir_images,
|
|
image_height=32,
|
|
image_width=768,
|
|
which_set="train",
|
|
):
|
|
"""
|
|
---------
|
|
Arguments
|
|
---------
|
|
list_image_files : list
|
|
list of image files
|
|
list_labels : list
|
|
list of labels
|
|
dir_images : str
|
|
full path to directory containing images
|
|
image_height : int
|
|
image height (default: 32)
|
|
image_width : int
|
|
image width (default: 768)
|
|
which_set : str
|
|
a string indicating which set is being used (default: train)
|
|
"""
|
|
self.list_labels = list_labels
|
|
self.dir_images = dir_images
|
|
self.list_image_files = list_image_files
|
|
self.image_width = image_width
|
|
self.image_height = image_height
|
|
self.which_set = which_set
|
|
|
|
if self.which_set == "train":
|
|
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.ToPILImage(),
|
|
transforms.Resize(
|
|
(self.image_height, self.image_width), Image.BILINEAR
|
|
),
|
|
transforms.RandomAffine(
|
|
degrees=[-0.75, 0.75],
|
|
translate=[0, 0.05],
|
|
scale=[0.75, 1],
|
|
shear=[-35, 35],
|
|
interpolation=transforms.InterpolationMode.BILINEAR,
|
|
fill=255,
|
|
),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225],
|
|
),
|
|
]
|
|
)
|
|
else:
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.ToPILImage(),
|
|
transforms.Resize(
|
|
(self.image_height, self.image_width), Image.BILINEAR
|
|
),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225],
|
|
),
|
|
]
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.list_image_files)
|
|
|
|
def __getitem__(self, idx):
|
|
image_file_name = self.list_image_files[idx]
|
|
image_gray = imread(os.path.join(self.dir_images, image_file_name))
|
|
image_3_channel = np.repeat(np.expand_dims(image_gray, -1), 3, -1)
|
|
image_3_channel = self.transform(image_3_channel)
|
|
|
|
label_string = self.list_labels[idx]
|
|
label_encoded = [self.CHAR_2_LABEL[c] for c in label_string]
|
|
label_length = [len(label_encoded)]
|
|
|
|
label_encoded = torch.LongTensor(label_encoded)
|
|
label_length = torch.LongTensor(label_length)
|
|
|
|
return image_3_channel, label_encoded, label_length
|
|
|
|
|
|
def IAM_collate_fn(batch):
|
|
"""
|
|
collate function
|
|
|
|
---------
|
|
Arguments
|
|
---------
|
|
batch : tuple
|
|
a batch of input data as a tuple
|
|
|
|
-------
|
|
Returns
|
|
-------
|
|
a collated tuple of
|
|
images : tensor
|
|
tensor of batch images
|
|
labels : tensor
|
|
tensor of batch labels
|
|
label_lengths : tensor
|
|
tensor of batch label lengths
|
|
"""
|
|
images, labels, label_lengths = zip(*batch)
|
|
images = torch.stack(images, 0)
|
|
labels = torch.cat(labels, 0)
|
|
label_lengths = torch.cat(label_lengths, 0)
|
|
return images, labels, label_lengths
|
|
|
|
|
|
def split_dataset(file_txt_labels, for_train=True):
|
|
"""
|
|
---------
|
|
Arguments
|
|
---------
|
|
file_txt_labels : str
|
|
full path to the text file containing labels
|
|
for_train : bool
|
|
indicating whether split is for training or internal testing
|
|
|
|
-------
|
|
Returns
|
|
-------
|
|
a tuple of files depending for train or internal testing
|
|
"""
|
|
all_image_files, all_labels = read_IAM_label_txt_file(file_txt_labels)
|
|
train_image_files, test_image_files, train_labels, test_labels = train_test_split(
|
|
all_image_files, all_labels, test_size=0.1, random_state=4
|
|
)
|
|
train_image_files, valid_image_files, train_labels, valid_labels = train_test_split(
|
|
train_image_files, train_labels, test_size=0.1, random_state=4
|
|
)
|
|
if for_train:
|
|
return train_image_files, valid_image_files, train_labels, valid_labels
|
|
else:
|
|
return test_image_files, test_labels
|
|
|
|
|
|
def get_dataloaders_for_training(
|
|
train_x,
|
|
train_y,
|
|
valid_x,
|
|
valid_y,
|
|
dir_images,
|
|
image_height=32,
|
|
image_width=768,
|
|
batch_size=8,
|
|
):
|
|
"""
|
|
---------
|
|
Arguments
|
|
---------
|
|
train_x : list
|
|
list of train file names
|
|
train_y : list
|
|
list of train labels
|
|
valid_x : list
|
|
list of validation file names
|
|
valid_y : list
|
|
list of validation labels
|
|
dir_images : str
|
|
full directory path containing the images
|
|
image_height : int
|
|
image height (default: 32)
|
|
image_width : int
|
|
image width (default: 768)
|
|
batch_size : int
|
|
batch size (default: 8)
|
|
|
|
-------
|
|
Returns
|
|
-------
|
|
a tuple of dataloaders objects
|
|
train_loader : object
|
|
object of train set dataloader
|
|
valid_loader : object
|
|
object of validation set dataloader
|
|
"""
|
|
train_dataset = HWRecogIAMDataset(
|
|
train_x,
|
|
train_y,
|
|
dir_images,
|
|
image_height=image_height,
|
|
image_width=image_width,
|
|
which_set="train",
|
|
)
|
|
valid_dataset = HWRecogIAMDataset(
|
|
valid_x,
|
|
valid_y,
|
|
dir_images,
|
|
image_height=image_height,
|
|
image_width=image_width,
|
|
which_set="valid",
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
collate_fn=IAM_collate_fn,
|
|
)
|
|
valid_loader = DataLoader(
|
|
valid_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=4,
|
|
collate_fn=IAM_collate_fn,
|
|
)
|
|
return train_loader, valid_loader
|
|
|
|
|
|
def get_dataloader_for_testing(
|
|
test_x, test_y, dir_images, image_height=32, image_width=768, batch_size=1
|
|
):
|
|
"""
|
|
---------
|
|
Arguments
|
|
---------
|
|
test_x : list
|
|
list of test file names
|
|
test_y : list
|
|
list of test labels
|
|
dir_images : str
|
|
full directory path containing the images
|
|
image_height : int
|
|
image height (default: 32)
|
|
image_width : int
|
|
image width (default: 768)
|
|
batch_size : int
|
|
batch size (default: 1)
|
|
|
|
-------
|
|
Returns
|
|
-------
|
|
test_loader : object
|
|
object of test set dataloader
|
|
"""
|
|
test_dataset = HWRecogIAMDataset(
|
|
test_x,
|
|
test_y,
|
|
dir_images=dir_images,
|
|
image_height=image_height,
|
|
image_width=image_width,
|
|
which_set="test",
|
|
)
|
|
test_loader = DataLoader(
|
|
test_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=4,
|
|
collate_fn=IAM_collate_fn,
|
|
)
|
|
return test_loader
|
|
|