# Date: 2023-03-14 # Creater: zejunyang # Function: 边缘注意力层。 import torch import torch.nn as nn import torch.nn.functional as F from NTED.base_function import Blur class ResBlock(nn.Module): def __init__(self, in_nc, out_nc, scale='down'): # , norm_layer=nn.BatchNorm2d super(ResBlock, self).__init__() use_bias = True assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'" if scale == 'same': # self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True) self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=True) if scale == 'up': self.scale = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True) ) if scale == 'down': self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias) self.block = nn.Sequential( nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), # norm_layer(out_nc), nn.ReLU(inplace=True), nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), # norm_layer(out_nc) ) self.relu = nn.ReLU(inplace=True) # self.padding = nn.ReplicationPad2d(padding=(0, 1, 0, 0)) def forward(self, x): residual = self.scale(x) return self.relu(residual + self.block(residual)) class Edge_Attn(nn.Module): def __init__(self, in_channels=3): super(Edge_Attn, self).__init__() self.in_channels = in_channels blur_kernel=[1, 3, 3, 3, 1] self.blur = Blur(blur_kernel, pad=(2, 2), upsample_factor=1) # self.conv = nn.Conv2d(self.in_channels, self.in_channels, 3, padding=1, bias=False) self.res_block = ResBlock(self.in_channels, self.in_channels, scale='same') self.sigmoid = nn.Sigmoid() def gradient(self, x): h_x = x.size()[2] w_x = x.size()[3] stride = 3 r = F.pad(x, (0, stride, 0, 0), mode='replicate')[:, :, :, stride:] l = F.pad(x, (stride, 0, 0, 0), mode='replicate')[:, :, :, :w_x] t = F.pad(x, (0, 0, stride, 0), mode='replicate')[:, :, :h_x, :] b = F.pad(x, (0, 0, 0, stride), mode='replicate')[:, :, stride:, :] xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5) xgrad = self.blur(xgrad) return xgrad def forward(self, x): # feature_edge = self.gradient(x).detach() # attn = self.conv(feature_edge) for b in range(x.shape[0]): for c in range(x.shape[1]): if c == 0: channel_edge = self.gradient(x[b:b+1, c:c+1]) else: channel_edge = torch.concat([channel_edge, self.gradient(x[b:b+1, c:c+1])], dim=1) if b == 0: feature_edge = channel_edge else: feature_edge = torch.concat([feature_edge, channel_edge], dim=0) feature_edge = feature_edge.detach() feature_edge = x * feature_edge attn = self.res_block(feature_edge) attn = self.sigmoid(attn) # out = x * attn out = x * attn + x return out if __name__ == '__main__': from PIL import Image import numpy as np import cv2 edg_atten = Edge_Attn() im = Image.open('/apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset/fake_images/001400.png') npim = np.array(im,dtype=np.float32) npim = cv2.cvtColor(npim, cv2.COLOR_RGB2GRAY) # npim = npim[:, :, 2] tim = torch.from_numpy(npim).unsqueeze_(0).unsqueeze_(0) edge = edg_atten.gradient(tim) npgrad = edge.squeeze(0).squeeze(0).data.clamp(0,255).numpy() Image.fromarray(npgrad.astype('uint8')).save('tmp.png') # tim = torch.from_numpy(npim).unsqueeze_(0) # edge = edg_atten.gradient_1order(tim) # npgrad = edge.squeeze(0).data.clamp(0,255).numpy()[:, :, 0] # Image.fromarray(npgrad.astype('uint8')).save('tmp.png')