countgd / datasets /dataset.py
nikigoli's picture
Upload folder using huggingface_hub
a277bb8 verified
from __future__ import print_function
import torch
import torchvision.datasets as datasets
from torch.utils.data import Dataset
from PIL import Image
from .tsv_io import TSVFile
import numpy as np
import base64
import io
class TSVDataset(Dataset):
""" TSV dataset for ImageNet 1K training
"""
def __init__(self, tsv_file, transform=None, target_transform=None):
self.tsv = TSVFile(tsv_file)
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
row = self.tsv.seek(index)
image_data = base64.b64decode(row[-1])
image = Image.open(io.BytesIO(image_data))
image = image.convert('RGB')
target = int(row[1])
if self.transform is not None:
img = self.transform(image)
else:
img = image
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return self.tsv.num_rows()