Spaces:
Paused
Paused
import torch | |
import os | |
from PIL import Image | |
import random | |
import numpy as np | |
import pickle | |
import torchvision.transforms as transforms | |
class BaseDataset(torch.utils.data.Dataset): | |
"""docstring for BaseDataset""" | |
def __init__(self): | |
super(BaseDataset, self).__init__() | |
def name(self): | |
return os.path.basename(self.opt.data_root.strip('/')) | |
def initialize(self, opt): | |
self.opt = opt | |
self.imgs_dir = os.path.join(self.opt.data_root, self.opt.imgs_dir) | |
self.is_train = self.opt.mode == "train" | |
# load images path | |
filename = self.opt.train_csv if self.is_train else self.opt.test_csv | |
self.imgs_name_file = os.path.join(self.opt.data_root, filename) | |
self.imgs_path = self.make_dataset() | |
# load AUs dicitionary | |
aus_pkl = os.path.join(self.opt.data_root, self.opt.aus_pkl) | |
self.aus_dict = self.load_dict(aus_pkl) | |
# load image to tensor transformer | |
self.img2tensor = self.img_transformer() | |
def make_dataset(self): | |
return None | |
def load_dict(self, pkl_path): | |
saved_dict = {} | |
with open(pkl_path, 'rb') as f: | |
saved_dict = pickle.load(f, encoding='latin1') | |
return saved_dict | |
def get_img_by_path(self, img_path): | |
assert os.path.isfile(img_path), "Cannot find image file: %s" % img_path | |
img_type = 'L' if self.opt.img_nc == 1 else 'RGB' | |
return Image.open(img_path).convert(img_type) | |
def get_aus_by_path(self, img_path): | |
return None | |
def img_transformer(self): | |
transform_list = [] | |
if self.opt.resize_or_crop == 'resize_and_crop': | |
transform_list.append(transforms.Resize([self.opt.load_size, self.opt.load_size], Image.BICUBIC)) | |
transform_list.append(transforms.RandomCrop(self.opt.final_size)) | |
elif self.opt.resize_or_crop == 'crop': | |
transform_list.append(transforms.RandomCrop(self.opt.final_size)) | |
elif self.opt.resize_or_crop == 'none': | |
transform_list.append(transforms.Lambda(lambda image: image)) | |
else: | |
raise ValueError("--resize_or_crop %s is not a valid option." % self.opt.resize_or_crop) | |
if self.is_train and not self.opt.no_flip: | |
transform_list.append(transforms.RandomHorizontalFlip()) | |
transform_list.append(transforms.ToTensor()) | |
transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) | |
img2tensor = transforms.Compose(transform_list) | |
return img2tensor | |
def __len__(self): | |
return len(self.imgs_path) | |