hafidhsoekma's picture
First commit
49bceed
raw
history blame
16.2 kB
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
import time
from itertools import product as product
from math import ceil
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)
class Inception(nn.Module):
def __init__(self):
super(Inception, self).__init__()
self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0)
self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0)
self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0)
self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1)
self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0)
self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1)
self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch1x1_2 = self.branch1x1_2(branch1x1_pool)
branch3x3_reduce = self.branch3x3_reduce(x)
branch3x3 = self.branch3x3(branch3x3_reduce)
branch3x3_reduce_2 = self.branch3x3_reduce_2(x)
branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2)
branch3x3_3 = self.branch3x3_3(branch3x3_2)
outputs = (branch1x1, branch1x1_2, branch3x3, branch3x3_3)
return torch.cat(outputs, 1)
class CRelu(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(CRelu, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = torch.cat((x, -x), 1)
x = F.relu(x, inplace=True)
return x
class FaceBoxes(nn.Module):
def __init__(self, phase, size, num_classes):
super(FaceBoxes, self).__init__()
self.phase = phase
self.num_classes = num_classes
self.size = size
self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3)
self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2)
self.inception1 = Inception()
self.inception2 = Inception()
self.inception3 = Inception()
self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.loc, self.conf = self.multibox(self.num_classes)
if self.phase == "test":
self.softmax = nn.Softmax(dim=-1)
if self.phase == "train":
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m.bias is not None:
nn.init.xavier_normal_(m.weight.data)
m.bias.data.fill_(0.02)
else:
m.weight.data.normal_(0, 0.01)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def multibox(self, num_classes):
loc_layers = []
conf_layers = []
loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)]
conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)]
loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers)
def forward(self, x):
detection_sources = list()
loc = list()
conf = list()
x = self.conv1(x)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
x = self.conv2(x)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
x = self.inception1(x)
x = self.inception2(x)
x = self.inception3(x)
detection_sources.append(x)
x = self.conv3_1(x)
x = self.conv3_2(x)
detection_sources.append(x)
x = self.conv4_1(x)
x = self.conv4_2(x)
detection_sources.append(x)
for x, l, c in zip(detection_sources, self.loc, self.conf):
loc.append(l(x).permute(0, 2, 3, 1).contiguous())
conf.append(c(x).permute(0, 2, 3, 1).contiguous())
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
if self.phase == "test":
output = (
loc.view(loc.size(0), -1, 4),
self.softmax(conf.view(-1, self.num_classes)),
)
else:
output = (
loc.view(loc.size(0), -1, 4),
conf.view(conf.size(0), -1, self.num_classes),
)
return output
class PriorBox(object):
def __init__(self, cfg, image_size=None, phase="train"):
super(PriorBox, self).__init__()
# self.aspect_ratios = cfg['aspect_ratios']
self.min_sizes = cfg["min_sizes"]
self.steps = cfg["steps"]
self.clip = cfg["clip"]
self.image_size = image_size
self.feature_maps = [
(ceil(self.image_size[0] / step), ceil(self.image_size[1] / step))
for step in self.steps
]
self.feature_maps = tuple(self.feature_maps)
def forward(self):
anchors = []
for k, f in enumerate(self.feature_maps):
min_sizes = self.min_sizes[k]
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes:
s_kx = min_size / self.image_size[1]
s_ky = min_size / self.image_size[0]
if min_size == 32:
dense_cx = [
x * self.steps[k] / self.image_size[1]
for x in [j + 0, j + 0.25, j + 0.5, j + 0.75]
]
dense_cy = [
y * self.steps[k] / self.image_size[0]
for y in [i + 0, i + 0.25, i + 0.5, i + 0.75]
]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]
elif min_size == 64:
dense_cx = [
x * self.steps[k] / self.image_size[1]
for x in [j + 0, j + 0.5]
]
dense_cy = [
y * self.steps[k] / self.image_size[0]
for y in [i + 0, i + 0.5]
]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]
else:
cx = (j + 0.5) * self.steps[k] / self.image_size[1]
cy = (i + 0.5) * self.steps[k] / self.image_size[0]
anchors += [cx, cy, s_kx, s_ky]
# back to torch land
output = torch.Tensor(anchors).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
return output
def mymax(a, b):
if a >= b:
return a
else:
return b
def mymin(a, b):
if a >= b:
return b
else:
return a
def cpu_nms(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=int)
keep = []
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
keep.append(i)
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = mymax(ix1, x1[j])
yy1 = mymax(iy1, y1[j])
xx2 = mymin(ix2, x2[j])
yy2 = mymin(iy2, y2[j])
w = mymax(0.0, xx2 - xx1 + 1)
h = mymax(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
return tuple(keep)
def nms(dets, thresh, force_cpu=False):
"""Dispatch to either CPU or GPU NMS implementations."""
if dets.shape[0] == 0:
return ()
if force_cpu:
# return cpu_soft_nms(dets, thresh, method = 0)
return cpu_nms(dets, thresh)
return cpu_nms(dets, thresh)
# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat(
(
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]),
),
1,
)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
# print('Missing keys:{}'.format(len(missing_keys)))
# print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
# print('Used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint"
return True
def remove_prefix(state_dict, prefix):
"""Old style model is stored with all names of parameters sharing common prefix 'module.'"""
# print('remove prefix \'{}\''.format(prefix))
def f(x):
return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def load_model(model, pretrained_path, load_to_cpu):
# print('Loading pretrained model from {}'.format(pretrained_path))
if load_to_cpu:
pretrained_dict = torch.load(
pretrained_path, map_location=lambda storage, loc: storage
)
else:
device = torch.cuda.current_device()
pretrained_dict = torch.load(
pretrained_path, map_location=lambda storage, loc: storage.cuda(device)
)
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict["state_dict"], "module.")
else:
pretrained_dict = remove_prefix(pretrained_dict, "module.")
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model
class SingleShotDetectorModel:
def __init__(
self,
path_to_weights: str = "./weights/anime_face_detection/ssd_anime_face_detect.pth",
confidence_threshold: float = 0.5,
nms_threshold: float = 0.3,
top_k: int = 5000,
keep_top_k: int = 750,
):
self.path_to_weights = path_to_weights
self.confidence_threshold = confidence_threshold
self.nms_threshold = nms_threshold
self.top_k = top_k
self.keep_top_k = keep_top_k
self.cfg = {
"name": "FaceBoxes",
#'min_dim': 1024,
#'feature_maps': [[32, 32], [16, 16], [8, 8]],
# 'aspect_ratios': [[1], [1], [1]],
"min_sizes": [[32, 64, 128], [256], [512]],
"steps": [32, 64, 128],
"variance": [0.1, 0.2],
"clip": False,
"loc_weight": 2.0,
"gpu_train": True,
}
self.cpu = False if torch.cuda.is_available() else True
torch.set_grad_enabled(False)
self.net = FaceBoxes(phase="test", size=None, num_classes=2)
self.net = load_model(self.net, path_to_weights, self.cpu)
self.net.eval()
self.device = torch.device("cpu" if self.cpu else "cuda")
self.net = self.net.to(self.device)
def detect_anime_face(self, image: np.ndarray) -> dict:
image = np.float32(image)
im_height, im_width, _ = image.shape
scale = torch.Tensor(
(image.shape[1], image.shape[0], image.shape[1], image.shape[0])
)
image -= (104, 117, 123)
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).unsqueeze(0)
start_time = time.perf_counter()
image = image.to(self.device)
end_time = time.perf_counter() - start_time
scale = scale.to(self.device)
loc, conf = self.net(image) # forward pass
priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(self.device)
prior_data = priors.data
boxes = decode(loc.data.squeeze(0), prior_data, self.cfg["variance"])
boxes = boxes * scale
boxes = boxes.cpu().numpy()
scores = conf.data.cpu().numpy()[:, 1]
# ignore low scores
inds = np.where(scores > self.confidence_threshold)[0]
boxes = boxes[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][: self.top_k]
boxes = boxes[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
# keep = py_cpu_nms(dets, args.nms_threshold)
keep = nms(dets, self.nms_threshold, force_cpu=self.cpu)
dets = dets[keep, :]
# keep top-K faster NMS
dets = dets[: self.keep_top_k, :]
return_data = []
for k in range(dets.shape[0]):
xmin = dets[k, 0]
ymin = dets[k, 1]
xmax = dets[k, 2]
ymax = dets[k, 3]
ymin += 0.2 * (ymax - ymin + 1)
score = dets[k, 4]
return_data.append([xmin, ymin, xmax, ymax, score])
return {"anime_face": tuple(return_data), "inference_time": end_time}
if __name__ == "__main__":
model = SingleShotDetectorModel()
image = cv2.imread(
"../../assets/example_images/others/d29492bbe7604505a6f1b5394f62b393.png"
)
data = model.detect_anime_face(image)
for d in data:
cv2.rectangle(
image, (int(d[0]), int(d[1])), (int(d[2]), int(d[3])), (0, 255, 0), 2
)
print(data)
cv2.imshow("image", image)
cv2.waitKey(0)
cv2.destroyAllWindows()