import os import torch from torchvision.transforms import transforms from PIL import Image class HandGestureDataset(torch.utils.data.Dataset): def __init__(self, data_dir, transform=None): self.data_dir = data_dir self.transform = transform self.image_files = [os.path.join(self.data_dir, f) for f in os.listdir(self.data_dir) if f.endswith('.jpg')] def __len__(self): return len(self.image_files) def __getitem__(self, idx): image_path = self.image_files[idx] image = Image.open(image_path) if self.transform: image = self.transform(image) label = self.get_label(image_path) return image, label def get_label(self, image_path): label = os.path.basename(os.path.dirname(image_path)) return label