ZHIJI_cv_web_ui / NTED /edge_attention_layer.py
zejunyang
update
9667e74
raw
history blame
No virus
4.27 kB
# Date: 2023-03-14
# Creater: zejunyang
# Function: 边缘注意力层。
import torch
import torch.nn as nn
import torch.nn.functional as F
from NTED.base_function import Blur
class ResBlock(nn.Module):
def __init__(self, in_nc, out_nc, scale='down'): # , norm_layer=nn.BatchNorm2d
super(ResBlock, self).__init__()
use_bias = True
assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'"
if scale == 'same':
# self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True)
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=True)
if scale == 'up':
self.scale = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True)
)
if scale == 'down':
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias)
self.block = nn.Sequential(
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
# norm_layer(out_nc),
nn.ReLU(inplace=True),
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
# norm_layer(out_nc)
)
self.relu = nn.ReLU(inplace=True)
# self.padding = nn.ReplicationPad2d(padding=(0, 1, 0, 0))
def forward(self, x):
residual = self.scale(x)
return self.relu(residual + self.block(residual))
class Edge_Attn(nn.Module):
def __init__(self, in_channels=3):
super(Edge_Attn, self).__init__()
self.in_channels = in_channels
blur_kernel=[1, 3, 3, 3, 1]
self.blur = Blur(blur_kernel, pad=(2, 2), upsample_factor=1)
# self.conv = nn.Conv2d(self.in_channels, self.in_channels, 3, padding=1, bias=False)
self.res_block = ResBlock(self.in_channels, self.in_channels, scale='same')
self.sigmoid = nn.Sigmoid()
def gradient(self, x):
h_x = x.size()[2]
w_x = x.size()[3]
stride = 3
r = F.pad(x, (0, stride, 0, 0), mode='replicate')[:, :, :, stride:]
l = F.pad(x, (stride, 0, 0, 0), mode='replicate')[:, :, :, :w_x]
t = F.pad(x, (0, 0, stride, 0), mode='replicate')[:, :, :h_x, :]
b = F.pad(x, (0, 0, 0, stride), mode='replicate')[:, :, stride:, :]
xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5)
xgrad = self.blur(xgrad)
return xgrad
def forward(self, x):
# feature_edge = self.gradient(x).detach()
# attn = self.conv(feature_edge)
for b in range(x.shape[0]):
for c in range(x.shape[1]):
if c == 0:
channel_edge = self.gradient(x[b:b+1, c:c+1])
else:
channel_edge = torch.concat([channel_edge, self.gradient(x[b:b+1, c:c+1])], dim=1)
if b == 0:
feature_edge = channel_edge
else:
feature_edge = torch.concat([feature_edge, channel_edge], dim=0)
feature_edge = feature_edge.detach()
feature_edge = x * feature_edge
attn = self.res_block(feature_edge)
attn = self.sigmoid(attn)
# out = x * attn
out = x * attn + x
return out
if __name__ == '__main__':
from PIL import Image
import numpy as np
import cv2
edg_atten = Edge_Attn()
im = Image.open('/apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset/fake_images/001400.png')
npim = np.array(im,dtype=np.float32)
npim = cv2.cvtColor(npim, cv2.COLOR_RGB2GRAY)
# npim = npim[:, :, 2]
tim = torch.from_numpy(npim).unsqueeze_(0).unsqueeze_(0)
edge = edg_atten.gradient(tim)
npgrad = edge.squeeze(0).squeeze(0).data.clamp(0,255).numpy()
Image.fromarray(npgrad.astype('uint8')).save('tmp.png')
# tim = torch.from_numpy(npim).unsqueeze_(0)
# edge = edg_atten.gradient_1order(tim)
# npgrad = edge.squeeze(0).data.clamp(0,255).numpy()[:, :, 0]
# Image.fromarray(npgrad.astype('uint8')).save('tmp.png')