Spaces:
Running
Running
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting | |
# and https://github.com/open-mmlab/mmediting | |
import paddle | |
import paddle.nn as nn | |
import paddle.nn.functional as F | |
from paddleseg.cvlibs import param_init | |
class GuidedCxtAtten(nn.Layer): | |
def __init__(self, | |
out_channels, | |
guidance_channels, | |
kernel_size=3, | |
stride=1, | |
rate=2): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.rate = rate | |
self.stride = stride | |
self.guidance_conv = nn.Conv2D( | |
in_channels=guidance_channels, | |
out_channels=guidance_channels // 2, | |
kernel_size=1) | |
self.out_conv = nn.Sequential( | |
nn.Conv2D( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
bias_attr=False), | |
nn.BatchNorm(out_channels)) | |
self.init_weight() | |
def init_weight(self): | |
param_init.xavier_uniform(self.guidance_conv.weight) | |
param_init.constant_init(self.guidance_conv.bias, value=0.0) | |
param_init.xavier_uniform(self.out_conv[0].weight) | |
param_init.constant_init(self.out_conv[1].weight, value=1e-3) | |
param_init.constant_init(self.out_conv[1].bias, value=0.0) | |
def forward(self, img_feat, alpha_feat, unknown=None, softmax_scale=1.): | |
img_feat = self.guidance_conv(img_feat) | |
img_feat = F.interpolate( | |
img_feat, scale_factor=1 / self.rate, mode='nearest') | |
# process unknown mask | |
unknown, softmax_scale = self.process_unknown_mask(unknown, img_feat, | |
softmax_scale) | |
img_ps, alpha_ps, unknown_ps = self.extract_feature_maps_patches( | |
img_feat, alpha_feat, unknown) | |
self_mask = self.get_self_correlation_mask(img_feat) | |
# split tensors by batch dimension; tuple is returned | |
img_groups = paddle.split(img_feat, 1, axis=0) | |
img_ps_groups = paddle.split(img_ps, 1, axis=0) | |
alpha_ps_groups = paddle.split(alpha_ps, 1, axis=0) | |
unknown_ps_groups = paddle.split(unknown_ps, 1, axis=0) | |
scale_groups = paddle.split(softmax_scale, 1, axis=0) | |
groups = (img_groups, img_ps_groups, alpha_ps_groups, unknown_ps_groups, | |
scale_groups) | |
y = [] | |
for img_i, img_ps_i, alpha_ps_i, unknown_ps_i, scale_i in zip(*groups): | |
# conv for compare | |
similarity_map = self.compute_similarity_map(img_i, img_ps_i) | |
gca_score = self.compute_guided_attention_score( | |
similarity_map, unknown_ps_i, scale_i, self_mask) | |
yi = self.propagate_alpha_feature(gca_score, alpha_ps_i) | |
y.append(yi) | |
y = paddle.concat(y, axis=0) # back to the mini-batch | |
y = paddle.reshape(y, alpha_feat.shape) | |
y = self.out_conv(y) + alpha_feat | |
return y | |
def extract_feature_maps_patches(self, img_feat, alpha_feat, unknown): | |
# extract image feature patches with shape: | |
# (N, img_h*img_w, img_c, img_ks, img_ks) | |
img_ks = self.kernel_size | |
img_ps = self.extract_patches(img_feat, img_ks, self.stride) | |
# extract alpha feature patches with shape: | |
# (N, img_h*img_w, alpha_c, alpha_ks, alpha_ks) | |
alpha_ps = self.extract_patches(alpha_feat, self.rate * 2, self.rate) | |
# extract unknown mask patches with shape: (N, img_h*img_w, 1, 1) | |
unknown_ps = self.extract_patches(unknown, img_ks, self.stride) | |
unknown_ps = unknown_ps.squeeze(axis=2) # squeeze channel dimension | |
unknown_ps = unknown_ps.mean(axis=[2, 3], keepdim=True) | |
return img_ps, alpha_ps, unknown_ps | |
def extract_patches(self, x, kernel_size, stride): | |
n, c, _, _ = x.shape | |
x = self.pad(x, kernel_size, stride) | |
x = F.unfold(x, [kernel_size, kernel_size], strides=[stride, stride]) | |
x = paddle.transpose(x, (0, 2, 1)) | |
x = paddle.reshape(x, (n, -1, c, kernel_size, kernel_size)) | |
return x | |
def pad(self, x, kernel_size, stride): | |
left = (kernel_size - stride + 1) // 2 | |
right = (kernel_size - stride) // 2 | |
pad = (left, right, left, right) | |
return F.pad(x, pad, mode='reflect') | |
def compute_guided_attention_score(self, similarity_map, unknown_ps, scale, | |
self_mask): | |
# scale the correlation with predicted scale factor for known and | |
# unknown area | |
unknown_scale, known_scale = scale[0] | |
out = similarity_map * ( | |
unknown_scale * paddle.greater_than(unknown_ps, | |
paddle.to_tensor([0.])) + | |
known_scale * paddle.less_equal(unknown_ps, paddle.to_tensor([0.]))) | |
# mask itself, self-mask only applied to unknown area | |
out = out + self_mask * unknown_ps | |
gca_score = F.softmax(out, axis=1) | |
return gca_score | |
def propagate_alpha_feature(self, gca_score, alpha_ps): | |
alpha_ps = alpha_ps[0] # squeeze dim 0 | |
if self.rate == 1: | |
gca_score = self.pad(gca_score, kernel_size=2, stride=1) | |
alpha_ps = paddle.transpose(alpha_ps, (1, 0, 2, 3)) | |
out = F.conv2d(gca_score, alpha_ps) / 4. | |
else: | |
out = F.conv2d_transpose( | |
gca_score, alpha_ps, stride=self.rate, padding=1) / 4. | |
return out | |
def compute_similarity_map(self, img_feat, img_ps): | |
img_ps = img_ps[0] # squeeze dim 0 | |
# convolve the feature to get correlation (similarity) map | |
img_ps_normed = img_ps / paddle.clip(self.l2_norm(img_ps), 1e-4) | |
img_feat = F.pad(img_feat, (1, 1, 1, 1), mode='reflect') | |
similarity_map = F.conv2d(img_feat, img_ps_normed) | |
return similarity_map | |
def get_self_correlation_mask(self, img_feat): | |
_, _, h, w = img_feat.shape | |
self_mask = F.one_hot( | |
paddle.reshape(paddle.arange(h * w), (h, w)), | |
num_classes=int(h * w)) | |
self_mask = paddle.transpose(self_mask, (2, 0, 1)) | |
self_mask = paddle.reshape(self_mask, (1, h * w, h, w)) | |
return self_mask * (-1e4) | |
def process_unknown_mask(self, unknown, img_feat, softmax_scale): | |
n, _, h, w = img_feat.shape | |
if unknown is not None: | |
unknown = unknown.clone() | |
unknown = F.interpolate( | |
unknown, scale_factor=1 / self.rate, mode='nearest') | |
unknown_mean = unknown.mean(axis=[2, 3]) | |
known_mean = 1 - unknown_mean | |
unknown_scale = paddle.clip( | |
paddle.sqrt(unknown_mean / known_mean), 0.1, 10) | |
known_scale = paddle.clip( | |
paddle.sqrt(known_mean / unknown_mean), 0.1, 10) | |
softmax_scale = paddle.concat([unknown_scale, known_scale], axis=1) | |
else: | |
unknown = paddle.ones([n, 1, h, w]) | |
softmax_scale = paddle.reshape( | |
paddle.to_tensor([softmax_scale, softmax_scale]), (1, 2)) | |
softmax_scale = paddle.expand(softmax_scale, (n, 2)) | |
return unknown, softmax_scale | |
def l2_norm(x): | |
x = x**2 | |
x = x.sum(axis=[1, 2, 3], keepdim=True) | |
return paddle.sqrt(x) | |