| from __future__ import print_function | |
| import torch | |
| import torchvision.datasets as datasets | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from .tsv_io import TSVFile | |
| import numpy as np | |
| import base64 | |
| import io | |
| class TSVDataset(Dataset): | |
| """ TSV dataset for ImageNet 1K training | |
| """ | |
| def __init__(self, tsv_file, transform=None, target_transform=None): | |
| self.tsv = TSVFile(tsv_file) | |
| self.transform = transform | |
| self.target_transform = target_transform | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (image, target) where target is class_index of the target class. | |
| """ | |
| row = self.tsv.seek(index) | |
| image_data = base64.b64decode(row[-1]) | |
| image = Image.open(io.BytesIO(image_data)) | |
| image = image.convert('RGB') | |
| target = int(row[1]) | |
| if self.transform is not None: | |
| img = self.transform(image) | |
| else: | |
| img = image | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| return img, target | |
| def __len__(self): | |
| return self.tsv.num_rows() | |