File size: 8,741 Bytes
9e7a9f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch
import torch.nn as nn
import model.ops as ops

'''
class Block(nn.Module):
    def __init__(self, 
                 in_channels, out_channels,
                 group=1):
        super(Block, self).__init__()

        self.b1 = ops.EResidualBlock(64, 64, group=group)
        self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
        self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
        self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)

    def forward(self, x):
        c0 = o0 = x

        b1 = self.b1(o0)
        c1 = torch.cat([c0, b1], dim=1)
        o1 = self.c1(c1)
        
        b2 = self.b1(o1)
        c2 = torch.cat([c1, b2], dim=1)
        o2 = self.c2(c2)
        
        b3 = self.b1(o2)
        c3 = torch.cat([c2, b3], dim=1)
        o3 = self.c3(c3)

        return o3
'''
        
class  MFCModule(nn.Module):
	def __init__(self,in_channels,out_channels,gropus=1):
		super(MFCModule,self).__init__()
		kernel_size =3
		padding = 1
		features = 64
		features1 = 32
		distill_rate = 0.5
		self.distilled_channels = int(features*distill_rate)
		self.remaining_channels = int(features-self.distilled_channels)
		self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
		self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
		self.conv1_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		self.conv2_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		self.conv3_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		self.conv7_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		self.conv8_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		'''
		self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
		self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
		self.conv2_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
		self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
		self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
		self.conv4_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
		self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
		self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=2*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
		'''
		self.ReLU = nn.ReLU(inplace=True)
	def forward(self,input):
		dit1,remain1 = torch.split(input,(self.distilled_channels,self.remaining_channels),dim=1)
		out1_1=self.conv1_1(dit1)
		out1_1_t = self.ReLU(out1_1)
		out2_1=self.conv2_1(out1_1_t)
		out3_1=self.conv3_1(out2_1)
		out1_2=self.conv1_1(remain1)
		out1_2_t = self.ReLU(out1_2)
		out2_2=self.conv2_1(out1_2_t)
		out3_2=self.conv3_1(out2_2)
		#out3 = torch.cat([out1_1,out3_1],dim=1)
		#out3_t = torch.cat([out1_2,out3_2],dim=1)
		out3_t = torch.cat([out3_1,out3_2],dim=1) 
		out3 = self.ReLU(out3_t)
		#out3 = input+out3
		out1_1t = self.conv1_1_1(input)
		out1_2t1 = self.conv2_1_1(out1_1t)
		out1_3t1 = self.conv3_1_1(out1_2t1)
		out1_3t1 = out3+out1_3t1
		out4_1=self.conv4_1(out1_3t1)
		out5_1=self.conv5_1(out4_1)
		out6_1=self.conv6_1(out5_1)
		out7_1=self.conv7_1(out6_1)
		out8_1=self.conv8_1(out7_1)
		out8_1=out8_1+input+out4_1 
		'''
		out1_c = self.conv1_1(input)
		dit1,remain1 = torch.split(out1_c,(self.distilled_channels,self.remaining_channels),dim=1)
		out1_r = self.ReLU(remain1)
		out1_d = self.ReLU(dit1)
		out2_r = self.conv2_1(out1_r)        
		out2_d = self.conv2_2(out1_d)
		out2 = torch.cat([out2_r,out2_d],dim=1)
		out2_r = torch.cat([remain1,out2_r],dim=1)
		out2_d = torch.cat([dit1,out2_d],dim=1)
		out2_1 = out2+out2_r+out2_d
		out2 = self.ReLU(out2_1)
		out3 = self.conv3_1(out2)
		dit3,remain3 = torch.split(out3,(self.distilled_channels,self.remaining_channels),dim=1)
		out3_r = self.ReLU(remain3)
		out3_d = self.ReLU(dit3)
		out4_r = self.conv4_1(out3_r)
		out4_d = self.conv4_2(out3_d)
		out4 = torch.cat([out4_r,out4_d],dim=1)
		out4_r = torch.cat([remain3,out4_r],dim=1)
		out4_d = torch.cat([dit3,out4_d],dim=1)
		out4_1 = out4+out4_r+out4_d
		out4 = self.ReLU(out4_1)
		out5 = self.conv5_1(out4)
		out5_1 = torch.cat([out3,out5],dim=1)
		out5_1 = self.ReLU(out5_1)
		out6_1 = self.conv6_1(out5_1)
		out6_r = input+out6_1
		'''
		return out8_1


class Net(nn.Module):
    def __init__(self, **kwargs):
        super(Net, self).__init__()
        
        scale = kwargs.get("scale") #value of scale is scale. 
        multi_scale = kwargs.get("multi_scale") # value of multi_scale is multi_scale in args.
        group = kwargs.get("group", 1) #if valule of group isn't given, group is 1.
        kernel_size = 3 #tcw 201904091123
        kernel_size1 = 1 #tcw 201904091123
        padding1 = 0 #tcw 201904091124
        padding = 1     #tcw201904091123
        features = 64   #tcw201904091124
        groups = 1       #tcw201904091124
        channels = 3
        features1 = 64
        self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
        self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
        '''
           in_channels, out_channels, kernel_size, stride, padding,dialation, groups,
        '''

        self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=channels,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
        self.b1 = MFCModule(features,features)
        self.b2 = MFCModule(features,features)
        self.b3 = MFCModule(features,features)
        self.b4 = MFCModule(features,features)
        self.b5 = MFCModule(features,features)
        self.b6 = MFCModule(features,features)
        self.ReLU=nn.ReLU(inplace=True)
        #self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=3,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
        self.upsample = ops.UpsampleBlock(64, scale=scale, multi_scale=multi_scale,group=1)
    def forward(self, x, scale):
        x = self.sub_mean(x)
        x1 = self.conv1_1(x)
        b1 = self.b1(x1)
        b2 = self.b2(b1)
        b3 = self.b3(b2)
        b4 = self.b4(b3)
        b5 = self.b5(b4)
        b5 = b5+b1
        b6 = self.b6(b5)
        b6 = b6+x1
        #b6 = torch.cat([b1,b2,b3,b4,b5,b6],dim=1)
        #b6 = x1+b1+b2+b3+b4+b5+b6
        #x2 = x1+b1+b2+b3+b4+b5+b6
        x2 = self.conv2(b6)
        temp = self.upsample(x2, scale=scale)
        #temp1 = self.upsample(x1, scale=scale)
        #temp = temp+temp1
        #temp2 = self.ReLU(temp)
        out = self.conv3(temp)
        out = self.add_mean(out)
        return out