reysarms's picture
updated
c08ab4e
import torch
from torch.utils.data import Dataset
class HumanActionDataset(Dataset):
def __init__(self, hf_dataset_split, transform=None):
"""
hf_dataset_split: Hugging Face dataset split, e.g. ds['train']
transform: torchvision transforms
"""
self.dataset = hf_dataset_split
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
image = item["image"] # PIL.Image.Image
label = item["labels"]
if self.transform:
image = self.transform(image)
return image, label