# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict import paddle import paddle.nn as nn import paddle.nn.functional as F from paddleseg.models import layers from paddleseg import utils from paddleseg.cvlibs import manager from ppmatting.models.losses import MRSD @manager.MODELS.add_component class DIM(nn.Layer): """ The DIM implementation based on PaddlePaddle. The original article refers to Ning Xu, et, al. "Deep Image Matting" (https://arxiv.org/pdf/1908.07919.pdf). Args: backbone: backbone model. stage (int, optional): The stage of model. Defautl: 3. decoder_input_channels(int, optional): The channel of decoder input. Default: 512. pretrained(str, optional): The path of pretrianed model. Defautl: None. """ def __init__(self, backbone, stage=3, decoder_input_channels=512, pretrained=None): super().__init__() self.backbone = backbone self.pretrained = pretrained self.stage = stage self.loss_func_dict = None decoder_output_channels = [64, 128, 256, 512] self.decoder = Decoder( input_channels=decoder_input_channels, output_channels=decoder_output_channels) if self.stage == 2: for param in self.backbone.parameters(): param.stop_gradient = True for param in self.decoder.parameters(): param.stop_gradient = True if self.stage >= 2: self.refine = Refine() self.init_weight() def forward(self, inputs): input_shape = paddle.shape(inputs['img'])[-2:] x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1) fea_list = self.backbone(x) # decoder stage up_shape = [] for i in range(5): up_shape.append(paddle.shape(fea_list[i])[-2:]) alpha_raw = self.decoder(fea_list, up_shape) alpha_raw = F.interpolate( alpha_raw, input_shape, mode='bilinear', align_corners=False) logit_dict = {'alpha_raw': alpha_raw} if self.stage < 2: return logit_dict if self.stage >= 2: # refine stage refine_input = paddle.concat([inputs['img'], alpha_raw], axis=1) alpha_refine = self.refine(refine_input) # finally alpha alpha_pred = alpha_refine + alpha_raw alpha_pred = F.interpolate( alpha_pred, input_shape, mode='bilinear', align_corners=False) if not self.training: alpha_pred = paddle.clip(alpha_pred, min=0, max=1) logit_dict['alpha_pred'] = alpha_pred if self.training: loss_dict = self.loss(logit_dict, inputs) return logit_dict, loss_dict else: return alpha_pred def loss(self, logit_dict, label_dict, loss_func_dict=None): if loss_func_dict is None: if self.loss_func_dict is None: self.loss_func_dict = defaultdict(list) self.loss_func_dict['alpha_raw'].append(MRSD()) self.loss_func_dict['comp'].append(MRSD()) self.loss_func_dict['alpha_pred'].append(MRSD()) else: self.loss_func_dict = loss_func_dict loss = {} mask = label_dict['trimap'] == 128 loss['all'] = 0 if self.stage != 2: loss['alpha_raw'] = self.loss_func_dict['alpha_raw'][0]( logit_dict['alpha_raw'], label_dict['alpha'], mask) loss['alpha_raw'] = 0.5 * loss['alpha_raw'] loss['all'] = loss['all'] + loss['alpha_raw'] if self.stage == 1 or self.stage == 3: comp_pred = logit_dict['alpha_raw'] * label_dict['fg'] + \ (1 - logit_dict['alpha_raw']) * label_dict['bg'] loss['comp'] = self.loss_func_dict['comp'][0]( comp_pred, label_dict['img'], mask) loss['comp'] = 0.5 * loss['comp'] loss['all'] = loss['all'] + loss['comp'] if self.stage == 2 or self.stage == 3: loss['alpha_pred'] = self.loss_func_dict['alpha_pred'][0]( logit_dict['alpha_pred'], label_dict['alpha'], mask) loss['all'] = loss['all'] + loss['alpha_pred'] return loss def init_weight(self): if self.pretrained is not None: utils.load_entire_model(self, self.pretrained) # bilinear interpolate skip connect class Up(nn.Layer): def __init__(self, input_channels, output_channels): super().__init__() self.conv = layers.ConvBNReLU( input_channels, output_channels, kernel_size=5, padding=2, bias_attr=False) def forward(self, x, skip, output_shape): x = F.interpolate( x, size=output_shape, mode='bilinear', align_corners=False) x = x + skip x = self.conv(x) x = F.relu(x) return x class Decoder(nn.Layer): def __init__(self, input_channels, output_channels=(64, 128, 256, 512)): super().__init__() self.deconv6 = nn.Conv2D( input_channels, input_channels, kernel_size=1, bias_attr=False) self.deconv5 = Up(input_channels, output_channels[-1]) self.deconv4 = Up(output_channels[-1], output_channels[-2]) self.deconv3 = Up(output_channels[-2], output_channels[-3]) self.deconv2 = Up(output_channels[-3], output_channels[-4]) self.deconv1 = Up(output_channels[-4], 64) self.alpha_conv = nn.Conv2D( 64, 1, kernel_size=5, padding=2, bias_attr=False) def forward(self, fea_list, shape_list): x = fea_list[-1] x = self.deconv6(x) x = self.deconv5(x, fea_list[4], shape_list[4]) x = self.deconv4(x, fea_list[3], shape_list[3]) x = self.deconv3(x, fea_list[2], shape_list[2]) x = self.deconv2(x, fea_list[1], shape_list[1]) x = self.deconv1(x, fea_list[0], shape_list[0]) alpha = self.alpha_conv(x) alpha = F.sigmoid(alpha) return alpha class Refine(nn.Layer): def __init__(self): super().__init__() self.conv1 = layers.ConvBNReLU( 4, 64, kernel_size=3, padding=1, bias_attr=False) self.conv2 = layers.ConvBNReLU( 64, 64, kernel_size=3, padding=1, bias_attr=False) self.conv3 = layers.ConvBNReLU( 64, 64, kernel_size=3, padding=1, bias_attr=False) self.alpha_pred = layers.ConvBNReLU( 64, 1, kernel_size=3, padding=1, bias_attr=False) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) alpha = self.alpha_pred(x) return alpha