MMFS / utils /face_parsing /
add basic files
history blame
14.8 kB
# -*- encoding: utf-8 -*-
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet import Resnet18
from models.modules.networks import init_net
class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
kernel_size = ks,
stride = stride,
padding = padding,
bias = False) = nn.BatchNorm2d(out_chan)
def forward(self, x):
x = self.conv(x)
x = F.relu(
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for _, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
if not module.bias is None:
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class ContextPath(nn.Module):
def __init__(self, *args, **kwargs):
super(ContextPath, self).__init__()
self.resnet = Resnet18()
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
def forward(self, x):
H0, W0 = x.size()[2:]
feat8, feat16, feat32 = self.resnet(x)
H8, W8 = feat8.size()[2:]
H16, W16 = feat16.size()[2:]
H32, W32 = feat32.size()[2:]
avg = F.avg_pool2d(feat32, feat32.size()[2:])
avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
feat16_up = self.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up # x8, x8, x16
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for _, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
if not module.bias is None:
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
### This is not used, since I replace this with the resnet feature with the same size
class SpatialPath(nn.Module):
def __init__(self, *args, **kwargs):
super(SpatialPath, self).__init__()
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
def forward(self, x):
feat = self.conv1(x)
feat = self.conv2(feat)
feat = self.conv3(feat)
feat = self.conv_out(feat)
return feat
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for _, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
if not module.bias is None:
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.conv2 = nn.Conv2d(out_chan//4,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, fsp, fcp):
fcat =[fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for _, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
if not module.bias is None:
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class BiSeNet(nn.Module):
def __init__(self, n_classes, *args, **kwargs):
super(BiSeNet, self).__init__()
self.cp = ContextPath()
## here self.sp is deleted
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes)
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
def forward(self, x):
H, W = x.size()[2:]
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
feat_fuse = self.ffm(feat_sp, feat_cp8)
feat_out = self.conv_out(feat_fuse)
feat_out16 = self.conv_out16(feat_cp8)
feat_out32 = self.conv_out32(feat_cp16)
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
return feat_out, feat_out16, feat_out32
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
for _, child in self.named_children():
child_wd_params, child_nowd_params = child.get_params()
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
lr_mul_wd_params += child_wd_params
lr_mul_nowd_params += child_nowd_params
wd_params += child_wd_params
nowd_params += child_nowd_params
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
class PartWeightsGenerator():
def __init__(self, gpu_ids, DDP_device):
super(PartWeightsGenerator, self).__init__()
# init face parsing network = BiSeNet(n_classes=19) = init_net(, gpu_ids=gpu_ids, DDP_device=DDP_device)
if isinstance(, torch.nn.DataParallel) or isinstance(, torch.nn.parallel.DistributedDataParallel): =
cur_folder = os.path.split(__file__)[0], 'face_parsing.pth'), map_location=lambda storage, loc: storage))
# init some tensors = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
self.sigma = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
self.gauss_kernel = torch.tensor([1,4,7,4,1,4,16,26,16,4,7,26,41,26,7,4,16,26,16,4,1,4,7,4,1],).view(1, 1, 5, 5) / 273.0
if len(gpu_ids) > 0 or not DDP_device is None:
device = DDP_device if not DDP_device is None else gpu_ids[0] =
self.sigma =
self.gauss_kernel =
# init attributes list
self.atts_list = [
'background', #0
'skin', #1
'left_brow', #2
'right_brow', #3
'left_eye', #4
'right_eye', #5
'eye_glasses', #6
'left_ear', #7
'right_ear', #8
'ear_rings', #9
'nose', #10
'teeth', #11
'upper_lip', #12
'lower_lip', #13
'neck', #14
'necklace', #15
'cloth', #16
'hair', #17
'hat', #18
def generate_masks(self, img):
with torch.no_grad():
_, _, h, w = img.size()
img_512 = F.interpolate(img, size=512, mode='bilinear')
pred_512 = * 0.5 + 0.5 - / self.sigma)[0]
pred = F.interpolate(pred_512, size=(h, w), mode='bilinear')
pred = pred.argmax(1, keepdim=True)
skin_mask = torch.zeros_like(pred).float()
skin_mask[pred == 1] = 1.0
skin_mask[pred == 7] = 1.0
skin_mask[pred == 8] = 1.0
skin_mask[pred == 10] = 1.0
skin_mask[pred == 14] = 1.0
skin_mask = F.conv2d(skin_mask, self.gauss_kernel, padding=2)
eye_mask = torch.zeros_like(pred).float()
eye_mask[pred == 4] = 1.0
eye_mask[pred == 5] = 1.0
eye_mask = F.conv2d(eye_mask, self.gauss_kernel, padding=2)
mouth_mask = torch.zeros_like(pred).float()
mouth_mask[pred == 11] = 1.0
mouth_mask[pred == 12] = 1.0
mouth_mask[pred == 13] = 1.0
mouth_mask = F.conv2d(mouth_mask, self.gauss_kernel, padding=2)
hair_mask = torch.zeros_like(pred).float()
hair_mask[pred == 17] = 1.0
hair_mask = F.conv2d(hair_mask, self.gauss_kernel, padding=2)
return skin_mask, eye_mask, mouth_mask, hair_mask
def generate_weights(self, img, weights_dict, blur=True):
with torch.no_grad():
_, _, h, w = img.size()
img_512 = F.interpolate(img, size=512, mode='bilinear')
pred_512 = * 0.5 + 0.5 - / self.sigma)[0]
pred = F.interpolate(pred_512, size=(h, w), mode='bilinear')
pred = pred.argmax(1, keepdim=True)
weights = torch.ones_like(pred).float()
for idx, att in enumerate(self.atts_list):
if att in weights_dict:
weights[pred == idx] = weights_dict[att]
if blur:
return F.conv2d(weights, self.gauss_kernel, padding=2)
return weights
class GradWeightFunc(torch.autograd.Function):
def forward(ctx, img, weights):
ctx.param = (weights, )
return img
def backward(ctx, grad):
weights = ctx.param[0]
return weights * grad, None
class GradWeightLayer(nn.Module):
def __init__(self):
super(GradWeightLayer, self).__init__()
def forward(self, img, weights):
return GradWeightFunc.apply(img, weights)