|
|
|
|
|
""" |
|
imagefolder loader |
|
inspired from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py |
|
@author: Tu Bui @surrey.ac.uk |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
import os |
|
import sys |
|
import io |
|
import time |
|
import pandas as pd |
|
import numpy as np |
|
import random |
|
from PIL import Image |
|
from typing import Any, Callable, List, Optional, Tuple |
|
import torch |
|
from .base_lmdb import PILlmdb, ArrayDatabase |
|
from torchvision import transforms |
|
|
|
|
|
|
|
def worker_init_fn(worker_id): |
|
|
|
|
|
torch_seed = torch.initial_seed() |
|
random.seed(torch_seed + worker_id) |
|
if torch_seed >= 2**30: |
|
torch_seed = torch_seed % 2**30 |
|
np.random.seed(torch_seed + worker_id) |
|
|
|
|
|
def pil_loader(path: str) -> Image.Image: |
|
|
|
with open(path, 'rb') as f: |
|
img = Image.open(f) |
|
return img.convert('RGB') |
|
|
|
|
|
class ImageDataset(torch.utils.data.Dataset): |
|
r""" |
|
Customised Image Folder class for pytorch. |
|
Accept lmdb and a csv list as the input. |
|
Usage: |
|
dataset = ImageDataset(img_dir, img_list) |
|
dataset.set_transform(some_pytorch_transforms) |
|
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, |
|
num_workers=4, worker_init_fn=worker_init_fn) |
|
for x,y in loader: |
|
# x and y is input and target (dict), the keys can be customised. |
|
""" |
|
_repr_indent = 4 |
|
def __init__(self, data_dir, data_list, secret_len=100, transform=None, target_transform=None, **kwargs): |
|
super().__init__() |
|
self.set_transform(transform, target_transform) |
|
self.build_data(data_dir, data_list, **kwargs) |
|
self.secret_len = secret_len |
|
self.kwargs = kwargs |
|
|
|
def set_transform(self, transform, target_transform=None): |
|
self.transform, self.target_transform = transform, target_transform |
|
|
|
def build_data(self, data_dir, data_list, **kwargs): |
|
""" |
|
Args: |
|
data_list (text file) must have at least 3 fields: id, path and label |
|
|
|
This method must create an attribute self.samples containing ID, input and target samples; and another attribute N storing the dataset size |
|
|
|
Optional attributes: classes (list of unique classes), group (useful for |
|
metric learning) |
|
""" |
|
self.data_dir, self.list = data_dir, data_list |
|
if ('dtype' in kwargs) and (kwargs['dtype'].lower() == 'array'): |
|
data = ArrayDatabase(data_dir, data_list) |
|
else: |
|
data = PILlmdb(data_dir, data_list, **kwargs) |
|
self.N = len(data) |
|
self.classes = np.unique(data.labels) |
|
self.samples = {'x': data, 'y': data.labels} |
|
|
|
def __getitem__(self, index: int) -> Any: |
|
""" |
|
Args: |
|
index (int): Index |
|
Returns: |
|
dict: (x: sample, y: target, **kwargs) |
|
""" |
|
x, y = self.samples['x'][index], self.samples['y'][index] |
|
if self.transform is not None: |
|
x = self.transform(x) |
|
if self.target_transform is not None: |
|
y = self.target_transform(y) |
|
x = np.array(x, dtype=np.float32)/127.5-1. |
|
secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2) |
|
return {'image': x, 'secret': secret} |
|
|
|
def __len__(self) -> int: |
|
|
|
return self.N |
|
|
|
def __repr__(self) -> str: |
|
head = "\nDataset " + self.__class__.__name__ |
|
body = ["Number of datapoints: {}".format(self.__len__())] |
|
if hasattr(self, 'data_dir') and self.data_dir is not None: |
|
body.append("data_dir location: {}".format(self.data_dir)) |
|
if hasattr(self, 'kwargs'): |
|
body.append(f'kwargs: {self.kwargs}') |
|
body += self.extra_repr().splitlines() |
|
if hasattr(self, "transform") and self.transform is not None: |
|
body += [repr(self.transform)] |
|
lines = [head] + [" " * self._repr_indent + line for line in body] |
|
return '\n'.join(lines) |
|
|
|
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: |
|
lines = transform.__repr__().splitlines() |
|
return (["{}{}".format(head, lines[0])] + |
|
["{}{}".format(" " * len(head), line) for line in lines[1:]]) |
|
|
|
def extra_repr(self) -> str: |
|
return "" |
|
|
|
class ImageFolder(torch.utils.data.Dataset): |
|
_repr_indent = 4 |
|
def __init__(self, data_dir, data_list, secret_len=100, resize=256, transform=None, **kwargs): |
|
super().__init__() |
|
self.transform = transforms.Resize((resize, resize)) if transform is None else transform |
|
self.build_data(data_dir, data_list, **kwargs) |
|
self.kwargs = kwargs |
|
self.secret_len = secret_len |
|
|
|
def build_data(self, data_dir, data_list, **kwargs): |
|
self.data_dir = data_dir |
|
if isinstance(data_list, list): |
|
self.data_list = data_list |
|
elif isinstance(data_list, str): |
|
self.data_list = pd.read_csv(data_list)['path'].tolist() |
|
elif isinstance(data_list, pd.DataFrame): |
|
self.data_list = data_list['path'].tolist() |
|
else: |
|
raise ValueError('data_list must be a list, str or pd.DataFrame') |
|
self.N = len(self.data_list) |
|
|
|
def __getitem__(self, index): |
|
path = self.data_list[index] |
|
img = pil_loader(os.path.join(self.data_dir, path)) |
|
img = self.transform(img) |
|
img = np.array(img, dtype=np.float32)/127.5-1. |
|
secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2) |
|
return {'image': img, 'secret': secret} |
|
|
|
def __len__(self) -> int: |
|
|
|
return self.N |