Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn | |
import numpy as np | |
from PIL import Image | |
from skimage.io import imread | |
import torchvision.transforms as transforms | |
from torch.utils.data import Dataset, DataLoader | |
from sklearn.model_selection import train_test_split | |
def read_IAM_label_txt_file(file_txt_labels): | |
""" | |
--------- | |
Arguments | |
--------- | |
file_txt_labels : str | |
full path to the text file containing labels | |
------- | |
Returns | |
------- | |
a tuple of | |
all_image_files : list | |
a list of all image file names | |
all_labels : list | |
a list of all labels | |
""" | |
label_file_handler = open(file_txt_labels, mode="r") | |
all_lines = label_file_handler.readlines() | |
num_lines = len(all_lines) | |
all_image_files = [] | |
all_labels = [] | |
for cur_line_num in range(num_lines): | |
if cur_line_num % 3 == 0: | |
all_image_files.append(all_lines[cur_line_num].strip()) | |
elif cur_line_num % 3 == 1: | |
all_labels.append(all_lines[cur_line_num].strip()) | |
else: | |
continue | |
return all_image_files, all_labels | |
class HWRecogIAMDataset(Dataset): | |
""" | |
Main dataset class to be used only for training, validation and internal testing | |
""" | |
CHAR_SET = ' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' | |
CHAR_2_LABEL = {char: i + 1 for i, char in enumerate(CHAR_SET)} | |
LABEL_2_CHAR = {label: char for char, label in CHAR_2_LABEL.items()} | |
def __init__(self, list_image_files, list_labels, dir_images, image_height=32, image_width=768, which_set="train"): | |
""" | |
--------- | |
Arguments | |
--------- | |
list_image_files : list | |
list of image files | |
list_labels : list | |
list of labels | |
dir_images : str | |
full path to directory containing images | |
image_height : int | |
image height (default: 32) | |
image_width : int | |
image width (default: 768) | |
which_set : str | |
a string indicating which set is being used (default: train) | |
""" | |
self.list_labels = list_labels | |
self.dir_images = dir_images | |
self.list_image_files = list_image_files | |
self.image_width = image_width | |
self.image_height = image_height | |
self.which_set = which_set | |
if self.which_set == "train": | |
# apply data augmentation only for train set | |
self.transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize((self.image_height, self.image_width), Image.BILINEAR), | |
transforms.RandomAffine(degrees=[-0.75, 0.75], translate=[0, 0.05], scale=[0.75, 1], | |
shear=[-35, 35], interpolation=transforms.InterpolationMode.BILINEAR, fill=255, | |
), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
), | |
]) | |
else: | |
self.transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize((self.image_height, self.image_width), Image.BILINEAR), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
), | |
]) | |
def __len__(self): | |
return len(self.list_image_files) | |
def __getitem__(self, idx): | |
image_file_name = self.list_image_files[idx] | |
image_gray = imread(os.path.join(self.dir_images, image_file_name)) | |
image_3_channel = np.repeat(np.expand_dims(image_gray, -1), 3, -1) | |
image_3_channel = self.transform(image_3_channel) | |
label_string = self.list_labels[idx] | |
label_encoded = [self.CHAR_2_LABEL[c] for c in label_string] | |
label_length = [len(label_encoded)] | |
label_encoded = torch.LongTensor(label_encoded) | |
label_length = torch.LongTensor(label_length) | |
return image_3_channel, label_encoded, label_length | |
def IAM_collate_fn(batch): | |
""" | |
collate function | |
--------- | |
Arguments | |
--------- | |
batch : tuple | |
a batch of input data as a tuple | |
------- | |
Returns | |
------- | |
a collated tuple of | |
images : tensor | |
tensor of batch images | |
labels : tensor | |
tensor of batch labels | |
label_lengths : tensor | |
tensor of batch label lengths | |
""" | |
images, labels, label_lengths = zip(*batch) | |
images = torch.stack(images, 0) | |
labels = torch.cat(labels, 0) | |
label_lengths = torch.cat(label_lengths, 0) | |
return images, labels, label_lengths | |
def split_dataset(file_txt_labels, for_train=True): | |
""" | |
--------- | |
Arguments | |
--------- | |
file_txt_labels : str | |
full path to the text file containing labels | |
for_train : bool | |
indicating whether split is for training or internal testing | |
------- | |
Returns | |
------- | |
a tuple of files depending for train or internal testing | |
""" | |
all_image_files, all_labels = read_IAM_label_txt_file(file_txt_labels) | |
train_image_files, test_image_files, train_labels, test_labels = train_test_split(all_image_files, all_labels, test_size=0.1, random_state=4) | |
train_image_files, valid_image_files, train_labels, valid_labels = train_test_split(train_image_files, train_labels, test_size=0.1, random_state=4) | |
if for_train: | |
return train_image_files, valid_image_files, train_labels, valid_labels | |
else: | |
return test_image_files, test_labels | |
def get_dataloaders_for_training(train_x, train_y, valid_x, valid_y, dir_images, image_height=32, image_width=768, batch_size=8): | |
""" | |
--------- | |
Arguments | |
--------- | |
train_x : list | |
list of train file names | |
train_y : list | |
list of train labels | |
valid_x : list | |
list of validation file names | |
valid_y : list | |
list of validation labels | |
dir_images : str | |
full directory path containing the images | |
image_height : int | |
image height (default: 32) | |
image_width : int | |
image width (default: 768) | |
batch_size : int | |
batch size (default: 8) | |
------- | |
Returns | |
------- | |
a tuple of dataloaders objects | |
train_loader : object | |
object of train set dataloader | |
valid_loader : object | |
object of validation set dataloader | |
""" | |
train_dataset = HWRecogIAMDataset(train_x, train_y, dir_images, image_height=image_height, image_width=image_width, which_set="train") | |
valid_dataset = HWRecogIAMDataset(valid_x, valid_y, dir_images, image_height=image_height, image_width=image_width, which_set="valid") | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=4, | |
collate_fn=IAM_collate_fn, | |
) | |
valid_loader = DataLoader( | |
valid_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=4, | |
collate_fn=IAM_collate_fn, | |
) | |
return train_loader, valid_loader | |
def get_dataloader_for_testing(test_x, test_y, dir_images, image_height=32, image_width=768, batch_size=1): | |
""" | |
--------- | |
Arguments | |
--------- | |
test_x : list | |
list of test file names | |
test_y : list | |
list of test labels | |
dir_images : str | |
full directory path containing the images | |
image_height : int | |
image height (default: 32) | |
image_width : int | |
image width (default: 768) | |
batch_size : int | |
batch size (default: 1) | |
------- | |
Returns | |
------- | |
test_loader : object | |
object of test set dataloader | |
""" | |
test_dataset = HWRecogIAMDataset(test_x, test_y, dir_images=dir_images, image_height=image_height, image_width=image_width, which_set="test") | |
test_loader = DataLoader( | |
test_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=4, | |
collate_fn=IAM_collate_fn, | |
) | |
return test_loader | |