Spaces:
Running
Running
import torch, os | |
from PIL import Image | |
import numpy as np | |
import torchvision.transforms as transforms | |
import torch.utils.data as data | |
from einops import rearrange | |
class ImageLoader: | |
def __init__(self, root): | |
self.img_dir = root | |
def __call__(self, img): | |
file = f'{self.img_dir}/{img}' | |
img = Image.open(file).convert('RGB') | |
return img | |
def imagenet_transform(phase): | |
if phase == 'train': | |
transform = transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor() | |
]) | |
elif phase == 'test': | |
transform = transforms.Compose([ | |
transforms.Resize([224,224]), | |
transforms.ToTensor() | |
]) | |
return transform | |
class Dataset_embedding(data.Dataset): | |
def __init__(self, cfg_data, phase='train'): | |
self.transform = imagenet_transform(phase) | |
self.type_name = cfg_data.type_name | |
self.type2idx = {self.type_name[i]: i for i in range(len(self.type_name))} | |
if phase == 'train': | |
self.loader = ImageLoader(cfg_data.train_dir) | |
name = os.listdir(f'{cfg_data.train_dir}/{self.type_name[0]}') | |
self.data = [] | |
for i in range(len(self.type_name)): | |
for j in range(len(name)): | |
self.data.append([self.type_name[i], name[j]]) | |
elif phase == 'test': | |
self.loader = ImageLoader(cfg_data.test_dir) | |
name = os.listdir(f'{cfg_data.test_dir}/{self.type_name[0]}') | |
self.data = [] | |
for i in range(1, len(self.type_name)): | |
for j in range(len(name)): | |
self.data.append([self.type_name[i], name[j]]) | |
print(f'The amount of {phase} data is {len(self.data)}') | |
def __getitem__(self, index): | |
type_name, image_name = self.data[index] | |
scene = self.type2idx[type_name] | |
image = self.transform(self.loader(f'{type_name}/{image_name}')) | |
return (scene, image) | |
def __len__(self): | |
return len(self.data) | |
def init_embedding_data(cfg_em, phase): | |
if phase == 'train': | |
train_dataset = Dataset_embedding(cfg_em, 'train') | |
test_dataset = Dataset_embedding(cfg_em, 'test') | |
train_loader = data.DataLoader(train_dataset, | |
batch_size=cfg_em.batch, | |
shuffle=True, | |
num_workers=cfg_em.num_workers, | |
pin_memory=True) | |
test_loader = data.DataLoader(test_dataset, | |
batch_size=cfg_em.batch, | |
shuffle=False, | |
num_workers=cfg_em.num_workers, | |
pin_memory=True) | |
print(len(train_dataset),len(test_dataset)) | |
elif phase == 'inference': | |
test_dataset = Dataset_embedding(cfg_em, 'test') | |
test_loader = data.DataLoader(test_dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=cfg_em.num_workers, | |
pin_memory=True) | |
return train_loader, test_loader |