Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn | |
import numpy as np | |
from PIL import Image | |
from skimage.io import imread | |
import torchvision.transforms as transforms | |
from torch.utils.data import Dataset, DataLoader | |
from sklearn.model_selection import train_test_split | |
def read_IAM_label_txt_file(file_txt_labels): | |
""" | |
--------- | |
Arguments | |
--------- | |
file_txt_labels : str | |
full path to the text file containing labels | |
------- | |
Returns | |
------- | |
a tuple of | |
all_image_files : list | |
a list of all image file names | |
all_labels : list | |
a list of all labels | |
""" | |
label_file_handler = open(file_txt_labels, mode="r") | |
all_lines = label_file_handler.readlines() | |
num_lines = len(all_lines) | |
all_image_files = [] | |
all_labels = [] | |
for cur_line_num in range(num_lines): | |
if cur_line_num % 3 == 0: | |
all_image_files.append(all_lines[cur_line_num].strip()) | |
elif cur_line_num % 3 == 1: | |
all_labels.append(all_lines[cur_line_num].strip()) | |
else: | |
continue | |
return all_image_files, all_labels | |
class HWRecogIAMDataset(Dataset): | |
""" | |
Main dataset class to be used only for training, validation and internal testing | |
""" | |
CHAR_SET = " !\"#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" | |
CHAR_2_LABEL = {char: i + 1 for i, char in enumerate(CHAR_SET)} | |
LABEL_2_CHAR = {label: char for char, label in CHAR_2_LABEL.items()} | |
def __init__( | |
self, | |
list_image_files, | |
list_labels, | |
dir_images, | |
image_height=32, | |
image_width=768, | |
which_set="train", | |
): | |
""" | |
--------- | |
Arguments | |
--------- | |
list_image_files : list | |
list of image files | |
list_labels : list | |
list of labels | |
dir_images : str | |
full path to directory containing images | |
image_height : int | |
image height (default: 32) | |
image_width : int | |
image width (default: 768) | |
which_set : str | |
a string indicating which set is being used (default: train) | |
""" | |
self.list_labels = list_labels | |
self.dir_images = dir_images | |
self.list_image_files = list_image_files | |
self.image_width = image_width | |
self.image_height = image_height | |
self.which_set = which_set | |
if self.which_set == "train": | |
# apply data augmentation only for train set | |
self.transform = transforms.Compose( | |
[ | |
transforms.ToPILImage(), | |
transforms.Resize( | |
(self.image_height, self.image_width), Image.BILINEAR | |
), | |
transforms.RandomAffine( | |
degrees=[-0.75, 0.75], | |
translate=[0, 0.05], | |
scale=[0.75, 1], | |
shear=[-35, 35], | |
interpolation=transforms.InterpolationMode.BILINEAR, | |
fill=255, | |
), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
), | |
] | |
) | |
else: | |
self.transform = transforms.Compose( | |
[ | |
transforms.ToPILImage(), | |
transforms.Resize( | |
(self.image_height, self.image_width), Image.BILINEAR | |
), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
), | |
] | |
) | |
def __len__(self): | |
return len(self.list_image_files) | |
def __getitem__(self, idx): | |
image_file_name = self.list_image_files[idx] | |
image_gray = imread(os.path.join(self.dir_images, image_file_name)) | |
image_3_channel = np.repeat(np.expand_dims(image_gray, -1), 3, -1) | |
image_3_channel = self.transform(image_3_channel) | |
label_string = self.list_labels[idx] | |
label_encoded = [self.CHAR_2_LABEL[c] for c in label_string] | |
label_length = [len(label_encoded)] | |
label_encoded = torch.LongTensor(label_encoded) | |
label_length = torch.LongTensor(label_length) | |
return image_3_channel, label_encoded, label_length | |
def IAM_collate_fn(batch): | |
""" | |
collate function | |
--------- | |
Arguments | |
--------- | |
batch : tuple | |
a batch of input data as a tuple | |
------- | |
Returns | |
------- | |
a collated tuple of | |
images : tensor | |
tensor of batch images | |
labels : tensor | |
tensor of batch labels | |
label_lengths : tensor | |
tensor of batch label lengths | |
""" | |
images, labels, label_lengths = zip(*batch) | |
images = torch.stack(images, 0) | |
labels = torch.cat(labels, 0) | |
label_lengths = torch.cat(label_lengths, 0) | |
return images, labels, label_lengths | |
def split_dataset(file_txt_labels, for_train=True): | |
""" | |
--------- | |
Arguments | |
--------- | |
file_txt_labels : str | |
full path to the text file containing labels | |
for_train : bool | |
indicating whether split is for training or internal testing | |
------- | |
Returns | |
------- | |
a tuple of files depending for train or internal testing | |
""" | |
all_image_files, all_labels = read_IAM_label_txt_file(file_txt_labels) | |
train_image_files, test_image_files, train_labels, test_labels = train_test_split( | |
all_image_files, all_labels, test_size=0.1, random_state=4 | |
) | |
train_image_files, valid_image_files, train_labels, valid_labels = train_test_split( | |
train_image_files, train_labels, test_size=0.1, random_state=4 | |
) | |
if for_train: | |
return train_image_files, valid_image_files, train_labels, valid_labels | |
else: | |
return test_image_files, test_labels | |
def get_dataloaders_for_training( | |
train_x, | |
train_y, | |
valid_x, | |
valid_y, | |
dir_images, | |
image_height=32, | |
image_width=768, | |
batch_size=8, | |
): | |
""" | |
--------- | |
Arguments | |
--------- | |
train_x : list | |
list of train file names | |
train_y : list | |
list of train labels | |
valid_x : list | |
list of validation file names | |
valid_y : list | |
list of validation labels | |
dir_images : str | |
full directory path containing the images | |
image_height : int | |
image height (default: 32) | |
image_width : int | |
image width (default: 768) | |
batch_size : int | |
batch size (default: 8) | |
------- | |
Returns | |
------- | |
a tuple of dataloaders objects | |
train_loader : object | |
object of train set dataloader | |
valid_loader : object | |
object of validation set dataloader | |
""" | |
train_dataset = HWRecogIAMDataset( | |
train_x, | |
train_y, | |
dir_images, | |
image_height=image_height, | |
image_width=image_width, | |
which_set="train", | |
) | |
valid_dataset = HWRecogIAMDataset( | |
valid_x, | |
valid_y, | |
dir_images, | |
image_height=image_height, | |
image_width=image_width, | |
which_set="valid", | |
) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=4, | |
collate_fn=IAM_collate_fn, | |
) | |
valid_loader = DataLoader( | |
valid_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=4, | |
collate_fn=IAM_collate_fn, | |
) | |
return train_loader, valid_loader | |
def get_dataloader_for_testing( | |
test_x, test_y, dir_images, image_height=32, image_width=768, batch_size=1 | |
): | |
""" | |
--------- | |
Arguments | |
--------- | |
test_x : list | |
list of test file names | |
test_y : list | |
list of test labels | |
dir_images : str | |
full directory path containing the images | |
image_height : int | |
image height (default: 32) | |
image_width : int | |
image width (default: 768) | |
batch_size : int | |
batch size (default: 1) | |
------- | |
Returns | |
------- | |
test_loader : object | |
object of test set dataloader | |
""" | |
test_dataset = HWRecogIAMDataset( | |
test_x, | |
test_y, | |
dir_images=dir_images, | |
image_height=image_height, | |
image_width=image_width, | |
which_set="test", | |
) | |
test_loader = DataLoader( | |
test_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=4, | |
collate_fn=IAM_collate_fn, | |
) | |
return test_loader | |