# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict import paddle import paddle.nn as nn import paddle.nn.functional as F import paddleseg from paddleseg.models import layers, losses from paddleseg import utils from paddleseg.cvlibs import manager, param_init import numpy as np import scipy @manager.MODELS.add_component class MODNet(nn.Layer): """ The MODNet implementation based on PaddlePaddle. The original article refers to Zhanghan Ke, et, al. "Is a Green Screen Really Necessary for Real-Time Portrait Matting?" (https://arxiv.org/pdf/2011.11961.pdf). Args: backbone: backbone model. hr(int, optional): The channels of high resolutions branch. Defautl: None. pretrained(str, optional): The path of pretrianed model. Defautl: None. """ def __init__(self, backbone, hr_channels=32, pretrained=None): super().__init__() self.backbone = backbone self.pretrained = pretrained self.head = MODNetHead( hr_channels=hr_channels, backbone_channels=backbone.feat_channels) self.init_weight() self.blurer = GaussianBlurLayer(1, 3) def forward(self, inputs): """ If training, return a dict. If evaluation, return the final alpha prediction. """ x = inputs['img'] feat_list = self.backbone(x) y = self.head(inputs=inputs, feat_list=feat_list) return y def loss(self, logit_dict, label_dict, loss_func_dict=None): if loss_func_dict is None: loss_func_dict = defaultdict(list) loss_func_dict['semantic'].append(paddleseg.models.MSELoss()) loss_func_dict['detail'].append(paddleseg.models.L1Loss()) loss_func_dict['fusion'].append(paddleseg.models.L1Loss()) loss_func_dict['fusion'].append(paddleseg.models.L1Loss()) loss = {} # semantic loss semantic_gt = F.interpolate( label_dict['alpha'], scale_factor=1 / 16, mode='bilinear', align_corners=False) semantic_gt = self.blurer(semantic_gt) # semantic_gt.stop_gradient=True loss['semantic'] = loss_func_dict['semantic'][0](logit_dict['semantic'], semantic_gt) # detail loss trimap = label_dict['trimap'] mask = (trimap == 128).astype('float32') logit_detail = logit_dict['detail'] * mask label_detail = label_dict['alpha'] * mask loss_detail = loss_func_dict['detail'][0](logit_detail, label_detail) loss_detail = loss_detail / (mask.mean() + 1e-6) loss['detail'] = 10 * loss_detail # fusion loss matte = logit_dict['matte'] alpha = label_dict['alpha'] transition_mask = label_dict['trimap'] == 128 matte_boundary = paddle.where(transition_mask, matte, alpha) # l1 loss loss_fusion_l1 = loss_func_dict['fusion'][0]( matte, alpha) + 4 * loss_func_dict['fusion'][0](matte_boundary, alpha) # composition loss loss_fusion_comp = loss_func_dict['fusion'][1]( matte * label_dict['img'], alpha * label_dict['img']) + 4 * loss_func_dict['fusion'][1]( matte_boundary * label_dict['img'], alpha * label_dict['img']) # consisten loss with semantic transition_mask = F.interpolate( label_dict['trimap'], scale_factor=1 / 16, mode='nearest', align_corners=False) transition_mask = transition_mask == 128 matte_con_sem = F.interpolate( matte, scale_factor=1 / 16, mode='bilinear', align_corners=False) matte_con_sem = self.blurer(matte_con_sem) logit_semantic = logit_dict['semantic'].clone() logit_semantic.stop_gradient = True matte_con_sem = paddle.where(transition_mask, logit_semantic, matte_con_sem) if False: import cv2 matte_con_sem_num = matte_con_sem.numpy() matte_con_sem_num = matte_con_sem_num[0].squeeze() matte_con_sem_num = (matte_con_sem_num * 255).astype('uint8') semantic = logit_dict['semantic'].numpy() semantic = semantic[0].squeeze() semantic = (semantic * 255).astype('uint8') transition_mask = transition_mask.astype('uint8') transition_mask = transition_mask.numpy() transition_mask = (transition_mask[0].squeeze()) * 255 cv2.imwrite('matte_con.png', matte_con_sem_num) cv2.imwrite('semantic.png', semantic) cv2.imwrite('transition.png', transition_mask) mse_loss = paddleseg.models.MSELoss() loss_fusion_con_sem = mse_loss(matte_con_sem, logit_dict['semantic']) loss_fusion = loss_fusion_l1 + loss_fusion_comp + loss_fusion_con_sem loss['fusion'] = loss_fusion loss['fusion_l1'] = loss_fusion_l1 loss['fusion_comp'] = loss_fusion_comp loss['fusion_con_sem'] = loss_fusion_con_sem loss['all'] = loss['semantic'] + loss['detail'] + loss['fusion'] return loss def init_weight(self): if self.pretrained is not None: utils.load_entire_model(self, self.pretrained) class MODNetHead(nn.Layer): def __init__(self, hr_channels, backbone_channels): super().__init__() self.lr_branch = LRBranch(backbone_channels) self.hr_branch = HRBranch(hr_channels, backbone_channels) self.f_branch = FusionBranch(hr_channels, backbone_channels) self.init_weight() def forward(self, inputs, feat_list): pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(feat_list) pred_detail, hr2x = self.hr_branch(inputs['img'], enc2x, enc4x, lr8x) pred_matte = self.f_branch(inputs['img'], lr8x, hr2x) if self.training: logit_dict = { 'semantic': pred_semantic, 'detail': pred_detail, 'matte': pred_matte } return logit_dict else: return pred_matte def init_weight(self): for layer in self.sublayers(): if isinstance(layer, nn.Conv2D): param_init.kaiming_uniform(layer.weight) class FusionBranch(nn.Layer): def __init__(self, hr_channels, enc_channels): super().__init__() self.conv_lr4x = Conv2dIBNormRelu( enc_channels[2], hr_channels, 5, stride=1, padding=2) self.conv_f2x = Conv2dIBNormRelu( 2 * hr_channels, hr_channels, 3, stride=1, padding=1) self.conv_f = nn.Sequential( Conv2dIBNormRelu( hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), Conv2dIBNormRelu( int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False)) def forward(self, img, lr8x, hr2x): lr4x = F.interpolate( lr8x, scale_factor=2, mode='bilinear', align_corners=False) lr4x = self.conv_lr4x(lr4x) lr2x = F.interpolate( lr4x, scale_factor=2, mode='bilinear', align_corners=False) f2x = self.conv_f2x(paddle.concat((lr2x, hr2x), axis=1)) f = F.interpolate( f2x, scale_factor=2, mode='bilinear', align_corners=False) f = self.conv_f(paddle.concat((f, img), axis=1)) pred_matte = F.sigmoid(f) return pred_matte class HRBranch(nn.Layer): """ High Resolution Branch of MODNet """ def __init__(self, hr_channels, enc_channels): super().__init__() self.tohr_enc2x = Conv2dIBNormRelu( enc_channels[0], hr_channels, 1, stride=1, padding=0) self.conv_enc2x = Conv2dIBNormRelu( hr_channels + 3, hr_channels, 3, stride=2, padding=1) self.tohr_enc4x = Conv2dIBNormRelu( enc_channels[1], hr_channels, 1, stride=1, padding=0) self.conv_enc4x = Conv2dIBNormRelu( 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) self.conv_hr4x = nn.Sequential( Conv2dIBNormRelu( 2 * hr_channels + enc_channels[2] + 3, 2 * hr_channels, 3, stride=1, padding=1), Conv2dIBNormRelu( 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), Conv2dIBNormRelu( 2 * hr_channels, hr_channels, 3, stride=1, padding=1)) self.conv_hr2x = nn.Sequential( Conv2dIBNormRelu( 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), Conv2dIBNormRelu( 2 * hr_channels, hr_channels, 3, stride=1, padding=1), Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1)) self.conv_hr = nn.Sequential( Conv2dIBNormRelu( hr_channels + 3, hr_channels, 3, stride=1, padding=1), Conv2dIBNormRelu( hr_channels, 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False)) def forward(self, img, enc2x, enc4x, lr8x): img2x = F.interpolate( img, scale_factor=1 / 2, mode='bilinear', align_corners=False) img4x = F.interpolate( img, scale_factor=1 / 4, mode='bilinear', align_corners=False) enc2x = self.tohr_enc2x(enc2x) hr4x = self.conv_enc2x(paddle.concat((img2x, enc2x), axis=1)) enc4x = self.tohr_enc4x(enc4x) hr4x = self.conv_enc4x(paddle.concat((hr4x, enc4x), axis=1)) lr4x = F.interpolate( lr8x, scale_factor=2, mode='bilinear', align_corners=False) hr4x = self.conv_hr4x(paddle.concat((hr4x, lr4x, img4x), axis=1)) hr2x = F.interpolate( hr4x, scale_factor=2, mode='bilinear', align_corners=False) hr2x = self.conv_hr2x(paddle.concat((hr2x, enc2x), axis=1)) pred_detail = None if self.training: hr = F.interpolate( hr2x, scale_factor=2, mode='bilinear', align_corners=False) hr = self.conv_hr(paddle.concat((hr, img), axis=1)) pred_detail = F.sigmoid(hr) return pred_detail, hr2x class LRBranch(nn.Layer): def __init__(self, backbone_channels): super().__init__() self.se_block = SEBlock(backbone_channels[4], reduction=4) self.conv_lr16x = Conv2dIBNormRelu( backbone_channels[4], backbone_channels[3], 5, stride=1, padding=2) self.conv_lr8x = Conv2dIBNormRelu( backbone_channels[3], backbone_channels[2], 5, stride=1, padding=2) self.conv_lr = Conv2dIBNormRelu( backbone_channels[2], 1, 3, stride=2, padding=1, with_ibn=False, with_relu=False) def forward(self, feat_list): enc2x, enc4x, enc32x = feat_list[0], feat_list[1], feat_list[4] enc32x = self.se_block(enc32x) lr16x = F.interpolate( enc32x, scale_factor=2, mode='bilinear', align_corners=False) lr16x = self.conv_lr16x(lr16x) lr8x = F.interpolate( lr16x, scale_factor=2, mode='bilinear', align_corners=False) lr8x = self.conv_lr8x(lr8x) pred_semantic = None if self.training: lr = self.conv_lr(lr8x) pred_semantic = F.sigmoid(lr) return pred_semantic, lr8x, [enc2x, enc4x] class IBNorm(nn.Layer): """ Combine Instance Norm and Batch Norm into One Layer """ def __init__(self, in_channels): super().__init__() self.bnorm_channels = in_channels // 2 self.inorm_channels = in_channels - self.bnorm_channels self.bnorm = nn.BatchNorm2D(self.bnorm_channels) self.inorm = nn.InstanceNorm2D(self.inorm_channels) def forward(self, x): bn_x = self.bnorm(x[:, :self.bnorm_channels, :, :]) in_x = self.inorm(x[:, self.bnorm_channels:, :, :]) return paddle.concat((bn_x, in_x), 1) class Conv2dIBNormRelu(nn.Layer): """ Convolution + IBNorm + Relu """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias_attr=None, with_ibn=True, with_relu=True): super().__init__() layers = [ nn.Conv2D( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias_attr=bias_attr) ] if with_ibn: layers.append(IBNorm(out_channels)) if with_relu: layers.append(nn.ReLU()) self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class SEBlock(nn.Layer): """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf """ def __init__(self, num_channels, reduction=1): super().__init__() self.pool = nn.AdaptiveAvgPool2D(1) self.conv = nn.Sequential( nn.Conv2D( num_channels, int(num_channels // reduction), 1, bias_attr=False), nn.ReLU(), nn.Conv2D( int(num_channels // reduction), num_channels, 1, bias_attr=False), nn.Sigmoid()) def forward(self, x): w = self.pool(x) w = self.conv(w) return w * x class GaussianBlurLayer(nn.Layer): """ Add Gaussian Blur to a 4D tensors This layer takes a 4D tensor of {N, C, H, W} as input. The Gaussian blur will be performed in given channel number (C) splitly. """ def __init__(self, channels, kernel_size): """ Args: channels (int): Channel for input tensor kernel_size (int): Size of the kernel used in blurring """ super(GaussianBlurLayer, self).__init__() self.channels = channels self.kernel_size = kernel_size assert self.kernel_size % 2 != 0 self.op = nn.Sequential( nn.Pad2D(int(self.kernel_size / 2), mode='reflect'), nn.Conv2D( channels, channels, self.kernel_size, stride=1, padding=0, bias_attr=False, groups=channels)) self._init_kernel() self.op[1].weight.stop_gradient = True def forward(self, x): """ Args: x (paddle.Tensor): input 4D tensor Returns: paddle.Tensor: Blurred version of the input """ if not len(list(x.shape)) == 4: print('\'GaussianBlurLayer\' requires a 4D tensor as input\n') exit() elif not x.shape[1] == self.channels: print('In \'GaussianBlurLayer\', the required channel ({0}) is' 'not the same as input ({1})\n'.format( self.channels, x.shape[1])) exit() return self.op(x) def _init_kernel(self): sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8 n = np.zeros((self.kernel_size, self.kernel_size)) i = int(self.kernel_size / 2) n[i, i] = 1 kernel = scipy.ndimage.gaussian_filter(n, sigma) kernel = kernel.astype('float32') kernel = kernel[np.newaxis, np.newaxis, :, :] paddle.assign(kernel, self.op[1].weight)