#!/usr/bin/python # -*- encoding: utf-8 -*- import os import torch import torch.nn as nn import torch.nn.functional as F from .resnet import Resnet18 from models.modules.networks import init_net class ConvBNReLU(nn.Module): def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): super(ConvBNReLU, self).__init__() self.conv = nn.Conv2d(in_chan, out_chan, kernel_size = ks, stride = stride, padding = padding, bias = False) self.bn = nn.BatchNorm2d(out_chan) self.init_weight() def forward(self, x): x = self.conv(x) x = F.relu(self.bn(x)) return x def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) class BiSeNetOutput(nn.Module): def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): super(BiSeNetOutput, self).__init__() self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) self.init_weight() def forward(self, x): x = self.conv(x) x = self.conv_out(x) return x def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for _, module in self.named_modules(): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params class AttentionRefinementModule(nn.Module): def __init__(self, in_chan, out_chan, *args, **kwargs): super(AttentionRefinementModule, self).__init__() self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) self.bn_atten = nn.BatchNorm2d(out_chan) self.sigmoid_atten = nn.Sigmoid() self.init_weight() def forward(self, x): feat = self.conv(x) atten = F.avg_pool2d(feat, feat.size()[2:]) atten = self.conv_atten(atten) atten = self.bn_atten(atten) atten = self.sigmoid_atten(atten) out = torch.mul(feat, atten) return out def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) class ContextPath(nn.Module): def __init__(self, *args, **kwargs): super(ContextPath, self).__init__() self.resnet = Resnet18() self.arm16 = AttentionRefinementModule(256, 128) self.arm32 = AttentionRefinementModule(512, 128) self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) self.init_weight() def forward(self, x): H0, W0 = x.size()[2:] feat8, feat16, feat32 = self.resnet(x) H8, W8 = feat8.size()[2:] H16, W16 = feat16.size()[2:] H32, W32 = feat32.size()[2:] avg = F.avg_pool2d(feat32, feat32.size()[2:]) avg = self.conv_avg(avg) avg_up = F.interpolate(avg, (H32, W32), mode='nearest') feat32_arm = self.arm32(feat32) feat32_sum = feat32_arm + avg_up feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') feat32_up = self.conv_head32(feat32_up) feat16_arm = self.arm16(feat16) feat16_sum = feat16_arm + feat32_up feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') feat16_up = self.conv_head16(feat16_up) return feat8, feat16_up, feat32_up # x8, x8, x16 def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for _, module in self.named_modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params ### This is not used, since I replace this with the resnet feature with the same size class SpatialPath(nn.Module): def __init__(self, *args, **kwargs): super(SpatialPath, self).__init__() self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) self.init_weight() def forward(self, x): feat = self.conv1(x) feat = self.conv2(feat) feat = self.conv3(feat) feat = self.conv_out(feat) return feat def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for _, module in self.named_modules(): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params class FeatureFusionModule(nn.Module): def __init__(self, in_chan, out_chan, *args, **kwargs): super(FeatureFusionModule, self).__init__() self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) self.conv1 = nn.Conv2d(out_chan, out_chan//4, kernel_size = 1, stride = 1, padding = 0, bias = False) self.conv2 = nn.Conv2d(out_chan//4, out_chan, kernel_size = 1, stride = 1, padding = 0, bias = False) self.relu = nn.ReLU(inplace=True) self.sigmoid = nn.Sigmoid() self.init_weight() def forward(self, fsp, fcp): fcat = torch.cat([fsp, fcp], dim=1) feat = self.convblk(fcat) atten = F.avg_pool2d(feat, feat.size()[2:]) atten = self.conv1(atten) atten = self.relu(atten) atten = self.conv2(atten) atten = self.sigmoid(atten) feat_atten = torch.mul(feat, atten) feat_out = feat_atten + feat return feat_out def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for _, module in self.named_modules(): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params class BiSeNet(nn.Module): def __init__(self, n_classes, *args, **kwargs): super(BiSeNet, self).__init__() self.cp = ContextPath() ## here self.sp is deleted self.ffm = FeatureFusionModule(256, 256) self.conv_out = BiSeNetOutput(256, 256, n_classes) self.conv_out16 = BiSeNetOutput(128, 64, n_classes) self.conv_out32 = BiSeNetOutput(128, 64, n_classes) self.init_weight() def forward(self, x): H, W = x.size()[2:] feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature feat_fuse = self.ffm(feat_sp, feat_cp8) feat_out = self.conv_out(feat_fuse) feat_out16 = self.conv_out16(feat_cp8) feat_out32 = self.conv_out32(feat_cp16) feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) return feat_out, feat_out16, feat_out32 def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] for _, child in self.named_children(): child_wd_params, child_nowd_params = child.get_params() if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): lr_mul_wd_params += child_wd_params lr_mul_nowd_params += child_nowd_params else: wd_params += child_wd_params nowd_params += child_nowd_params return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params class PartWeightsGenerator(): def __init__(self, gpu_ids, DDP_device): super(PartWeightsGenerator, self).__init__() # init face parsing network self.net = BiSeNet(n_classes=19) self.net = init_net(self.net, gpu_ids=gpu_ids, DDP_device=DDP_device) if isinstance(self.net, torch.nn.DataParallel) or isinstance(self.net, torch.nn.parallel.DistributedDataParallel): self.net = self.net.module cur_folder = os.path.split(__file__)[0] self.net.load_state_dict(torch.load(os.path.join(cur_folder, 'face_parsing.pth'), map_location=lambda storage, loc: storage)) self.net.eval() # init some tensors self.mu = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) self.sigma = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) self.gauss_kernel = torch.tensor([1,4,7,4,1,4,16,26,16,4,7,26,41,26,7,4,16,26,16,4,1,4,7,4,1],).view(1, 1, 5, 5) / 273.0 if len(gpu_ids) > 0 or not DDP_device is None: device = DDP_device if not DDP_device is None else gpu_ids[0] self.mu = self.mu.to(device) self.sigma = self.sigma.to(device) self.gauss_kernel = self.gauss_kernel.to(device) # init attributes list self.atts_list = [ 'background', #0 'skin', #1 'left_brow', #2 'right_brow', #3 'left_eye', #4 'right_eye', #5 'eye_glasses', #6 'left_ear', #7 'right_ear', #8 'ear_rings', #9 'nose', #10 'teeth', #11 'upper_lip', #12 'lower_lip', #13 'neck', #14 'necklace', #15 'cloth', #16 'hair', #17 'hat', #18 ] def generate_masks(self, img): with torch.no_grad(): _, _, h, w = img.size() img_512 = F.interpolate(img, size=512, mode='bilinear') pred_512 = self.net((img_512 * 0.5 + 0.5 - self.mu) / self.sigma)[0] pred = F.interpolate(pred_512, size=(h, w), mode='bilinear') pred = pred.argmax(1, keepdim=True) skin_mask = torch.zeros_like(pred).float() skin_mask[pred == 1] = 1.0 skin_mask[pred == 7] = 1.0 skin_mask[pred == 8] = 1.0 skin_mask[pred == 10] = 1.0 skin_mask[pred == 14] = 1.0 skin_mask = F.conv2d(skin_mask, self.gauss_kernel, padding=2) eye_mask = torch.zeros_like(pred).float() eye_mask[pred == 4] = 1.0 eye_mask[pred == 5] = 1.0 eye_mask = F.conv2d(eye_mask, self.gauss_kernel, padding=2) mouth_mask = torch.zeros_like(pred).float() mouth_mask[pred == 11] = 1.0 mouth_mask[pred == 12] = 1.0 mouth_mask[pred == 13] = 1.0 mouth_mask = F.conv2d(mouth_mask, self.gauss_kernel, padding=2) hair_mask = torch.zeros_like(pred).float() hair_mask[pred == 17] = 1.0 hair_mask = F.conv2d(hair_mask, self.gauss_kernel, padding=2) return skin_mask, eye_mask, mouth_mask, hair_mask def generate_weights(self, img, weights_dict, blur=True): with torch.no_grad(): _, _, h, w = img.size() img_512 = F.interpolate(img, size=512, mode='bilinear') pred_512 = self.net((img_512 * 0.5 + 0.5 - self.mu) / self.sigma)[0] pred = F.interpolate(pred_512, size=(h, w), mode='bilinear') pred = pred.argmax(1, keepdim=True) weights = torch.ones_like(pred).float() for idx, att in enumerate(self.atts_list): if att in weights_dict: weights[pred == idx] = weights_dict[att] if blur: return F.conv2d(weights, self.gauss_kernel, padding=2) else: return weights class GradWeightFunc(torch.autograd.Function): @staticmethod def forward(ctx, img, weights): ctx.param = (weights, ) return img @staticmethod def backward(ctx, grad): weights = ctx.param[0] return weights * grad, None class GradWeightLayer(nn.Module): def __init__(self): super(GradWeightLayer, self).__init__() def forward(self, img, weights): return GradWeightFunc.apply(img, weights)