File size: 4,270 Bytes
9667e74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Date: 2023-03-14
# Creater: zejunyang
# Function: 边缘注意力层。

import torch
import torch.nn as nn
import torch.nn.functional as F

from NTED.base_function import Blur


class ResBlock(nn.Module):
    def __init__(self, in_nc, out_nc, scale='down'): # , norm_layer=nn.BatchNorm2d
        super(ResBlock, self).__init__()
        use_bias = True
        assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'"

        if scale == 'same':
            # self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True)
            self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=True)
        if scale == 'up':
            self.scale = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear'),
                nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True)
            )
        if scale == 'down':
            self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias)
            
        self.block = nn.Sequential(
            nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
            # norm_layer(out_nc),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
            # norm_layer(out_nc)
        )
        self.relu = nn.ReLU(inplace=True)
        # self.padding = nn.ReplicationPad2d(padding=(0, 1, 0, 0))

    def forward(self, x):
        residual = self.scale(x)
        return self.relu(residual + self.block(residual))


class Edge_Attn(nn.Module):
    def __init__(self, in_channels=3):
        super(Edge_Attn, self).__init__()
        self.in_channels = in_channels
        
        blur_kernel=[1, 3, 3, 3, 1]
        self.blur = Blur(blur_kernel, pad=(2, 2), upsample_factor=1)
        
        # self.conv = nn.Conv2d(self.in_channels, self.in_channels, 3, padding=1, bias=False)
        self.res_block = ResBlock(self.in_channels, self.in_channels, scale='same')
        self.sigmoid = nn.Sigmoid()
        
    def gradient(self, x):
        h_x = x.size()[2]
        w_x = x.size()[3]
        stride = 3
        r = F.pad(x, (0, stride, 0, 0), mode='replicate')[:, :, :, stride:]
        l = F.pad(x, (stride, 0, 0, 0), mode='replicate')[:, :, :, :w_x]
        t = F.pad(x, (0, 0, stride, 0), mode='replicate')[:, :, :h_x, :]
        b = F.pad(x, (0, 0, 0, stride), mode='replicate')[:, :, stride:, :]
        xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5)
        xgrad = self.blur(xgrad)
        return xgrad
    
    def forward(self, x):
        # feature_edge = self.gradient(x).detach()
        # attn = self.conv(feature_edge)
        
        for b in range(x.shape[0]):
            for c in range(x.shape[1]):
                if c == 0:
                    channel_edge = self.gradient(x[b:b+1, c:c+1])
                else:
                    channel_edge = torch.concat([channel_edge, self.gradient(x[b:b+1, c:c+1])], dim=1)
            if b == 0:
                feature_edge = channel_edge
            else:
                feature_edge = torch.concat([feature_edge, channel_edge], dim=0)
        feature_edge = feature_edge.detach()
        feature_edge = x * feature_edge
        attn = self.res_block(feature_edge)
        attn = self.sigmoid(attn)
        
        # out = x * attn
        
        out = x * attn + x
        
        return out
        


if __name__ == '__main__':
    from PIL import Image
    import numpy as np
    import cv2
    
    edg_atten = Edge_Attn()
    
    im = Image.open('/apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset/fake_images/001400.png')
    npim = np.array(im,dtype=np.float32)
    npim = cv2.cvtColor(npim, cv2.COLOR_RGB2GRAY)
    
    # npim = npim[:, :, 2]
    tim = torch.from_numpy(npim).unsqueeze_(0).unsqueeze_(0)
    edge = edg_atten.gradient(tim)
    npgrad = edge.squeeze(0).squeeze(0).data.clamp(0,255).numpy()
    Image.fromarray(npgrad.astype('uint8')).save('tmp.png')
    
    # tim = torch.from_numpy(npim).unsqueeze_(0)
    # edge = edg_atten.gradient_1order(tim)
    # npgrad = edge.squeeze(0).data.clamp(0,255).numpy()[:, :, 0]
    # Image.fromarray(npgrad.astype('uint8')).save('tmp.png')