File size: 4,093 Bytes
9074b85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
"""

import os.path
from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
from data.image_folder import make_dataset
from PIL import Image
import random
import util.util as util
import numpy as np
import json
import torch
from scipy.io import loadmat, savemat
import pickle
from util.preprocess import align_img, estimate_norm
from util.load_mats import load_lm3d


def default_flist_reader(flist):
    """
    flist format: impath label\nimpath label\n ...(same to caffe's filelist)
    """
    imlist = []
    with open(flist, 'r') as rf:
        for line in rf.readlines():
            impath = line.strip()
            imlist.append(impath)

    return imlist

def jason_flist_reader(flist):
    with open(flist, 'r') as fp:
        info = json.load(fp)
    return info

def parse_label(label):
    return torch.tensor(np.array(label).astype(np.float32))


class FlistDataset(BaseDataset):
    """
    It requires one directories to host training images '/path/to/data/train'
    You can train the model with the dataset flag '--dataroot /path/to/data'.
    """

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseDataset.__init__(self, opt)
        
        self.lm3d_std = load_lm3d(opt.bfm_folder)
        
        msk_names = default_flist_reader(opt.flist)
        self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]

        self.size = len(self.msk_paths) 
        self.opt = opt
        
        self.name = 'train' if opt.isTrain else 'val'
        if '_' in opt.flist:
            self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
        

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index (int)      -- a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            img (tensor)       -- an image in the input domain
            msk (tensor)       -- its corresponding attention mask
            lm  (tensor)       -- its corresponding 3d landmarks
            im_paths (str)     -- image paths
            aug_flag (bool)    -- a flag used to tell whether its raw or augmented
        """
        msk_path = self.msk_paths[index % self.size]  # make sure index is within then range
        img_path = msk_path.replace('mask/', '')
        lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'

        raw_img = Image.open(img_path).convert('RGB')
        raw_msk = Image.open(msk_path).convert('RGB')
        raw_lm = np.loadtxt(lm_path).astype(np.float32)

        _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
        
        aug_flag = self.opt.use_aug and self.opt.isTrain
        if aug_flag:
            img, lm, msk = self._augmentation(img, lm, self.opt, msk)
        
        _, H = img.size
        M = estimate_norm(lm, H)
        transform = get_transform()
        img_tensor = transform(img)
        msk_tensor = transform(msk)[:1, ...]
        lm_tensor = parse_label(lm)
        M_tensor = parse_label(M)


        return {'imgs': img_tensor, 
                'lms': lm_tensor, 
                'msks': msk_tensor, 
                'M': M_tensor,
                'im_paths': img_path, 
                'aug_flag': aug_flag,
                'dataset': self.name}

    def _augmentation(self, img, lm, opt, msk=None):
        affine, affine_inv, flip = get_affine_mat(opt, img.size)
        img = apply_img_affine(img, affine_inv)
        lm = apply_lm_affine(lm, affine, flip, img.size)
        if msk is not None:
            msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
        return img, lm, msk
    



    def __len__(self):
        """Return the total number of images in the dataset.
        """
        return self.size