Spaces:
Runtime error
Runtime error
File size: 4,186 Bytes
83d8d3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
"""
import json
import os.path
import pickle
import random
import numpy as np
import torch
import util.util as util
from data.base_dataset import apply_img_affine
from data.base_dataset import apply_lm_affine
from data.base_dataset import BaseDataset
from data.base_dataset import get_affine_mat
from data.base_dataset import get_transform
from data.image_folder import make_dataset
from PIL import Image
from scipy.io import loadmat
from scipy.io import savemat
from util.load_mats import load_lm3d
from util.preprocess import align_img
from util.preprocess import estimate_norm
def default_flist_reader(flist):
"""
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
"""
imlist = []
with open(flist, "r") as rf:
for line in rf.readlines():
impath = line.strip()
imlist.append(impath)
return imlist
def jason_flist_reader(flist):
with open(flist, "r") as fp:
info = json.load(fp)
return info
def parse_label(label):
return torch.tensor(np.array(label).astype(np.float32))
class FlistDataset(BaseDataset):
"""
It requires one directories to host training images '/path/to/data/train'
You can train the model with the dataset flag '--dataroot /path/to/data'.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.lm3d_std = load_lm3d(opt.bfm_folder)
msk_names = default_flist_reader(opt.flist)
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
self.size = len(self.msk_paths)
self.opt = opt
self.name = "train" if opt.isTrain else "val"
if "_" in opt.flist:
self.name += "_" + opt.flist.split(os.sep)[-1].split("_")[0]
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
img (tensor) -- an image in the input domain
msk (tensor) -- its corresponding attention mask
lm (tensor) -- its corresponding 3d landmarks
im_paths (str) -- image paths
aug_flag (bool) -- a flag used to tell whether its raw or augmented
"""
msk_path = self.msk_paths[index % self.size] # make sure index is within then range
img_path = msk_path.replace("mask/", "")
lm_path = ".".join(msk_path.replace("mask", "landmarks").split(".")[:-1]) + ".txt"
raw_img = Image.open(img_path).convert("RGB")
raw_msk = Image.open(msk_path).convert("RGB")
raw_lm = np.loadtxt(lm_path).astype(np.float32)
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
aug_flag = self.opt.use_aug and self.opt.isTrain
if aug_flag:
img, lm, msk = self._augmentation(img, lm, self.opt, msk)
_, H = img.size
M = estimate_norm(lm, H)
transform = get_transform()
img_tensor = transform(img)
msk_tensor = transform(msk)[:1, ...]
lm_tensor = parse_label(lm)
M_tensor = parse_label(M)
return {
"imgs": img_tensor,
"lms": lm_tensor,
"msks": msk_tensor,
"M": M_tensor,
"im_paths": img_path,
"aug_flag": aug_flag,
"dataset": self.name,
}
def _augmentation(self, img, lm, opt, msk=None):
affine, affine_inv, flip = get_affine_mat(opt, img.size)
img = apply_img_affine(img, affine_inv)
lm = apply_lm_affine(lm, affine, flip, img.size)
if msk is not None:
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
return img, lm, msk
def __len__(self):
"""Return the total number of images in the dataset."""
return self.size
|