File size: 3,895 Bytes
f9e4a6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import glob
import numpy as np
from os import makedirs, name
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn

from face3d.options.inference_options import InferenceOptions
from face3d.models import create_model
from face3d.util.preprocess import align_img
from face3d.util.load_mats import load_lm3d
from face3d.extract_kp_videos import KeypointExtractor


class CoeffDetector(nn.Module):
    def __init__(self, opt):
        super().__init__()

        self.model = create_model(opt)
        self.model.setup(opt)
        self.model.device = 'cuda'
        self.model.parallelize()
        self.model.eval()

        self.lm3d_std = load_lm3d(opt.bfm_folder) 

    def forward(self, img, lm):
        
        img, trans_params = self.image_transform(img, lm)

        data_input = {                
                'imgs': img[None],
                }        
        self.model.set_input(data_input)  
        self.model.test()
        pred_coeff = {key:self.model.pred_coeffs_dict[key].cpu().numpy() for key in self.model.pred_coeffs_dict}
        pred_coeff = np.concatenate([
            pred_coeff['id'], 
            pred_coeff['exp'], 
            pred_coeff['tex'], 
            pred_coeff['angle'],
            pred_coeff['gamma'],
            pred_coeff['trans'],
            trans_params[None],
            ], 1)
        
        return {'coeff_3dmm':pred_coeff, 
                'crop_img': Image.fromarray((img.cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8))}

    def image_transform(self, images, lm):
        """
        param:
            images:          -- PIL image 
            lm:              -- numpy array
        """
        W,H = images.size
        if np.mean(lm) == -1:
            lm = (self.lm3d_std[:, :2]+1)/2.
            lm = np.concatenate(
                [lm[:, :1]*W, lm[:, 1:2]*H], 1
            )
        else:
            lm[:, -1] = H - 1 - lm[:, -1]

        trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std)        
        img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1)
        trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)])
        trans_params = torch.tensor(trans_params.astype(np.float32))
        return img, trans_params        

def get_data_path(root, keypoint_root):
    filenames = list()
    keypoint_filenames = list()

    IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'}
    IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE})
    extensions = IMAGE_EXTENSIONS

    for ext in extensions:
        filenames += glob.glob(f'{root}/*.{ext}', recursive=True)
    filenames = sorted(filenames)
    for filename in filenames:
        name = os.path.splitext(os.path.basename(filename))[0]
        keypoint_filenames.append(
            os.path.join(keypoint_root, name + '.txt')
        )
    return filenames, keypoint_filenames


if __name__ == "__main__":
    opt = InferenceOptions().parse() 
    coeff_detector = CoeffDetector(opt)
    kp_extractor = KeypointExtractor()
    image_names, keypoint_names = get_data_path(opt.input_dir, opt.keypoint_dir)
    makedirs(opt.keypoint_dir, exist_ok=True)
    makedirs(opt.output_dir, exist_ok=True)

    for image_name, keypoint_name in tqdm(zip(image_names, keypoint_names)):
        image = Image.open(image_name)
        if not os.path.isfile(keypoint_name):
            lm = kp_extractor.extract_keypoint(image, keypoint_name)
        else:
            lm = np.loadtxt(keypoint_name).astype(np.float32)
            lm = lm.reshape([-1, 2]) 
        predicted = coeff_detector(image, lm)
        name = os.path.splitext(os.path.basename(image_name))[0]
        np.savetxt(
            "{}/{}_3dmm_coeff.txt".format(opt.output_dir, name), 
            predicted['coeff_3dmm'].reshape(-1))