# File 4: utils/data_loader.py import os import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader, Dataset def get_dataloader(batch_size=64, dataset_path='data/train'): transform = transforms.Compose([ transforms.Resize(128), transforms.CenterCrop(112), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]) dataset = datasets.ImageFolder(dataset_path, transform=transform) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) return loader class UserFeedbackDataset(Dataset): def __init__(self, storage_path='storage/user_data'): self.data_files = [os.path.join(storage_path, f) for f in os.listdir(storage_path)] def __len__(self): return len(self.data_files) def __getitem__(self, idx): data = torch.load(self.data_files[idx]) return data['latent'], data['text'], data['similarity']