"""This script defines the custom dataset for Deep3DFaceRecon_pytorch """ import json import os.path import pickle import random import numpy as np import torch import util.util as util from data.base_dataset import apply_img_affine from data.base_dataset import apply_lm_affine from data.base_dataset import BaseDataset from data.base_dataset import get_affine_mat from data.base_dataset import get_transform from data.image_folder import make_dataset from PIL import Image from scipy.io import loadmat from scipy.io import savemat from util.load_mats import load_lm3d from util.preprocess import align_img from util.preprocess import estimate_norm def default_flist_reader(flist): """ flist format: impath label\nimpath label\n ...(same to caffe's filelist) """ imlist = [] with open(flist, "r") as rf: for line in rf.readlines(): impath = line.strip() imlist.append(impath) return imlist def jason_flist_reader(flist): with open(flist, "r") as fp: info = json.load(fp) return info def parse_label(label): return torch.tensor(np.array(label).astype(np.float32)) class FlistDataset(BaseDataset): """ It requires one directories to host training images '/path/to/data/train' You can train the model with the dataset flag '--dataroot /path/to/data'. """ def __init__(self, opt): """Initialize this dataset class. Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseDataset.__init__(self, opt) self.lm3d_std = load_lm3d(opt.bfm_folder) msk_names = default_flist_reader(opt.flist) self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] self.size = len(self.msk_paths) self.opt = opt self.name = "train" if opt.isTrain else "val" if "_" in opt.flist: self.name += "_" + opt.flist.split(os.sep)[-1].split("_")[0] def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index (int) -- a random integer for data indexing Returns a dictionary that contains A, B, A_paths and B_paths img (tensor) -- an image in the input domain msk (tensor) -- its corresponding attention mask lm (tensor) -- its corresponding 3d landmarks im_paths (str) -- image paths aug_flag (bool) -- a flag used to tell whether its raw or augmented """ msk_path = self.msk_paths[index % self.size] # make sure index is within then range img_path = msk_path.replace("mask/", "") lm_path = ".".join(msk_path.replace("mask", "landmarks").split(".")[:-1]) + ".txt" raw_img = Image.open(img_path).convert("RGB") raw_msk = Image.open(msk_path).convert("RGB") raw_lm = np.loadtxt(lm_path).astype(np.float32) _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) aug_flag = self.opt.use_aug and self.opt.isTrain if aug_flag: img, lm, msk = self._augmentation(img, lm, self.opt, msk) _, H = img.size M = estimate_norm(lm, H) transform = get_transform() img_tensor = transform(img) msk_tensor = transform(msk)[:1, ...] lm_tensor = parse_label(lm) M_tensor = parse_label(M) return { "imgs": img_tensor, "lms": lm_tensor, "msks": msk_tensor, "M": M_tensor, "im_paths": img_path, "aug_flag": aug_flag, "dataset": self.name, } def _augmentation(self, img, lm, opt, msk=None): affine, affine_inv, flip = get_affine_mat(opt, img.size) img = apply_img_affine(img, affine_inv) lm = apply_lm_affine(lm, affine, flip, img.size) if msk is not None: msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) return img, lm, msk def __len__(self): """Return the total number of images in the dataset.""" return self.size