goldpulpy's picture
Upload model and app
f9e4a6c
import torch
import os
from PIL import Image
import random
import numpy as np
import pickle
import torchvision.transforms as transforms
class BaseDataset(torch.utils.data.Dataset):
"""docstring for BaseDataset"""
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return os.path.basename(self.opt.data_root.strip('/'))
def initialize(self, opt):
self.opt = opt
self.imgs_dir = os.path.join(self.opt.data_root, self.opt.imgs_dir)
self.is_train = self.opt.mode == "train"
# load images path
filename = self.opt.train_csv if self.is_train else self.opt.test_csv
self.imgs_name_file = os.path.join(self.opt.data_root, filename)
self.imgs_path = self.make_dataset()
# load AUs dicitionary
aus_pkl = os.path.join(self.opt.data_root, self.opt.aus_pkl)
self.aus_dict = self.load_dict(aus_pkl)
# load image to tensor transformer
self.img2tensor = self.img_transformer()
def make_dataset(self):
return None
def load_dict(self, pkl_path):
saved_dict = {}
with open(pkl_path, 'rb') as f:
saved_dict = pickle.load(f, encoding='latin1')
return saved_dict
def get_img_by_path(self, img_path):
assert os.path.isfile(img_path), "Cannot find image file: %s" % img_path
img_type = 'L' if self.opt.img_nc == 1 else 'RGB'
return Image.open(img_path).convert(img_type)
def get_aus_by_path(self, img_path):
return None
def img_transformer(self):
transform_list = []
if self.opt.resize_or_crop == 'resize_and_crop':
transform_list.append(transforms.Resize([self.opt.load_size, self.opt.load_size], Image.BICUBIC))
transform_list.append(transforms.RandomCrop(self.opt.final_size))
elif self.opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(self.opt.final_size))
elif self.opt.resize_or_crop == 'none':
transform_list.append(transforms.Lambda(lambda image: image))
else:
raise ValueError("--resize_or_crop %s is not a valid option." % self.opt.resize_or_crop)
if self.is_train and not self.opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list.append(transforms.ToTensor())
transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
img2tensor = transforms.Compose(transform_list)
return img2tensor
def __len__(self):
return len(self.imgs_path)