xuehongyang
ser
83d8d3c
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
"""
import json
import os.path
import pickle
import random
import numpy as np
import torch
import util.util as util
from data.base_dataset import apply_img_affine
from data.base_dataset import apply_lm_affine
from data.base_dataset import BaseDataset
from data.base_dataset import get_affine_mat
from data.base_dataset import get_transform
from data.image_folder import make_dataset
from PIL import Image
from scipy.io import loadmat
from scipy.io import savemat
from util.load_mats import load_lm3d
from util.preprocess import align_img
from util.preprocess import estimate_norm
def default_flist_reader(flist):
"""
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
"""
imlist = []
with open(flist, "r") as rf:
for line in rf.readlines():
impath = line.strip()
imlist.append(impath)
return imlist
def jason_flist_reader(flist):
with open(flist, "r") as fp:
info = json.load(fp)
return info
def parse_label(label):
return torch.tensor(np.array(label).astype(np.float32))
class FlistDataset(BaseDataset):
"""
It requires one directories to host training images '/path/to/data/train'
You can train the model with the dataset flag '--dataroot /path/to/data'.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.lm3d_std = load_lm3d(opt.bfm_folder)
msk_names = default_flist_reader(opt.flist)
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
self.size = len(self.msk_paths)
self.opt = opt
self.name = "train" if opt.isTrain else "val"
if "_" in opt.flist:
self.name += "_" + opt.flist.split(os.sep)[-1].split("_")[0]
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
img (tensor) -- an image in the input domain
msk (tensor) -- its corresponding attention mask
lm (tensor) -- its corresponding 3d landmarks
im_paths (str) -- image paths
aug_flag (bool) -- a flag used to tell whether its raw or augmented
"""
msk_path = self.msk_paths[index % self.size] # make sure index is within then range
img_path = msk_path.replace("mask/", "")
lm_path = ".".join(msk_path.replace("mask", "landmarks").split(".")[:-1]) + ".txt"
raw_img = Image.open(img_path).convert("RGB")
raw_msk = Image.open(msk_path).convert("RGB")
raw_lm = np.loadtxt(lm_path).astype(np.float32)
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
aug_flag = self.opt.use_aug and self.opt.isTrain
if aug_flag:
img, lm, msk = self._augmentation(img, lm, self.opt, msk)
_, H = img.size
M = estimate_norm(lm, H)
transform = get_transform()
img_tensor = transform(img)
msk_tensor = transform(msk)[:1, ...]
lm_tensor = parse_label(lm)
M_tensor = parse_label(M)
return {
"imgs": img_tensor,
"lms": lm_tensor,
"msks": msk_tensor,
"M": M_tensor,
"im_paths": img_path,
"aug_flag": aug_flag,
"dataset": self.name,
}
def _augmentation(self, img, lm, opt, msk=None):
affine, affine_inv, flip = get_affine_mat(opt, img.size)
img = apply_img_affine(img, affine_inv)
lm = apply_lm_affine(lm, affine, flip, img.size)
if msk is not None:
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
return img, lm, msk
def __len__(self):
"""Return the total number of images in the dataset."""
return self.size