File size: 1,932 Bytes
0691d6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torch.nn.functional as F


class CSDN_Tem(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(CSDN_Tem, self).__init__()
        self.depth_conv = nn.Conv2d(
            in_channels=in_ch,
            out_channels=in_ch,
            kernel_size=3,
            padding=1,
            groups=in_ch
        )
        self.point_conv = nn.Conv2d(
            in_channels=in_ch,
            out_channels=out_ch,
            kernel_size=1
        )

    def forward(self, input):
        out = self.depth_conv(input)
        out = self.point_conv(out)
        return out

class enhance_net_nopool(nn.Module):
	def __init__(self,scale_factor):
		super(enhance_net_nopool, self).__init__()

		self.relu = nn.ReLU(inplace=True)
		self.scale_factor = scale_factor
		self.upsample = nn.UpsamplingBilinear2d(scale_factor=self.scale_factor)
		number_f = 32

#   zerodce DWC + p-shared
		self.e_conv1 = CSDN_Tem(3, number_f) 
		self.e_conv2 = CSDN_Tem(number_f, number_f) 
		self.e_conv3 = CSDN_Tem(number_f, number_f) 
		self.e_conv4 = CSDN_Tem(number_f, number_f) 
		self.e_conv5 = CSDN_Tem(number_f * 2, number_f) 
		self.e_conv6 = CSDN_Tem(number_f * 2, number_f) 
		self.e_conv7 = CSDN_Tem(number_f * 2, 3) 

	def enhance(self, x, x_r):
		for _ in range(8): x = x + x_r * (torch.pow(x, 2) - x)

		return x
		
	def forward(self, x):
		x_down = x if self.scale_factor==1 else F.interpolate(x, scale_factor = 1 / self.scale_factor, mode='bilinear')

		x1 = self.relu(self.e_conv1(x_down))
		x2 = self.relu(self.e_conv2(x1))
		x3 = self.relu(self.e_conv3(x2))
		x4 = self.relu(self.e_conv4(x3))
		x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1)))
		x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1)))
		x_r = torch.tanh(self.e_conv7(torch.cat([x1, x6], 1)))

		x_r = x_r if self.scale_factor==1 else self.upsample(x_r)
		enhance_image = self.enhance(x, x_r)

		return enhance_image, x_r