vivym's picture
init
4a582ec
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.models import layers
from paddleseg import utils
from paddleseg.cvlibs import manager
from ppmatting.models.losses import MRSD
@manager.MODELS.add_component
class DIM(nn.Layer):
"""
The DIM implementation based on PaddlePaddle.
The original article refers to
Ning Xu, et, al. "Deep Image Matting"
(https://arxiv.org/pdf/1908.07919.pdf).
Args:
backbone: backbone model.
stage (int, optional): The stage of model. Defautl: 3.
decoder_input_channels(int, optional): The channel of decoder input. Default: 512.
pretrained(str, optional): The path of pretrianed model. Defautl: None.
"""
def __init__(self,
backbone,
stage=3,
decoder_input_channels=512,
pretrained=None):
super().__init__()
self.backbone = backbone
self.pretrained = pretrained
self.stage = stage
self.loss_func_dict = None
decoder_output_channels = [64, 128, 256, 512]
self.decoder = Decoder(
input_channels=decoder_input_channels,
output_channels=decoder_output_channels)
if self.stage == 2:
for param in self.backbone.parameters():
param.stop_gradient = True
for param in self.decoder.parameters():
param.stop_gradient = True
if self.stage >= 2:
self.refine = Refine()
self.init_weight()
def forward(self, inputs):
input_shape = paddle.shape(inputs['img'])[-2:]
x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1)
fea_list = self.backbone(x)
# decoder stage
up_shape = []
for i in range(5):
up_shape.append(paddle.shape(fea_list[i])[-2:])
alpha_raw = self.decoder(fea_list, up_shape)
alpha_raw = F.interpolate(
alpha_raw, input_shape, mode='bilinear', align_corners=False)
logit_dict = {'alpha_raw': alpha_raw}
if self.stage < 2:
return logit_dict
if self.stage >= 2:
# refine stage
refine_input = paddle.concat([inputs['img'], alpha_raw], axis=1)
alpha_refine = self.refine(refine_input)
# finally alpha
alpha_pred = alpha_refine + alpha_raw
alpha_pred = F.interpolate(
alpha_pred, input_shape, mode='bilinear', align_corners=False)
if not self.training:
alpha_pred = paddle.clip(alpha_pred, min=0, max=1)
logit_dict['alpha_pred'] = alpha_pred
if self.training:
loss_dict = self.loss(logit_dict, inputs)
return logit_dict, loss_dict
else:
return alpha_pred
def loss(self, logit_dict, label_dict, loss_func_dict=None):
if loss_func_dict is None:
if self.loss_func_dict is None:
self.loss_func_dict = defaultdict(list)
self.loss_func_dict['alpha_raw'].append(MRSD())
self.loss_func_dict['comp'].append(MRSD())
self.loss_func_dict['alpha_pred'].append(MRSD())
else:
self.loss_func_dict = loss_func_dict
loss = {}
mask = label_dict['trimap'] == 128
loss['all'] = 0
if self.stage != 2:
loss['alpha_raw'] = self.loss_func_dict['alpha_raw'][0](
logit_dict['alpha_raw'], label_dict['alpha'], mask)
loss['alpha_raw'] = 0.5 * loss['alpha_raw']
loss['all'] = loss['all'] + loss['alpha_raw']
if self.stage == 1 or self.stage == 3:
comp_pred = logit_dict['alpha_raw'] * label_dict['fg'] + \
(1 - logit_dict['alpha_raw']) * label_dict['bg']
loss['comp'] = self.loss_func_dict['comp'][0](
comp_pred, label_dict['img'], mask)
loss['comp'] = 0.5 * loss['comp']
loss['all'] = loss['all'] + loss['comp']
if self.stage == 2 or self.stage == 3:
loss['alpha_pred'] = self.loss_func_dict['alpha_pred'][0](
logit_dict['alpha_pred'], label_dict['alpha'], mask)
loss['all'] = loss['all'] + loss['alpha_pred']
return loss
def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)
# bilinear interpolate skip connect
class Up(nn.Layer):
def __init__(self, input_channels, output_channels):
super().__init__()
self.conv = layers.ConvBNReLU(
input_channels,
output_channels,
kernel_size=5,
padding=2,
bias_attr=False)
def forward(self, x, skip, output_shape):
x = F.interpolate(
x, size=output_shape, mode='bilinear', align_corners=False)
x = x + skip
x = self.conv(x)
x = F.relu(x)
return x
class Decoder(nn.Layer):
def __init__(self, input_channels, output_channels=(64, 128, 256, 512)):
super().__init__()
self.deconv6 = nn.Conv2D(
input_channels, input_channels, kernel_size=1, bias_attr=False)
self.deconv5 = Up(input_channels, output_channels[-1])
self.deconv4 = Up(output_channels[-1], output_channels[-2])
self.deconv3 = Up(output_channels[-2], output_channels[-3])
self.deconv2 = Up(output_channels[-3], output_channels[-4])
self.deconv1 = Up(output_channels[-4], 64)
self.alpha_conv = nn.Conv2D(
64, 1, kernel_size=5, padding=2, bias_attr=False)
def forward(self, fea_list, shape_list):
x = fea_list[-1]
x = self.deconv6(x)
x = self.deconv5(x, fea_list[4], shape_list[4])
x = self.deconv4(x, fea_list[3], shape_list[3])
x = self.deconv3(x, fea_list[2], shape_list[2])
x = self.deconv2(x, fea_list[1], shape_list[1])
x = self.deconv1(x, fea_list[0], shape_list[0])
alpha = self.alpha_conv(x)
alpha = F.sigmoid(alpha)
return alpha
class Refine(nn.Layer):
def __init__(self):
super().__init__()
self.conv1 = layers.ConvBNReLU(
4, 64, kernel_size=3, padding=1, bias_attr=False)
self.conv2 = layers.ConvBNReLU(
64, 64, kernel_size=3, padding=1, bias_attr=False)
self.conv3 = layers.ConvBNReLU(
64, 64, kernel_size=3, padding=1, bias_attr=False)
self.alpha_pred = layers.ConvBNReLU(
64, 1, kernel_size=3, padding=1, bias_attr=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
alpha = self.alpha_pred(x)
return alpha