|
import os |
|
import sys |
|
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) |
|
|
|
import torch |
|
from BiSeNet import BiSeNet |
|
from torchvision.transforms import transforms |
|
import cv2 |
|
from tqdm import tqdm |
|
import numpy as np |
|
from glob import glob |
|
from correct_head_mask import correct_hair_mask |
|
import argparse |
|
|
|
|
|
class GenHeadMask(object): |
|
def __init__(self, gpu_id) -> None: |
|
super().__init__() |
|
|
|
self.device = torch.device("cuda:%s" % gpu_id) |
|
self.model_path = "ConfigModels/faceparsing_model.pth" |
|
|
|
self.init_model() |
|
self.lut = np.zeros((256, ), dtype=np.uint8) |
|
self.lut[1:14] = 1 |
|
self.lut[17] = 2 |
|
|
|
|
|
|
|
|
|
def init_model(self): |
|
n_classes = 19 |
|
net = BiSeNet(n_classes=n_classes).to(self.device) |
|
net.load_state_dict(torch.load(self.model_path, map_location=self.device)) |
|
self.net = net |
|
self.net.eval() |
|
|
|
self.to_tensor = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
|
|
|
|
def main_process(self, img_dir): |
|
img_path_list = [x for x in glob("%s/*.png" % img_dir) if "mask" not in x] |
|
img_path_list = img_path_list + [x for x in glob("%s/*.jpg" % img_dir) if "mask" not in x] |
|
if len(img_path_list) == 0: |
|
print("Dir: %s does include any .png and .jpg images." % img_dir) |
|
exit(0) |
|
img_path_list.sort() |
|
loop_bar = tqdm(img_path_list) |
|
loop_bar.set_description("Generate head masks") |
|
for img_path in loop_bar: |
|
save_path = img_path[:-4] + "_mask.png" |
|
|
|
bgr_img = cv2.imread(img_path) |
|
img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) |
|
img = self.to_tensor(img) |
|
img = img.unsqueeze(0) |
|
img = img.to(self.device) |
|
with torch.set_grad_enabled(False): |
|
pred_res = self.net(img) |
|
out = pred_res[0] |
|
|
|
res = out.squeeze(0).cpu().numpy().argmax(0) |
|
res = res.astype(np.uint8) |
|
cv2.LUT(res, self.lut, res) |
|
|
|
res = correct_hair_mask(res) |
|
res[res != 0] = 255 |
|
|
|
|
|
|
|
|
|
cv2.imwrite(save_path, res) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description='The code for generating head mask images.') |
|
parser.add_argument("--gpu_id", type=int, default=0) |
|
parser.add_argument("--img_dir", type=str, required=True) |
|
args = parser.parse_args() |
|
|
|
gpu_id = args.gpu_id |
|
img_dir = args.img_dir |
|
|
|
|
|
|
|
|
|
tt = GenHeadMask(gpu_id=gpu_id) |
|
tt.main_process(img_dir=img_dir) |
|
|