|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from NTED.base_function import Blur |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, in_nc, out_nc, scale='down'): |
|
super(ResBlock, self).__init__() |
|
use_bias = True |
|
assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'" |
|
|
|
if scale == 'same': |
|
|
|
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=True) |
|
if scale == 'up': |
|
self.scale = nn.Sequential( |
|
nn.Upsample(scale_factor=2, mode='bilinear'), |
|
nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True) |
|
) |
|
if scale == 'down': |
|
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias) |
|
|
|
self.block = nn.Sequential( |
|
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), |
|
|
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), |
|
|
|
) |
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
residual = self.scale(x) |
|
return self.relu(residual + self.block(residual)) |
|
|
|
|
|
class Edge_Attn(nn.Module): |
|
def __init__(self, in_channels=3): |
|
super(Edge_Attn, self).__init__() |
|
self.in_channels = in_channels |
|
|
|
blur_kernel=[1, 3, 3, 3, 1] |
|
self.blur = Blur(blur_kernel, pad=(2, 2), upsample_factor=1) |
|
|
|
|
|
self.res_block = ResBlock(self.in_channels, self.in_channels, scale='same') |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def gradient(self, x): |
|
h_x = x.size()[2] |
|
w_x = x.size()[3] |
|
stride = 3 |
|
r = F.pad(x, (0, stride, 0, 0), mode='replicate')[:, :, :, stride:] |
|
l = F.pad(x, (stride, 0, 0, 0), mode='replicate')[:, :, :, :w_x] |
|
t = F.pad(x, (0, 0, stride, 0), mode='replicate')[:, :, :h_x, :] |
|
b = F.pad(x, (0, 0, 0, stride), mode='replicate')[:, :, stride:, :] |
|
xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5) |
|
xgrad = self.blur(xgrad) |
|
return xgrad |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
for b in range(x.shape[0]): |
|
for c in range(x.shape[1]): |
|
if c == 0: |
|
channel_edge = self.gradient(x[b:b+1, c:c+1]) |
|
else: |
|
channel_edge = torch.concat([channel_edge, self.gradient(x[b:b+1, c:c+1])], dim=1) |
|
if b == 0: |
|
feature_edge = channel_edge |
|
else: |
|
feature_edge = torch.concat([feature_edge, channel_edge], dim=0) |
|
feature_edge = feature_edge.detach() |
|
feature_edge = x * feature_edge |
|
attn = self.res_block(feature_edge) |
|
attn = self.sigmoid(attn) |
|
|
|
|
|
|
|
out = x * attn + x |
|
|
|
return out |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
from PIL import Image |
|
import numpy as np |
|
import cv2 |
|
|
|
edg_atten = Edge_Attn() |
|
|
|
im = Image.open('/apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset/fake_images/001400.png') |
|
npim = np.array(im,dtype=np.float32) |
|
npim = cv2.cvtColor(npim, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
tim = torch.from_numpy(npim).unsqueeze_(0).unsqueeze_(0) |
|
edge = edg_atten.gradient(tim) |
|
npgrad = edge.squeeze(0).squeeze(0).data.clamp(0,255).numpy() |
|
Image.fromarray(npgrad.astype('uint8')).save('tmp.png') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|