import torch import torchvision from utils.dataset import folders from utils.dataset.process import ToTensor, Normalize, RandHorizontalFlip class Data_Loader(object): """Dataset class for IQA databases""" def __init__(self, config, path, img_indx, istrain=True): self.batch_size = config.batch_size self.istrain = istrain dataset = config.dataset patch_size = config.patch_size # Train transforms if istrain: transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), RandHorizontalFlip(prob_aug=0.5), ToTensor()]) else: transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) if dataset == 'livec': self.data = folders.LIVEC(root=path, index=img_indx, transform=transforms) elif dataset == 'koniq10k': self.data = folders.Koniq10k(root=path, index=img_indx, transform=transforms) elif dataset == 'bid': self.data = folders.BID(root=path, index=img_indx, transform=transforms) elif dataset == 'spaq': self.data = folders.SPAQ(root=path, index=img_indx, transform=transforms) else: raise Exception("Only support livec, koniq10k, bid, spaq.") def get_data(self): dataloader = torch.utils.data.DataLoader(self.data, batch_size=self.batch_size, shuffle=self.istrain, num_workers=8) return dataloader