import torch, os from PIL import Image import numpy as np import torchvision.transforms as transforms import torch.utils.data as data from einops import rearrange class ImageLoader: def __init__(self, root): self.img_dir = root def __call__(self, img): file = f'{self.img_dir}/{img}' img = Image.open(file).convert('RGB') return img def imagenet_transform(phase): if phase == 'train': transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) elif phase == 'test': transform = transforms.Compose([ transforms.Resize([224,224]), transforms.ToTensor() ]) return transform class Dataset_embedding(data.Dataset): def __init__(self, cfg_data, phase='train'): self.transform = imagenet_transform(phase) self.type_name = cfg_data.type_name self.type2idx = {self.type_name[i]: i for i in range(len(self.type_name))} if phase == 'train': self.loader = ImageLoader(cfg_data.train_dir) name = os.listdir(f'{cfg_data.train_dir}/{self.type_name[0]}') self.data = [] for i in range(len(self.type_name)): for j in range(len(name)): self.data.append([self.type_name[i], name[j]]) elif phase == 'test': self.loader = ImageLoader(cfg_data.test_dir) name = os.listdir(f'{cfg_data.test_dir}/{self.type_name[0]}') self.data = [] for i in range(1, len(self.type_name)): for j in range(len(name)): self.data.append([self.type_name[i], name[j]]) print(f'The amount of {phase} data is {len(self.data)}') def __getitem__(self, index): type_name, image_name = self.data[index] scene = self.type2idx[type_name] image = self.transform(self.loader(f'{type_name}/{image_name}')) return (scene, image) def __len__(self): return len(self.data) def init_embedding_data(cfg_em, phase): if phase == 'train': train_dataset = Dataset_embedding(cfg_em, 'train') test_dataset = Dataset_embedding(cfg_em, 'test') train_loader = data.DataLoader(train_dataset, batch_size=cfg_em.batch, shuffle=True, num_workers=cfg_em.num_workers, pin_memory=True) test_loader = data.DataLoader(test_dataset, batch_size=cfg_em.batch, shuffle=False, num_workers=cfg_em.num_workers, pin_memory=True) print(len(train_dataset),len(test_dataset)) elif phase == 'inference': test_dataset = Dataset_embedding(cfg_em, 'test') test_loader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=cfg_em.num_workers, pin_memory=True) return train_loader, test_loader