marta-0's picture
add files
6da6215
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddleseg
from paddleseg.models import layers, losses
from paddleseg import utils
from paddleseg.cvlibs import manager, param_init
import numpy as np
import scipy
@manager.MODELS.add_component
class MODNet(nn.Layer):
"""
The MODNet implementation based on PaddlePaddle.
The original article refers to
Zhanghan Ke, et, al. "Is a Green Screen Really Necessary for Real-Time Portrait Matting?"
(https://arxiv.org/pdf/2011.11961.pdf).
Args:
backbone: backbone model.
hr(int, optional): The channels of high resolutions branch. Defautl: None.
pretrained(str, optional): The path of pretrianed model. Defautl: None.
"""
def __init__(self, backbone, hr_channels=32, pretrained=None):
super().__init__()
self.backbone = backbone
self.pretrained = pretrained
self.head = MODNetHead(
hr_channels=hr_channels, backbone_channels=backbone.feat_channels)
self.init_weight()
self.blurer = GaussianBlurLayer(1, 3)
def forward(self, inputs):
"""
If training, return a dict.
If evaluation, return the final alpha prediction.
"""
x = inputs['img']
feat_list = self.backbone(x)
y = self.head(inputs=inputs, feat_list=feat_list)
return y
def loss(self, logit_dict, label_dict, loss_func_dict=None):
if loss_func_dict is None:
loss_func_dict = defaultdict(list)
loss_func_dict['semantic'].append(paddleseg.models.MSELoss())
loss_func_dict['detail'].append(paddleseg.models.L1Loss())
loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
loss = {}
# semantic loss
semantic_gt = F.interpolate(
label_dict['alpha'],
scale_factor=1 / 16,
mode='bilinear',
align_corners=False)
semantic_gt = self.blurer(semantic_gt)
# semantic_gt.stop_gradient=True
loss['semantic'] = loss_func_dict['semantic'][0](logit_dict['semantic'],
semantic_gt)
# detail loss
trimap = label_dict['trimap']
mask = (trimap == 128).astype('float32')
logit_detail = logit_dict['detail'] * mask
label_detail = label_dict['alpha'] * mask
loss_detail = loss_func_dict['detail'][0](logit_detail, label_detail)
loss_detail = loss_detail / (mask.mean() + 1e-6)
loss['detail'] = 10 * loss_detail
# fusion loss
matte = logit_dict['matte']
alpha = label_dict['alpha']
transition_mask = label_dict['trimap'] == 128
matte_boundary = paddle.where(transition_mask, matte, alpha)
# l1 loss
loss_fusion_l1 = loss_func_dict['fusion'][0](
matte,
alpha) + 4 * loss_func_dict['fusion'][0](matte_boundary, alpha)
# composition loss
loss_fusion_comp = loss_func_dict['fusion'][1](
matte * label_dict['img'],
alpha * label_dict['img']) + 4 * loss_func_dict['fusion'][1](
matte_boundary * label_dict['img'], alpha * label_dict['img'])
# consisten loss with semantic
transition_mask = F.interpolate(
label_dict['trimap'],
scale_factor=1 / 16,
mode='nearest',
align_corners=False)
transition_mask = transition_mask == 128
matte_con_sem = F.interpolate(
matte, scale_factor=1 / 16, mode='bilinear', align_corners=False)
matte_con_sem = self.blurer(matte_con_sem)
logit_semantic = logit_dict['semantic'].clone()
logit_semantic.stop_gradient = True
matte_con_sem = paddle.where(transition_mask, logit_semantic,
matte_con_sem)
if False:
import cv2
matte_con_sem_num = matte_con_sem.numpy()
matte_con_sem_num = matte_con_sem_num[0].squeeze()
matte_con_sem_num = (matte_con_sem_num * 255).astype('uint8')
semantic = logit_dict['semantic'].numpy()
semantic = semantic[0].squeeze()
semantic = (semantic * 255).astype('uint8')
transition_mask = transition_mask.astype('uint8')
transition_mask = transition_mask.numpy()
transition_mask = (transition_mask[0].squeeze()) * 255
cv2.imwrite('matte_con.png', matte_con_sem_num)
cv2.imwrite('semantic.png', semantic)
cv2.imwrite('transition.png', transition_mask)
mse_loss = paddleseg.models.MSELoss()
loss_fusion_con_sem = mse_loss(matte_con_sem, logit_dict['semantic'])
loss_fusion = loss_fusion_l1 + loss_fusion_comp + loss_fusion_con_sem
loss['fusion'] = loss_fusion
loss['fusion_l1'] = loss_fusion_l1
loss['fusion_comp'] = loss_fusion_comp
loss['fusion_con_sem'] = loss_fusion_con_sem
loss['all'] = loss['semantic'] + loss['detail'] + loss['fusion']
return loss
def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)
class MODNetHead(nn.Layer):
def __init__(self, hr_channels, backbone_channels):
super().__init__()
self.lr_branch = LRBranch(backbone_channels)
self.hr_branch = HRBranch(hr_channels, backbone_channels)
self.f_branch = FusionBranch(hr_channels, backbone_channels)
self.init_weight()
def forward(self, inputs, feat_list):
pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(feat_list)
pred_detail, hr2x = self.hr_branch(inputs['img'], enc2x, enc4x, lr8x)
pred_matte = self.f_branch(inputs['img'], lr8x, hr2x)
if self.training:
logit_dict = {
'semantic': pred_semantic,
'detail': pred_detail,
'matte': pred_matte
}
return logit_dict
else:
return pred_matte
def init_weight(self):
for layer in self.sublayers():
if isinstance(layer, nn.Conv2D):
param_init.kaiming_uniform(layer.weight)
class FusionBranch(nn.Layer):
def __init__(self, hr_channels, enc_channels):
super().__init__()
self.conv_lr4x = Conv2dIBNormRelu(
enc_channels[2], hr_channels, 5, stride=1, padding=2)
self.conv_f2x = Conv2dIBNormRelu(
2 * hr_channels, hr_channels, 3, stride=1, padding=1)
self.conv_f = nn.Sequential(
Conv2dIBNormRelu(
hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
Conv2dIBNormRelu(
int(hr_channels / 2),
1,
1,
stride=1,
padding=0,
with_ibn=False,
with_relu=False))
def forward(self, img, lr8x, hr2x):
lr4x = F.interpolate(
lr8x, scale_factor=2, mode='bilinear', align_corners=False)
lr4x = self.conv_lr4x(lr4x)
lr2x = F.interpolate(
lr4x, scale_factor=2, mode='bilinear', align_corners=False)
f2x = self.conv_f2x(paddle.concat((lr2x, hr2x), axis=1))
f = F.interpolate(
f2x, scale_factor=2, mode='bilinear', align_corners=False)
f = self.conv_f(paddle.concat((f, img), axis=1))
pred_matte = F.sigmoid(f)
return pred_matte
class HRBranch(nn.Layer):
"""
High Resolution Branch of MODNet
"""
def __init__(self, hr_channels, enc_channels):
super().__init__()
self.tohr_enc2x = Conv2dIBNormRelu(
enc_channels[0], hr_channels, 1, stride=1, padding=0)
self.conv_enc2x = Conv2dIBNormRelu(
hr_channels + 3, hr_channels, 3, stride=2, padding=1)
self.tohr_enc4x = Conv2dIBNormRelu(
enc_channels[1], hr_channels, 1, stride=1, padding=0)
self.conv_enc4x = Conv2dIBNormRelu(
2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
self.conv_hr4x = nn.Sequential(
Conv2dIBNormRelu(
2 * hr_channels + enc_channels[2] + 3,
2 * hr_channels,
3,
stride=1,
padding=1),
Conv2dIBNormRelu(
2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(
2 * hr_channels, hr_channels, 3, stride=1, padding=1))
self.conv_hr2x = nn.Sequential(
Conv2dIBNormRelu(
2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(
2 * hr_channels, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1))
self.conv_hr = nn.Sequential(
Conv2dIBNormRelu(
hr_channels + 3, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(
hr_channels,
1,
1,
stride=1,
padding=0,
with_ibn=False,
with_relu=False))
def forward(self, img, enc2x, enc4x, lr8x):
img2x = F.interpolate(
img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
img4x = F.interpolate(
img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
enc2x = self.tohr_enc2x(enc2x)
hr4x = self.conv_enc2x(paddle.concat((img2x, enc2x), axis=1))
enc4x = self.tohr_enc4x(enc4x)
hr4x = self.conv_enc4x(paddle.concat((hr4x, enc4x), axis=1))
lr4x = F.interpolate(
lr8x, scale_factor=2, mode='bilinear', align_corners=False)
hr4x = self.conv_hr4x(paddle.concat((hr4x, lr4x, img4x), axis=1))
hr2x = F.interpolate(
hr4x, scale_factor=2, mode='bilinear', align_corners=False)
hr2x = self.conv_hr2x(paddle.concat((hr2x, enc2x), axis=1))
pred_detail = None
if self.training:
hr = F.interpolate(
hr2x, scale_factor=2, mode='bilinear', align_corners=False)
hr = self.conv_hr(paddle.concat((hr, img), axis=1))
pred_detail = F.sigmoid(hr)
return pred_detail, hr2x
class LRBranch(nn.Layer):
def __init__(self, backbone_channels):
super().__init__()
self.se_block = SEBlock(backbone_channels[4], reduction=4)
self.conv_lr16x = Conv2dIBNormRelu(
backbone_channels[4], backbone_channels[3], 5, stride=1, padding=2)
self.conv_lr8x = Conv2dIBNormRelu(
backbone_channels[3], backbone_channels[2], 5, stride=1, padding=2)
self.conv_lr = Conv2dIBNormRelu(
backbone_channels[2],
1,
3,
stride=2,
padding=1,
with_ibn=False,
with_relu=False)
def forward(self, feat_list):
enc2x, enc4x, enc32x = feat_list[0], feat_list[1], feat_list[4]
enc32x = self.se_block(enc32x)
lr16x = F.interpolate(
enc32x, scale_factor=2, mode='bilinear', align_corners=False)
lr16x = self.conv_lr16x(lr16x)
lr8x = F.interpolate(
lr16x, scale_factor=2, mode='bilinear', align_corners=False)
lr8x = self.conv_lr8x(lr8x)
pred_semantic = None
if self.training:
lr = self.conv_lr(lr8x)
pred_semantic = F.sigmoid(lr)
return pred_semantic, lr8x, [enc2x, enc4x]
class IBNorm(nn.Layer):
"""
Combine Instance Norm and Batch Norm into One Layer
"""
def __init__(self, in_channels):
super().__init__()
self.bnorm_channels = in_channels // 2
self.inorm_channels = in_channels - self.bnorm_channels
self.bnorm = nn.BatchNorm2D(self.bnorm_channels)
self.inorm = nn.InstanceNorm2D(self.inorm_channels)
def forward(self, x):
bn_x = self.bnorm(x[:, :self.bnorm_channels, :, :])
in_x = self.inorm(x[:, self.bnorm_channels:, :, :])
return paddle.concat((bn_x, in_x), 1)
class Conv2dIBNormRelu(nn.Layer):
"""
Convolution + IBNorm + Relu
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias_attr=None,
with_ibn=True,
with_relu=True):
super().__init__()
layers = [
nn.Conv2D(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias_attr)
]
if with_ibn:
layers.append(IBNorm(out_channels))
if with_relu:
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class SEBlock(nn.Layer):
"""
SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
"""
def __init__(self, num_channels, reduction=1):
super().__init__()
self.pool = nn.AdaptiveAvgPool2D(1)
self.conv = nn.Sequential(
nn.Conv2D(
num_channels,
int(num_channels // reduction),
1,
bias_attr=False), nn.ReLU(),
nn.Conv2D(
int(num_channels // reduction),
num_channels,
1,
bias_attr=False), nn.Sigmoid())
def forward(self, x):
w = self.pool(x)
w = self.conv(w)
return w * x
class GaussianBlurLayer(nn.Layer):
""" Add Gaussian Blur to a 4D tensors
This layer takes a 4D tensor of {N, C, H, W} as input.
The Gaussian blur will be performed in given channel number (C) splitly.
"""
def __init__(self, channels, kernel_size):
"""
Args:
channels (int): Channel for input tensor
kernel_size (int): Size of the kernel used in blurring
"""
super(GaussianBlurLayer, self).__init__()
self.channels = channels
self.kernel_size = kernel_size
assert self.kernel_size % 2 != 0
self.op = nn.Sequential(
nn.Pad2D(int(self.kernel_size / 2), mode='reflect'),
nn.Conv2D(
channels,
channels,
self.kernel_size,
stride=1,
padding=0,
bias_attr=False,
groups=channels))
self._init_kernel()
self.op[1].weight.stop_gradient = True
def forward(self, x):
"""
Args:
x (paddle.Tensor): input 4D tensor
Returns:
paddle.Tensor: Blurred version of the input
"""
if not len(list(x.shape)) == 4:
print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
exit()
elif not x.shape[1] == self.channels:
print('In \'GaussianBlurLayer\', the required channel ({0}) is'
'not the same as input ({1})\n'.format(
self.channels, x.shape[1]))
exit()
return self.op(x)
def _init_kernel(self):
sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
n = np.zeros((self.kernel_size, self.kernel_size))
i = int(self.kernel_size / 2)
n[i, i] = 1
kernel = scipy.ndimage.gaussian_filter(n, sigma)
kernel = kernel.astype('float32')
kernel = kernel[np.newaxis, np.newaxis, :, :]
paddle.assign(kernel, self.op[1].weight)