File size: 3,122 Bytes
f12ab4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))

import torch
from BiSeNet import BiSeNet
from torchvision.transforms import transforms
import cv2
from tqdm import tqdm
import numpy as np
from glob import glob
from correct_head_mask import correct_hair_mask
import argparse


class GenHeadMask(object):
    def __init__(self, gpu_id) -> None:
        super().__init__()
        
        self.device = torch.device("cuda:%s" % gpu_id)
        self.model_path = "ConfigModels/faceparsing_model.pth"
        
        self.init_model()
        self.lut = np.zeros((256, ), dtype=np.uint8)
        self.lut[1:14] = 1
        self.lut[17] = 2
        
        #  ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
        #     'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']

    def init_model(self):
        n_classes = 19
        net = BiSeNet(n_classes=n_classes).to(self.device)
        net.load_state_dict(torch.load(self.model_path, map_location=self.device))
        self.net = net
        self.net.eval()

        self.to_tensor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        
        
    def main_process(self, img_dir):
        img_path_list = [x for x in glob("%s/*.png" % img_dir) if "mask" not in x]  
        img_path_list = img_path_list + [x for x in glob("%s/*.jpg" % img_dir) if "mask" not in x]  
        if len(img_path_list) == 0:
            print("Dir: %s does include any .png and .jpg images." % img_dir)
            exit(0)
        img_path_list.sort()
        loop_bar = tqdm(img_path_list)
        loop_bar.set_description("Generate head masks")
        for img_path in loop_bar:
            save_path = img_path[:-4] + "_mask.png"
            
            bgr_img = cv2.imread(img_path)
            img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
            img = self.to_tensor(img)
            img = img.unsqueeze(0)
            img = img.to(self.device)
            with torch.set_grad_enabled(False):
                pred_res = self.net(img)
                out = pred_res[0]

            res = out.squeeze(0).cpu().numpy().argmax(0)
            res = res.astype(np.uint8)
            cv2.LUT(res, self.lut, res)
            
            res = correct_hair_mask(res)
            res[res != 0] = 255
            # temp_img = bgr_img.copy()
            # temp_img[res == 0] = 255
            
            # res = np.concatenate([bgr_img, temp_img], axis=1)
            cv2.imwrite(save_path, res)
            
            
if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='The code for generating head mask images.')
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--img_dir", type=str, required=True)
    args = parser.parse_args()

    gpu_id = args.gpu_id
    img_dir = args.img_dir

    # assert len(sys.argv) == 2
    # img_dir = sys.argv[1]

    tt = GenHeadMask(gpu_id=gpu_id)
    tt.main_process(img_dir=img_dir)