gavinyuan
update: PIPNet
4599493
raw
history blame contribute delete
No virus
5.81 kB
import cv2
import sys
import os
def make_pipnet():
cmds = [
"cd ./third_party/PIPNet/FaceBoxesV2/utils/ && chmod +x ./make.sh "
"&& bash ./make.sh "
"&& cd - ",
]
for cmd in cmds:
os.system(cmd)
print('[PIPNet.lib.tools] nms .o file built successfully.')
make_pipnet()
from math import floor
from third_party.PIPNet.FaceBoxesV2.faceboxes_detector import *
import torch
import torch.nn.parallel
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.models as models
from third_party.PIPNet.lib.networks import *
from third_party.PIPNet.lib.functions import *
from third_party.PIPNet.reverse_index import ri1, ri2
make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))
class Config:
def __init__(self):
self.det_head = "pip"
self.net_stride = 32
self.batch_size = 16
self.init_lr = 0.0001
self.num_epochs = 60
self.decay_steps = [30, 50]
self.input_size = 256
self.backbone = "resnet101"
self.pretrained = True
self.criterion_cls = "l2"
self.criterion_reg = "l1"
self.cls_loss_weight = 10
self.reg_loss_weight = 1
self.num_lms = 98
self.save_interval = self.num_epochs
self.num_nb = 10
self.use_gpu = True
self.gpu_id = 3
def get_lmk_model():
cfg = Config()
resnet101 = models.resnet101(pretrained=cfg.pretrained)
net = Pip_resnet101(
resnet101,
cfg.num_nb,
num_lms=cfg.num_lms,
input_size=cfg.input_size,
net_stride=cfg.net_stride,
)
if cfg.use_gpu:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cpu")
net = net.to(device)
weight_file = make_abs_path('../../../weights/PIPNet/epoch59.pth')
state_dict = torch.load(weight_file, map_location=device)
net.load_state_dict(state_dict)
detector = FaceBoxesDetector(
"FaceBoxes",
make_abs_path("../../../weights/PIPNet/FaceBoxesV2.pth"),
use_gpu=torch.cuda.is_available(),
device=device,
)
return net, detector
def demo_image(
image_file,
net,
detector,
input_size=256,
net_stride=32,
num_nb=10,
use_gpu=True,
device="cuda:0",
):
my_thresh = 0.6
det_box_scale = 1.2
net.eval()
preprocess = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
reverse_index1, reverse_index2, max_len = ri1, ri2, 17
# image = cv2.imread(image_file)
image = image_file
image_height, image_width, _ = image.shape
detections, _ = detector.detect(image, my_thresh, 1)
lmks = []
for i in range(len(detections)):
det_xmin = detections[i][2]
det_ymin = detections[i][3]
det_width = detections[i][4]
det_height = detections[i][5]
det_xmax = det_xmin + det_width - 1
det_ymax = det_ymin + det_height - 1
det_xmin -= int(det_width * (det_box_scale - 1) / 2)
# remove a part of top area for alignment, see paper for details
det_ymin += int(det_height * (det_box_scale - 1) / 2)
det_xmax += int(det_width * (det_box_scale - 1) / 2)
det_ymax += int(det_height * (det_box_scale - 1) / 2)
det_xmin = max(det_xmin, 0)
det_ymin = max(det_ymin, 0)
det_xmax = min(det_xmax, image_width - 1)
det_ymax = min(det_ymax, image_height - 1)
det_width = det_xmax - det_xmin + 1
det_height = det_ymax - det_ymin + 1
# cv2.rectangle(image, (det_xmin, det_ymin), (det_xmax, det_ymax), (0, 0, 255), 2)
det_crop = image[det_ymin:det_ymax, det_xmin:det_xmax, :]
det_crop = cv2.resize(det_crop, (input_size, input_size))
inputs = Image.fromarray(det_crop[:, :, ::-1].astype("uint8"), "RGB")
inputs = preprocess(inputs).unsqueeze(0)
inputs = inputs.to(device)
(
lms_pred_x,
lms_pred_y,
lms_pred_nb_x,
lms_pred_nb_y,
outputs_cls,
max_cls,
) = forward_pip(net, inputs, preprocess, input_size, net_stride, num_nb)
lms_pred = torch.cat((lms_pred_x, lms_pred_y), dim=1).flatten()
tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(98, max_len)
tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(98, max_len)
tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1, 1)
tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1, 1)
lms_pred_merge = torch.cat((tmp_x, tmp_y), dim=1).flatten()
lms_pred = lms_pred.cpu().numpy()
lms_pred_merge = lms_pred_merge.cpu().numpy()
lmk_ = []
for i in range(98):
x_pred = lms_pred_merge[i * 2] * det_width
y_pred = lms_pred_merge[i * 2 + 1] * det_height
# cv2.circle(
# image,
# (int(x_pred) + det_xmin, int(y_pred) + det_ymin),
# 1,
# (0, 0, 255),
# 1,
# )
lmk_.append([int(x_pred) + det_xmin, int(y_pred) + det_ymin])
lmks.append(np.array(lmk_))
# image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# cv2.imwrite("./1_out.jpg", image_bgr)
return lmks
if __name__ == "__main__":
net, detector = get_lmk_model()
demo_image(
"/apdcephfs/private_ahbanliang/codes/Real-ESRGAN-master/tmp_frames/yanikefu/frame00000046.png",
net,
detector,
)