charlesnchr commited on
Commit
3715c63
1 Parent(s): ba9a83a

First version with RGB image input

Browse files
Files changed (5) hide show
  1. NNfunctions.py +290 -0
  2. app.py +11 -51
  3. model/DIV2K_randomised_3x3_20200317.pth +3 -0
  4. models.py +1997 -0
  5. requirements.txt +6 -1
NNfunctions.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import math
3
+ import os
4
+
5
+ import torch
6
+ import time
7
+
8
+ import skimage.io
9
+ import skimage.transform
10
+ import matplotlib.pyplot as plt
11
+ import glob
12
+
13
+ import torch.optim as optim
14
+ import torchvision
15
+ import torchvision.transforms as transforms
16
+ from skimage import exposure
17
+
18
+ toTensor = transforms.ToTensor()
19
+ toPIL = transforms.ToPILImage()
20
+
21
+
22
+ import numpy as np
23
+ from PIL import Image
24
+
25
+ from models import *
26
+
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
28
+
29
+ def remove_dataparallel_wrapper(state_dict):
30
+ r"""Converts a DataParallel model to a normal one by removing the "module."
31
+ wrapper in the module dictionary
32
+
33
+ Args:
34
+ state_dict: a torch.nn.DataParallel state dictionary
35
+ """
36
+ from collections import OrderedDict
37
+
38
+ new_state_dict = OrderedDict()
39
+ for k, vl in state_dict.items():
40
+ name = k[7:] # remove 'module.' of DataParallel
41
+ new_state_dict[name] = vl
42
+
43
+ return new_state_dict
44
+
45
+ from argparse import Namespace
46
+
47
+
48
+ def GetOptions():
49
+ # training options
50
+ opt = Namespace()
51
+ opt.model = 'rcan'
52
+ opt.n_resgroups = 3
53
+ opt.n_resblocks = 10
54
+ opt.n_feats = 96
55
+ opt.reduction = 16
56
+ opt.narch = 0
57
+ opt.norm = 'minmax'
58
+
59
+ opt.cpu = False
60
+ opt.multigpu = False
61
+ opt.undomulti = False
62
+ opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')
63
+
64
+ opt.imageSize = 512
65
+ opt.weights = "model/simrec_simin_gtout_rcan_512_2_ntrain790-final.pth"
66
+ opt.root = "model/0080.jpg"
67
+ opt.out = "model/myout"
68
+
69
+ opt.task = 'simin_gtout'
70
+ opt.scale = 1
71
+ opt.nch_in = 9
72
+ opt.nch_out = 1
73
+
74
+
75
+ return opt
76
+
77
+
78
+ def GetOptions_allRnd_0215():
79
+ # training options
80
+ opt = Namespace()
81
+ opt.model = 'rcan'
82
+ opt.n_resgroups = 3
83
+ opt.n_resblocks = 10
84
+ opt.n_feats = 48
85
+ opt.reduction = 16
86
+ opt.narch = 0
87
+ opt.norm = 'adapthist'
88
+
89
+ opt.cpu = False
90
+ opt.multigpu = False
91
+ opt.undomulti = False
92
+ opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')
93
+
94
+ opt.imageSize = 512
95
+ opt.weights = "model/0216_SIMRec_0214_rndAll_rcan_continued.pth"
96
+ opt.root = "model/0080.jpg"
97
+ opt.out = "model/myout"
98
+
99
+ opt.task = 'simin_gtout'
100
+ opt.scale = 1
101
+ opt.nch_in = 9
102
+ opt.nch_out = 1
103
+
104
+
105
+ return opt
106
+
107
+
108
+
109
+ def GetOptions_allRnd_0317():
110
+ # training options
111
+ opt = Namespace()
112
+ opt.model = 'rcan'
113
+ opt.n_resgroups = 3
114
+ opt.n_resblocks = 10
115
+ opt.n_feats = 96
116
+ opt.reduction = 16
117
+ opt.narch = 0
118
+ opt.norm = 'minmax'
119
+
120
+ opt.cpu = False
121
+ opt.multigpu = False
122
+ opt.undomulti = False
123
+ opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')
124
+
125
+ opt.imageSize = 512
126
+ opt.weights = "model/DIV2K_randomised_3x3_20200317.pth"
127
+ opt.root = "model/0080.jpg"
128
+ opt.out = "model/myout"
129
+
130
+ opt.task = 'simin_gtout'
131
+ opt.scale = 1
132
+ opt.nch_in = 9
133
+ opt.nch_out = 1
134
+
135
+
136
+ return opt
137
+
138
+
139
+
140
+ def LoadModel(opt):
141
+ print('Loading model')
142
+ print(opt)
143
+
144
+ net = GetModel(opt)
145
+ print('loading checkpoint',opt.weights)
146
+ checkpoint = torch.load(opt.weights,map_location=opt.device)
147
+
148
+ if type(checkpoint) is dict:
149
+ state_dict = checkpoint['state_dict']
150
+ else:
151
+ state_dict = checkpoint
152
+
153
+ if opt.undomulti:
154
+ state_dict = remove_dataparallel_wrapper(state_dict)
155
+ net.load_state_dict(state_dict)
156
+
157
+ return net
158
+
159
+
160
+ def prepimg(stack,self):
161
+
162
+ inputimg = stack[:9]
163
+
164
+ if self.nch_in == 6:
165
+ inputimg = inputimg[[0,1,3,4,6,7]]
166
+ elif self.nch_in == 3:
167
+ inputimg = inputimg[[0,4,8]]
168
+
169
+ if inputimg.shape[1] > 512 or inputimg.shape[2] > 512:
170
+ print('Over 512x512! Cropping')
171
+ inputimg = inputimg[:,:512,:512]
172
+
173
+
174
+ if self.norm == 'convert': # raw img from microscope, needs normalisation and correct frame ordering
175
+ print('Raw input assumed - converting')
176
+ # NCHW
177
+ # I = np.zeros((9,opt.imageSize,opt.imageSize),dtype='uint16')
178
+
179
+ # for t in range(9):
180
+ # frame = inputimg[t]
181
+ # frame = 120 / np.max(frame) * frame
182
+ # frame = np.rot90(np.rot90(np.rot90(frame)))
183
+ # I[t,:,:] = frame
184
+ # inputimg = I
185
+
186
+ inputimg = np.rot90(inputimg,axes=(1,2))
187
+ inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0]
188
+ for i in range(len(inputimg)):
189
+ inputimg[i] = 100 / np.max(inputimg[i]) * inputimg[i]
190
+ elif 'convert' in self.norm:
191
+ fac = float(self.norm[7:])
192
+ inputimg = np.rot90(inputimg,axes=(1,2))
193
+ inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0]
194
+ for i in range(len(inputimg)):
195
+ inputimg[i] = fac * 255 / np.max(inputimg[i]) * inputimg[i]
196
+
197
+
198
+ inputimg = inputimg.astype('float') / np.max(inputimg) # used to be /255
199
+ widefield = np.mean(inputimg,0)
200
+
201
+ if self.norm == 'adapthist':
202
+ for i in range(len(inputimg)):
203
+ inputimg[i] = exposure.equalize_adapthist(inputimg[i],clip_limit=0.001)
204
+ widefield = exposure.equalize_adapthist(widefield,clip_limit=0.001)
205
+ else:
206
+ # normalise
207
+ inputimg = torch.tensor(inputimg).float()
208
+ widefield = torch.tensor(widefield).float()
209
+ widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield))
210
+
211
+ if self.norm == 'minmax':
212
+ for i in range(len(inputimg)):
213
+ inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i]))
214
+ elif 'minmax' in self.norm:
215
+ fac = float(self.norm[6:])
216
+ for i in range(len(inputimg)):
217
+ inputimg[i] = fac * (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i]))
218
+
219
+
220
+
221
+ # otf = torch.tensor(otf.astype('float') / np.max(otf)).unsqueeze(0).float()
222
+ # gt = torch.tensor(gt.astype('float') / 255).unsqueeze(0).float()
223
+ # simimg = torch.tensor(simimg.astype('float') / 255).unsqueeze(0).float()
224
+ # widefield = torch.mean(inputimg,0).unsqueeze(0)
225
+
226
+
227
+ # normalise
228
+ # gt = (gt - torch.min(gt)) / (torch.max(gt) - torch.min(gt))
229
+ # simimg = (simimg - torch.min(simimg)) / (torch.max(simimg) - torch.min(simimg))
230
+ # widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield))
231
+ inputimg = torch.tensor(inputimg).float()
232
+ widefield = torch.tensor(widefield).float()
233
+ return inputimg,widefield
234
+
235
+ def save_image(data, filename,cmap):
236
+ sizes = np.shape(data)
237
+ fig = plt.figure()
238
+ fig.set_size_inches(1. * sizes[0] / sizes[1], 1, forward = False)
239
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
240
+ ax.set_axis_off()
241
+ fig.add_axes(ax)
242
+ ax.imshow(data, cmap=cmap)
243
+ plt.savefig(filename, dpi = sizes[0])
244
+ plt.close()
245
+
246
+
247
+ def EvaluateModel(net,opt,stack):
248
+
249
+ os.makedirs(opt.out, exist_ok=True)
250
+
251
+ print(stack.shape)
252
+ inputimg, widefield = prepimg(stack, opt)
253
+
254
+ if opt.norm == 'convert' or 'minmax' in opt.norm or 'adapthist' in opt.norm:
255
+ cmap = 'magma'
256
+ else:
257
+ cmap = 'gray'
258
+
259
+ # skimage.io.imsave('%s_wf.png' % outfile,(255*widefield.numpy()).astype('uint8'))
260
+ wf = (255*widefield.numpy()).astype('uint8')
261
+ wf_upscaled = skimage.transform.rescale(wf,1.5,order=3,multichannel=False) # should ideally be done by drawing on client side, in javascript
262
+ # save_image(wf_upscaled,'%s_wf.png' % outfile,cmap)
263
+
264
+ # skimage.io.imsave('%s.tif' % outfile, inputimg.numpy())
265
+
266
+ inputimg = inputimg.unsqueeze(0)
267
+
268
+ with torch.no_grad():
269
+ if opt.cpu:
270
+ sr = net(inputimg)
271
+ else:
272
+ sr = net(inputimg.cuda())
273
+ sr = sr.cpu()
274
+ sr = torch.clamp(sr,min=0,max=1)
275
+ print('min max',inputimg.min(),inputimg.max())
276
+
277
+ pil_sr_img = toPIL(sr[0])
278
+
279
+ if opt.norm == 'convert':
280
+ pil_sr_img = transforms.functional.rotate(pil_sr_img,-90)
281
+
282
+ #pil_sr_img.save('%s.png' % outfile) # true output for downloading, no LUT
283
+ sr_img = np.array(pil_sr_img)
284
+ sr_img = exposure.equalize_adapthist(sr_img,clip_limit=0.01)
285
+ # skimage.io.imsave('%s.png' % outfile, sr_img) # true out for downloading, no LUT
286
+
287
+ # sr_img = skimage.transform.rescale(sr_img,1.5,order=3,multichannel=False) # should ideally be done by drawing on client side, in javascript
288
+ # save_image(sr_img,'%s_sr.png' % outfile,cmap)
289
+ # return outfile + '_sr.png', outfile + '_wf.png', outfile + '.png'
290
+ return sr_img
app.py CHANGED
@@ -7,67 +7,27 @@
7
 
8
  from turtle import title
9
  import gradio as gr
10
- from huggingface_hub import from_pretrained_keras
11
- import tensorflow as tf
12
  import numpy as np
13
  from PIL import Image
14
  import io
15
  import base64
 
16
 
17
-
18
- model = tf.keras.models.load_model("./tf_model.h5")
19
-
20
 
21
  def predict(image):
22
  img = np.array(image)
23
- original_shape = img.shape[:2]
24
-
25
- im = tf.image.resize(img, (128, 128))
26
- im = tf.cast(im, tf.float32) / 255.0
27
- pred_mask = model.predict(im[tf.newaxis, ...])
28
-
29
-
30
- # take the best performing class for each pixel
31
- # the output of argmax looks like this [[1, 2, 0], ...]
32
- pred_mask_arg = tf.argmax(pred_mask, axis=-1)
33
-
34
-
35
- # convert the prediction mask into binary masks for each class
36
- binary_masks = {}
37
-
38
- # when we take tf.argmax() over pred_mask, it becomes a tensor object
39
- # the shape becomes TensorShape object, looking like this TensorShape([128])
40
- # we need to take get shape, convert to list and take the best one
41
-
42
- rows = pred_mask_arg[0][1].get_shape().as_list()[0]
43
- cols = pred_mask_arg[0][2].get_shape().as_list()[0]
44
-
45
- for cls in range(pred_mask.shape[-1]):
46
-
47
- binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
48
-
49
- for row in range(rows):
50
-
51
- for col in range(cols):
52
-
53
- if pred_mask_arg[0][row][col] == cls:
54
-
55
- binary_masks[f"mask_{cls}"][row][col] = 1
56
- else:
57
- binary_masks[f"mask_{cls}"][row][col] = 0
58
-
59
- mask = binary_masks[f"mask_{cls}"]
60
- mask *= 255
61
 
62
- mask = np.array(Image.fromarray(mask).convert("L"))
63
- mask = tf.image.resize(mask[..., tf.newaxis], original_shape)
64
- mask = tf.cast(mask, tf.uint8)
65
- mask = mask.numpy().squeeze()
66
 
67
- return mask
68
 
69
 
70
- title = '<h1 style="text-align: center;">Segment Pets</h1>'
71
 
72
  description = """
73
  ## About
@@ -77,9 +37,9 @@ according to the pixels.
77
  Upload a pet image and hit submit or select one from the given examples
78
  """
79
 
80
- inputs = gr.inputs.Image(label="Upload a pet image", type = 'pil', optional=False)
81
  outputs = [
82
- gr.outputs.Image(label="Segmentation")
83
  # , gr.outputs.Textbox(type="auto",label="Pet Prediction")
84
  ]
85
 
7
 
8
  from turtle import title
9
  import gradio as gr
 
 
10
  import numpy as np
11
  from PIL import Image
12
  import io
13
  import base64
14
+ from NNfunctions import *
15
 
16
+ opt = GetOptions_allRnd_0317()
17
+ net = LoadModel(opt)
 
18
 
19
  def predict(image):
20
  img = np.array(image)
21
+ img = np.concatenate((img,img,img),axis=2)
22
+ img = np.transpose(img, (2,0,1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # sr,wf,out = EvaluateModel(net,opt,img,outfile)
25
+ sr_img = EvaluateModel(net,opt,img)
 
 
26
 
27
+ return sr_img
28
 
29
 
30
+ title = '<h1 style="text-align: center;">ML-SIM: Reconstruction of SIM images with deep learning</h1>'
31
 
32
  description = """
33
  ## About
37
  Upload a pet image and hit submit or select one from the given examples
38
  """
39
 
40
+ inputs = gr.inputs.Image(label="Upload a TIFF image", type = 'pil', optional=False)
41
  outputs = [
42
+ gr.outputs.Image(label="SIM Reconstruction")
43
  # , gr.outputs.Textbox(type="auto",label="Pet Prediction")
44
  ]
45
 
model/DIV2K_randomised_3x3_20200317.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4936f5ccf5db42009fa23ca3cbd63f53125dbd787240ca78f09e2b85c682a08
3
+ size 64467635
models.py ADDED
@@ -0,0 +1,1997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.init
6
+ import torch.nn.functional as F
7
+ import functools # used by RRDBNet
8
+
9
+
10
+ def GetModel(opt):
11
+ if opt.model.lower() == 'edsr':
12
+ net = EDSR(opt)
13
+ elif opt.model.lower() == 'edsr2max':
14
+ net = EDSR2Max(normalization=opt.norm,nch_in=opt.nch_in,nch_out=opt.nch_out,scale=opt.scale)
15
+ elif opt.model.lower() == 'edsr3max':
16
+ net = EDSR3Max(normalization=opt.norm,nch_in=opt.nch_in,nch_out=opt.nch_out,scale=opt.scale)
17
+ elif opt.model.lower() == 'rcan':
18
+ net = RCAN(opt)
19
+ elif opt.model.lower() == 'rnan':
20
+ net = RNAN(opt)
21
+ elif opt.model.lower() == 'rrdb':
22
+ net = RRDBNet(opt)
23
+ elif opt.model.lower() == 'srresnet' or opt.model.lower() == 'srgan':
24
+ net = Generator(16, opt)
25
+ elif opt.model.lower() == 'unet':
26
+ net = UNet(opt.nch_in,opt.nch_out,opt)
27
+ elif opt.model.lower() == 'unet_n2n':
28
+ net = UNet_n2n(opt.nch_in,opt.nch_out,opt)
29
+ elif opt.model.lower() == 'unet60m':
30
+ net = UNet60M(opt.nch_in,opt.nch_out)
31
+ elif opt.model.lower() == 'unetrep':
32
+ net = UNetRep(opt.nch_in,opt.nch_out)
33
+ elif opt.model.lower() == 'unetgreedy':
34
+ net = UNetGreedy(opt.nch_in,opt.nch_out)
35
+ elif opt.model.lower() == 'mlpnet':
36
+ net = MLPNet()
37
+ elif opt.model.lower() == 'ffdnet':
38
+ net = FFDNet(opt.nch_in)
39
+ elif opt.model.lower() == 'dncnn':
40
+ net = DNCNN(opt.nch_in)
41
+ elif opt.model.lower() == 'fouriernet':
42
+ net = FourierNet()
43
+ elif opt.model.lower() == 'fourierconvnet':
44
+ net = FourierConvNet()
45
+ else:
46
+ print("model undefined")
47
+ return None
48
+
49
+ net.to(opt.device)
50
+ if opt.multigpu:
51
+ net = nn.DataParallel(net)
52
+
53
+ return net
54
+
55
+
56
+
57
+ class MeanShift(nn.Conv2d):
58
+ def __init__(
59
+ self, rgb_range,
60
+ rgb_mean, rgb_std, sign=-1):
61
+
62
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
63
+ std = torch.Tensor(rgb_std)
64
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
65
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
66
+ self.requires_grad = False
67
+
68
+
69
+ def normalizationTransforms(normtype):
70
+ if normtype.lower() == 'div2k':
71
+ normalize = MeanShift(1, [0.4485, 0.4375, 0.4045], [0.2436, 0.2330, 0.2424])
72
+ unnormalize = MeanShift(1, [-1.8411, -1.8777, -1.6687], [4.1051, 4.2918, 4.1254])
73
+ print('using div2k normalization')
74
+ elif normtype.lower() == 'pcam':
75
+ normalize = MeanShift(1, [0.6975, 0.5348, 0.688], [0.2361, 0.2786, 0.2146])
76
+ unnormalize = MeanShift(1, [-2.9547, -1.9198, -3.20643], [4.2363, 3.58972, 4.66049])
77
+ print('using pcam normalization')
78
+ elif normtype.lower() == 'div2k_std1':
79
+ normalize = MeanShift(1, [0.4485, 0.4375, 0.4045], [1,1,1])
80
+ unnormalize = MeanShift(1, [-0.4485, -0.4375, -0.4045], [1,1,1])
81
+ print('using div2k normalization with std 1')
82
+ elif normtype.lower() == 'pcam_std1':
83
+ normalize = MeanShift(1, [0.6975, 0.5348, 0.688], [1,1,1])
84
+ unnormalize = MeanShift(1, [-0.6975, -0.5348, -0.688], [1,1,1])
85
+ print('using pcam normalization with std 1')
86
+ else:
87
+ print('not using normalization')
88
+ return None, None
89
+ return normalize, unnormalize
90
+
91
+
92
+ def conv(in_channels, out_channels, kernel_size, bias=True):
93
+ return nn.Conv2d(
94
+ in_channels, out_channels, kernel_size,
95
+ padding=(kernel_size//2), bias=bias)
96
+
97
+ class BasicBlock(nn.Sequential):
98
+ def __init__(
99
+ self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
100
+ bn=True, act=nn.ReLU(True)):
101
+
102
+ m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
103
+ if bn: m.append(nn.BatchNorm2d(out_channels))
104
+ if act is not None: m.append(act)
105
+ super(BasicBlock, self).__init__(*m)
106
+
107
+
108
+
109
+ class ResBlock(nn.Module):
110
+ def __init__(
111
+ self, conv, n_feats, kernel_size,
112
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
113
+
114
+ super(ResBlock, self).__init__()
115
+ m = []
116
+
117
+ m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
118
+ m.append(nn.ReLU(True))
119
+
120
+ m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
121
+
122
+ self.body = nn.Sequential(*m)
123
+ self.res_scale = res_scale
124
+
125
+ def forward(self, x):
126
+ res = self.body(x).mul(self.res_scale)
127
+ res += x
128
+
129
+ return res
130
+
131
+
132
+ class ResBlock2Max(nn.Module):
133
+ def __init__(
134
+ self, conv, n_feats, kernel_size,
135
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
136
+
137
+ super(ResBlock2Max, self).__init__()
138
+ m = []
139
+
140
+ m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
141
+
142
+ m.append(nn.MaxPool2d(2))
143
+ m.append(nn.ReLU(True))
144
+
145
+ m.append(conv(n_feats, 2*n_feats, kernel_size, bias=bias))
146
+
147
+ m.append(nn.MaxPool2d(2))
148
+ m.append(nn.ReLU(True))
149
+
150
+ m.append(conv(2*n_feats, 4*n_feats, kernel_size, bias=bias))
151
+ m.append(nn.ReLU(True))
152
+
153
+ m.append(nn.ConvTranspose2d(4*n_feats,2*n_feats,3,stride=2, padding=1, output_padding=1))
154
+
155
+ m.append(nn.ConvTranspose2d(2*n_feats,n_feats,3,stride=2, padding=1, output_padding=1))
156
+
157
+ self.body = nn.Sequential(*m)
158
+ self.res_scale = res_scale
159
+
160
+ def forward(self, x):
161
+ res = self.body(x).mul(self.res_scale)
162
+ res += x
163
+
164
+ return res
165
+
166
+
167
+
168
+
169
+ class ResBlock3Max(nn.Module):
170
+ def __init__(
171
+ self, conv, n_feats, kernel_size,
172
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
173
+
174
+ super(ResBlock3Max, self).__init__()
175
+ m = []
176
+
177
+ m.append(conv(n_feats, 2*n_feats, kernel_size, bias=bias))
178
+ m.append(nn.MaxPool2d(2))
179
+ m.append(nn.ReLU(True))
180
+
181
+ m.append(conv(2*n_feats, 2*n_feats, kernel_size, bias=bias))
182
+ m.append(nn.MaxPool2d(2))
183
+ m.append(nn.ReLU(True))
184
+
185
+ m.append(conv(2*n_feats, 4*n_feats, kernel_size, bias=bias))
186
+ m.append(nn.MaxPool2d(2))
187
+ m.append(nn.ReLU(True))
188
+
189
+ m.append(conv(4*n_feats, 8*n_feats, kernel_size, bias=bias))
190
+ m.append(nn.ReLU(True))
191
+
192
+ m.append(nn.ConvTranspose2d(8*n_feats,4*n_feats,3,stride=2, padding=1, output_padding=1))
193
+ m.append(nn.ConvTranspose2d(4*n_feats,2*n_feats,3,stride=2, padding=1, output_padding=1))
194
+ m.append(nn.ConvTranspose2d(2*n_feats,n_feats,3,stride=2, padding=1, output_padding=1))
195
+
196
+ self.body = nn.Sequential(*m)
197
+ self.res_scale = res_scale
198
+
199
+ def forward(self, x):
200
+ res = self.body(x).mul(self.res_scale)
201
+ res += x
202
+
203
+ return res
204
+
205
+
206
+
207
+ class Upsampler(nn.Sequential):
208
+ def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
209
+
210
+ m = []
211
+ if (scale & (scale - 1)) == 0: # Is scale = 2^n?
212
+ for _ in range(int(math.log(scale, 2))):
213
+ m.append(conv(n_feats, 4 * n_feats, 3, bias))
214
+ m.append(nn.PixelShuffle(2))
215
+ if bn: m.append(nn.BatchNorm2d(n_feats))
216
+
217
+ if act == 'relu':
218
+ m.append(nn.ReLU(True))
219
+ elif act == 'prelu':
220
+ m.append(nn.PReLU(n_feats))
221
+
222
+ elif scale == 3:
223
+ m.append(conv(n_feats, 9 * n_feats, 3, bias))
224
+ m.append(nn.PixelShuffle(3))
225
+ if bn: m.append(nn.BatchNorm2d(n_feats))
226
+
227
+ if act == 'relu':
228
+ m.append(nn.ReLU(True))
229
+ elif act == 'prelu':
230
+ m.append(nn.PReLU(n_feats))
231
+ else:
232
+ raise NotImplementedError
233
+
234
+ super(Upsampler, self).__init__(*m)
235
+
236
+
237
+ class EDSR(nn.Module):
238
+ def __init__(self,opt):
239
+ super(EDSR, self).__init__()
240
+
241
+ n_resblocks = 16
242
+ n_feats = 64
243
+ kernel_size = 3
244
+ act = nn.ReLU(True)
245
+
246
+ if not opt.norm == None:
247
+ self.normalize, self.unnormalize = normalizationTransforms(opt.norm)
248
+ else:
249
+ self.normalize, self.unnormalize = None, None
250
+
251
+
252
+ # define head module
253
+ m_head = [conv(opt.nch_in, n_feats, kernel_size)]
254
+
255
+ # define body module
256
+ m_body = [
257
+ ResBlock(
258
+ conv, n_feats, kernel_size, act=act, res_scale=0.1
259
+ ) for _ in range(n_resblocks)
260
+ ]
261
+ m_body.append(conv(n_feats, n_feats, kernel_size))
262
+
263
+ # define tail module
264
+ if opt.scale == 1:
265
+ if opt.task == 'segment':
266
+ m_tail = [nn.Conv2d(n_feats, 2, 1)]
267
+ else:
268
+ m_tail = [conv(n_feats, opt.nch_out, kernel_size)]
269
+ else:
270
+ m_tail = [
271
+ Upsampler(conv, opt.scale, n_feats, act=False),
272
+ conv(n_feats, opt.nch_out, kernel_size)]
273
+
274
+ self.head = nn.Sequential(*m_head)
275
+ self.body = nn.Sequential(*m_body)
276
+ self.tail = nn.Sequential(*m_tail)
277
+
278
+ def forward(self, x):
279
+
280
+ if not self.normalize == None:
281
+ x = self.normalize(x)
282
+
283
+ x = self.head(x)
284
+
285
+ res = self.body(x)
286
+ res += x
287
+
288
+ x = self.tail(res)
289
+
290
+ if not self.unnormalize == None:
291
+ x = self.unnormalize(x)
292
+
293
+ return x
294
+
295
+
296
+ class EDSR2Max(nn.Module):
297
+ def __init__(self, normalization=None,nch_in=3,nch_out=3,scale=4):
298
+ super(EDSR2Max, self).__init__()
299
+
300
+ n_resblocks = 16
301
+ n_feats = 64
302
+ kernel_size = 3
303
+ act = nn.ReLU(True)
304
+
305
+ if not opt.norm == None:
306
+ self.normalize, self.unnormalize = normalizationTransforms(normalization)
307
+ else:
308
+ self.normalize, self.unnormalize = None, None
309
+
310
+
311
+ # define head module
312
+ m_head = [conv(nch_in, n_feats, kernel_size)]
313
+
314
+ # define body module
315
+ m_body = [
316
+ ResBlock2Max(
317
+ conv, n_feats, kernel_size, act=act, res_scale=0.1
318
+ ) for _ in range(n_resblocks)
319
+ ]
320
+ m_body.append(conv(n_feats, n_feats, kernel_size))
321
+
322
+ # define tail module
323
+ m_tail = [
324
+ conv(n_feats, nch_out, kernel_size)
325
+ ]
326
+
327
+ self.head = nn.Sequential(*m_head)
328
+ self.body = nn.Sequential(*m_body)
329
+ self.tail = nn.Sequential(*m_tail)
330
+
331
+ def forward(self, x):
332
+
333
+ if not self.normalize == None:
334
+ x = self.normalize(x)
335
+
336
+ x = self.head(x)
337
+
338
+ res = self.body(x)
339
+ res += x
340
+
341
+ x = self.tail(res)
342
+
343
+ if not self.unnormalize == None:
344
+ x = self.unnormalize(x)
345
+
346
+ return x
347
+
348
+
349
+
350
+
351
+ class EDSR3Max(nn.Module):
352
+ def __init__(self, normalization=None,nch_in=3,nch_out=3,scale=4):
353
+ super(EDSR3Max, self).__init__()
354
+
355
+ n_resblocks = 16
356
+ n_feats = 64
357
+ kernel_size = 3
358
+ act = nn.ReLU(True)
359
+
360
+ if not opt.norm == None:
361
+ self.normalize, self.unnormalize = normalizationTransforms(normalization)
362
+ else:
363
+ self.normalize, self.unnormalize = None, None
364
+
365
+
366
+ # define head module
367
+ m_head = [conv(nch_in, n_feats, kernel_size)]
368
+
369
+ # define body module
370
+ m_body = [
371
+ ResBlock3Max(
372
+ conv, n_feats, kernel_size, act=act, res_scale=0.1
373
+ ) for _ in range(n_resblocks)
374
+ ]
375
+ m_body.append(conv(n_feats, n_feats, kernel_size))
376
+
377
+ # define tail module
378
+ m_tail = [
379
+ conv(n_feats, nch_out, kernel_size)
380
+ ]
381
+
382
+ self.head = nn.Sequential(*m_head)
383
+ self.body = nn.Sequential(*m_body)
384
+ self.tail = nn.Sequential(*m_tail)
385
+
386
+ def forward(self, x):
387
+
388
+ if not self.normalize == None:
389
+ x = self.normalize(x)
390
+
391
+ x = self.head(x)
392
+
393
+ res = self.body(x)
394
+ res += x
395
+
396
+ x = self.tail(res)
397
+
398
+ if not self.unnormalize == None:
399
+ x = self.unnormalize(x)
400
+
401
+ return x
402
+
403
+
404
+
405
+ # ----------------------------------- RCAN ------------------------------------------
406
+
407
+ ## Channel Attention (CA) Layer
408
+ class CALayer(nn.Module):
409
+ def __init__(self, channel, reduction=16):
410
+ super(CALayer, self).__init__()
411
+ # global average pooling: feature --> point
412
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
413
+ # feature channel downscale and upscale --> channel weight
414
+ self.conv_du = nn.Sequential(
415
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
416
+ nn.ReLU(inplace=True),
417
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
418
+ nn.Sigmoid()
419
+ )
420
+
421
+ def forward(self, x):
422
+ y = self.avg_pool(x)
423
+ y = self.conv_du(y)
424
+ return x * y
425
+
426
+ ## Residual Channel Attention Block (RCAB)
427
+ class RCAB(nn.Module):
428
+ def __init__(
429
+ self, conv, n_feat, kernel_size, reduction,
430
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
431
+
432
+ super(RCAB, self).__init__()
433
+ modules_body = []
434
+ for i in range(2):
435
+ modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
436
+ if bn: modules_body.append(nn.BatchNorm2d(n_feat))
437
+ if i == 0: modules_body.append(act)
438
+ modules_body.append(CALayer(n_feat, reduction))
439
+ self.body = nn.Sequential(*modules_body)
440
+ self.res_scale = res_scale
441
+
442
+ def forward(self, x):
443
+ res = self.body(x)
444
+ #res = self.body(x).mul(self.res_scale)
445
+ res += x
446
+ return res
447
+
448
+ ## Residual Group (RG)
449
+ class ResidualGroup(nn.Module):
450
+ def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
451
+ super(ResidualGroup, self).__init__()
452
+ modules_body = []
453
+ modules_body = [
454
+ RCAB(
455
+ conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
456
+ for _ in range(n_resblocks)]
457
+ modules_body.append(conv(n_feat, n_feat, kernel_size))
458
+ self.body = nn.Sequential(*modules_body)
459
+
460
+ def forward(self, x):
461
+ res = self.body(x)
462
+ res += x
463
+ return res
464
+
465
+ ## Residual Channel Attention Network (RCAN)
466
+ class RCAN(nn.Module):
467
+ def __init__(self, opt):
468
+ super(RCAN, self).__init__()
469
+
470
+ n_resgroups = opt.n_resgroups
471
+ n_resblocks = opt.n_resblocks
472
+ n_feats = opt.n_feats
473
+ kernel_size = 3
474
+ reduction = opt.reduction
475
+ act = nn.ReLU(True)
476
+ self.narch = opt.narch
477
+
478
+ if not opt.norm == None:
479
+ self.normalize, self.unnormalize = normalizationTransforms(opt.norm)
480
+ else:
481
+ self.normalize, self.unnormalize = None, None
482
+
483
+
484
+ # define head module
485
+ if self.narch == 0:
486
+ modules_head = [conv(opt.nch_in, n_feats, kernel_size)]
487
+ self.head = nn.Sequential(*modules_head)
488
+ else:
489
+ self.head0 = conv(1, n_feats, kernel_size)
490
+ self.head1 = conv(1, n_feats, kernel_size)
491
+ self.head2 = conv(1, n_feats, kernel_size)
492
+ self.head3 = conv(1, n_feats, kernel_size)
493
+ self.head4 = conv(1, n_feats, kernel_size)
494
+ self.head5 = conv(1, n_feats, kernel_size)
495
+ self.head6 = conv(1, n_feats, kernel_size)
496
+ self.head7 = conv(1, n_feats, kernel_size)
497
+ self.head8 = conv(1, n_feats, kernel_size)
498
+ self.combineHead = conv(9*n_feats, n_feats, kernel_size)
499
+
500
+
501
+
502
+ # define body module
503
+ modules_body = [
504
+ ResidualGroup(
505
+ conv, n_feats, kernel_size, reduction, act=act, res_scale=1, n_resblocks=n_resblocks) \
506
+ for _ in range(n_resgroups)]
507
+
508
+ modules_body.append(conv(n_feats, n_feats, kernel_size))
509
+
510
+ # define tail module
511
+ if opt.scale == 1:
512
+ if opt.task == 'segment':
513
+ modules_tail = [nn.Conv2d(n_feats, opt.nch_out, 1)]
514
+ else:
515
+ modules_tail = [conv(n_feats, opt.nch_out, kernel_size)]
516
+ else:
517
+ modules_tail = [
518
+ Upsampler(conv, opt.scale, n_feats, act=False),
519
+ conv(n_feats, opt.nch_out, kernel_size)]
520
+
521
+ self.body = nn.Sequential(*modules_body)
522
+ self.tail = nn.Sequential(*modules_tail)
523
+
524
+ def forward(self, x):
525
+
526
+ if not self.normalize == None:
527
+ x = self.normalize(x)
528
+
529
+ if self.narch == 0:
530
+ x = self.head(x)
531
+ else:
532
+ x0 = self.head0(x[:,0:0+1,:,:])
533
+ x1 = self.head1(x[:,1:1+1,:,:])
534
+ x2 = self.head2(x[:,2:2+1,:,:])
535
+ x3 = self.head3(x[:,3:3+1,:,:])
536
+ x4 = self.head4(x[:,4:4+1,:,:])
537
+ x5 = self.head5(x[:,5:5+1,:,:])
538
+ x6 = self.head6(x[:,6:6+1,:,:])
539
+ x7 = self.head7(x[:,7:7+1,:,:])
540
+ x8 = self.head8(x[:,8:8+1,:,:])
541
+ x = torch.cat((x0,x1,x2,x3,x4,x5,x6,x7,x8), 1)
542
+ x = self.combineHead(x)
543
+
544
+ res = self.body(x)
545
+ res += x
546
+
547
+ x = self.tail(res)
548
+
549
+ if not self.unnormalize == None:
550
+ x = self.unnormalize(x)
551
+
552
+ return x
553
+
554
+
555
+
556
+
557
+
558
+ # ----------------------------------- RNAN ------------------------------------------
559
+
560
+
561
+ # add NonLocalBlock2D
562
+ # reference: https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_simple_version.py
563
+ class NonLocalBlock2D(nn.Module):
564
+ def __init__(self, in_channels, inter_channels):
565
+ super(NonLocalBlock2D, self).__init__()
566
+
567
+ self.in_channels = in_channels
568
+ self.inter_channels = inter_channels
569
+
570
+ self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
571
+
572
+ self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)
573
+ # for pytorch 0.3.1
574
+ #nn.init.constant(self.W.weight, 0)
575
+ #nn.init.constant(self.W.bias, 0)
576
+ # for pytorch 0.4.0
577
+ nn.init.constant_(self.W.weight, 0)
578
+ nn.init.constant_(self.W.bias, 0)
579
+ self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
580
+
581
+ self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
582
+
583
+ def forward(self, x):
584
+
585
+ batch_size = x.size(0)
586
+
587
+ g_x = self.g(x).view(batch_size, self.inter_channels, -1)
588
+
589
+ g_x = g_x.permute(0,2,1)
590
+
591
+ theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
592
+
593
+ theta_x = theta_x.permute(0,2,1)
594
+
595
+ phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
596
+
597
+ f = torch.matmul(theta_x, phi_x)
598
+
599
+ f_div_C = F.softmax(f, dim=1)
600
+
601
+
602
+ y = torch.matmul(f_div_C, g_x)
603
+
604
+ y = y.permute(0,2,1).contiguous()
605
+
606
+ y = y.view(batch_size, self.inter_channels, *x.size()[2:])
607
+ W_y = self.W(y)
608
+ z = W_y + x
609
+
610
+ return z
611
+
612
+
613
+ ## define trunk branch
614
+ class TrunkBranch(nn.Module):
615
+ def __init__(
616
+ self, conv, n_feat, kernel_size,
617
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
618
+
619
+ super(TrunkBranch, self).__init__()
620
+ modules_body = []
621
+ for i in range(2):
622
+ modules_body.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
623
+ self.body = nn.Sequential(*modules_body)
624
+
625
+ def forward(self, x):
626
+ tx = self.body(x)
627
+
628
+ return tx
629
+
630
+
631
+
632
+ ## define mask branch
633
+ class MaskBranchDownUp(nn.Module):
634
+ def __init__(
635
+ self, conv, n_feat, kernel_size,
636
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
637
+
638
+ super(MaskBranchDownUp, self).__init__()
639
+
640
+ MB_RB1 = []
641
+ MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
642
+
643
+ MB_Down = []
644
+ MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1))
645
+
646
+ MB_RB2 = []
647
+ for i in range(2):
648
+ MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
649
+
650
+ MB_Up = []
651
+ MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2))
652
+
653
+ MB_RB3 = []
654
+ MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
655
+
656
+ MB_1x1conv = []
657
+ MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True))
658
+
659
+ MB_sigmoid = []
660
+ MB_sigmoid.append(nn.Sigmoid())
661
+
662
+ self.MB_RB1 = nn.Sequential(*MB_RB1)
663
+ self.MB_Down = nn.Sequential(*MB_Down)
664
+ self.MB_RB2 = nn.Sequential(*MB_RB2)
665
+ self.MB_Up = nn.Sequential(*MB_Up)
666
+ self.MB_RB3 = nn.Sequential(*MB_RB3)
667
+ self.MB_1x1conv = nn.Sequential(*MB_1x1conv)
668
+ self.MB_sigmoid = nn.Sequential(*MB_sigmoid)
669
+
670
+ def forward(self, x):
671
+ x_RB1 = self.MB_RB1(x)
672
+ x_Down = self.MB_Down(x_RB1)
673
+ x_RB2 = self.MB_RB2(x_Down)
674
+ x_Up = self.MB_Up(x_RB2)
675
+ x_preRB3 = x_RB1 + x_Up
676
+ x_RB3 = self.MB_RB3(x_preRB3)
677
+ x_1x1 = self.MB_1x1conv(x_RB3)
678
+ mx = self.MB_sigmoid(x_1x1)
679
+
680
+ return mx
681
+
682
+ ## define nonlocal mask branch
683
+ class NLMaskBranchDownUp(nn.Module):
684
+ def __init__(
685
+ self, conv, n_feat, kernel_size,
686
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
687
+
688
+ super(NLMaskBranchDownUp, self).__init__()
689
+
690
+ MB_RB1 = []
691
+ MB_RB1.append(NonLocalBlock2D(n_feat, n_feat // 2))
692
+ MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
693
+
694
+ MB_Down = []
695
+ MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1))
696
+
697
+ MB_RB2 = []
698
+ for i in range(2):
699
+ MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
700
+
701
+ MB_Up = []
702
+ MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2))
703
+
704
+ MB_RB3 = []
705
+ MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
706
+
707
+ MB_1x1conv = []
708
+ MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True))
709
+
710
+ MB_sigmoid = []
711
+ MB_sigmoid.append(nn.Sigmoid())
712
+
713
+ self.MB_RB1 = nn.Sequential(*MB_RB1)
714
+ self.MB_Down = nn.Sequential(*MB_Down)
715
+ self.MB_RB2 = nn.Sequential(*MB_RB2)
716
+ self.MB_Up = nn.Sequential(*MB_Up)
717
+ self.MB_RB3 = nn.Sequential(*MB_RB3)
718
+ self.MB_1x1conv = nn.Sequential(*MB_1x1conv)
719
+ self.MB_sigmoid = nn.Sequential(*MB_sigmoid)
720
+
721
+ def forward(self, x):
722
+ x_RB1 = self.MB_RB1(x)
723
+ x_Down = self.MB_Down(x_RB1)
724
+ x_RB2 = self.MB_RB2(x_Down)
725
+ x_Up = self.MB_Up(x_RB2)
726
+ x_preRB3 = x_RB1 + x_Up
727
+ x_RB3 = self.MB_RB3(x_preRB3)
728
+ x_1x1 = self.MB_1x1conv(x_RB3)
729
+ mx = self.MB_sigmoid(x_1x1)
730
+
731
+ return mx
732
+
733
+
734
+
735
+
736
+ ## define residual attention module
737
+ class ResAttModuleDownUpPlus(nn.Module):
738
+ def __init__(
739
+ self, conv, n_feat, kernel_size,
740
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
741
+ super(ResAttModuleDownUpPlus, self).__init__()
742
+ RA_RB1 = []
743
+ RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
744
+ RA_TB = []
745
+ RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
746
+ RA_MB = []
747
+ RA_MB.append(MaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
748
+ RA_tail = []
749
+ for i in range(2):
750
+ RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
751
+
752
+ self.RA_RB1 = nn.Sequential(*RA_RB1)
753
+ self.RA_TB = nn.Sequential(*RA_TB)
754
+ self.RA_MB = nn.Sequential(*RA_MB)
755
+ self.RA_tail = nn.Sequential(*RA_tail)
756
+
757
+ def forward(self, input):
758
+ RA_RB1_x = self.RA_RB1(input)
759
+ tx = self.RA_TB(RA_RB1_x)
760
+ mx = self.RA_MB(RA_RB1_x)
761
+ txmx = tx * mx
762
+ hx = txmx + RA_RB1_x
763
+ hx = self.RA_tail(hx)
764
+
765
+ return hx
766
+
767
+
768
+ ## define nonlocal residual attention module
769
+ class NLResAttModuleDownUpPlus(nn.Module):
770
+ def __init__(
771
+ self, conv, n_feat, kernel_size,
772
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
773
+ super(NLResAttModuleDownUpPlus, self).__init__()
774
+ RA_RB1 = []
775
+ RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
776
+ RA_TB = []
777
+ RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
778
+ RA_MB = []
779
+ RA_MB.append(NLMaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
780
+ RA_tail = []
781
+ for i in range(2):
782
+ RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
783
+
784
+ self.RA_RB1 = nn.Sequential(*RA_RB1)
785
+ self.RA_TB = nn.Sequential(*RA_TB)
786
+ self.RA_MB = nn.Sequential(*RA_MB)
787
+ self.RA_tail = nn.Sequential(*RA_tail)
788
+
789
+ def forward(self, input):
790
+ RA_RB1_x = self.RA_RB1(input)
791
+ tx = self.RA_TB(RA_RB1_x)
792
+ mx = self.RA_MB(RA_RB1_x)
793
+ txmx = tx * mx
794
+ hx = txmx + RA_RB1_x
795
+ hx = self.RA_tail(hx)
796
+
797
+ return hx
798
+
799
+
800
+ class _ResGroup(nn.Module):
801
+ def __init__(self, conv, n_feats, kernel_size, act, res_scale):
802
+ super(_ResGroup, self).__init__()
803
+ modules_body = []
804
+ modules_body.append(ResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
805
+ modules_body.append(conv(n_feats, n_feats, kernel_size))
806
+ self.body = nn.Sequential(*modules_body)
807
+
808
+ def forward(self, x):
809
+ res = self.body(x)
810
+ return res
811
+
812
+ class _NLResGroup(nn.Module):
813
+ def __init__(self, conv, n_feats, kernel_size, act, res_scale):
814
+ super(_NLResGroup, self).__init__()
815
+ modules_body = []
816
+ modules_body.append(NLResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
817
+ modules_body.append(conv(n_feats, n_feats, kernel_size))
818
+ self.body = nn.Sequential(*modules_body)
819
+
820
+ def forward(self, x):
821
+ res = self.body(x)
822
+ return res
823
+
824
+ class RNAN(nn.Module):
825
+ def __init__(self, opt):
826
+ super(RNAN, self).__init__()
827
+
828
+ n_resgroups = opt.n_resgroups
829
+ n_feats = opt.n_feats
830
+ kernel_size = 3
831
+ reduction = opt.reduction
832
+ act = nn.ReLU(True)
833
+
834
+
835
+ print(n_resgroup2,n_resblock,n_feats,kernel_size,reduction,act)
836
+
837
+ # RGB mean for DIV2K 1-800
838
+ # rgb_mean = (0.4488, 0.4371, 0.4040)
839
+ # rgb_std = (1.0, 1.0, 1.0)
840
+ # self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std)
841
+
842
+ # define head module
843
+ modules_head = [conv(opt.nch_in, n_feats, kernel_size)]
844
+
845
+ # define body module
846
+ modules_body_nl_low = [
847
+ _NLResGroup(
848
+ conv, n_feats, kernel_size, act=act, res_scale=1)]
849
+ modules_body = [
850
+ _ResGroup(
851
+ conv, n_feats, kernel_size, act=act, res_scale=1) \
852
+ for _ in range(n_resgroups - 2)]
853
+ modules_body_nl_high = [
854
+ _NLResGroup(
855
+ conv, n_feats, kernel_size, act=act, res_scale=1)]
856
+ modules_body.append(conv(n_feats, n_feats, kernel_size))
857
+
858
+ # define tail module
859
+ modules_tail = [
860
+ Upsampler(conv, opt.scale, n_feats, act=False),
861
+ conv(n_feats, opt.nch_out, kernel_size)]
862
+
863
+ # self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
864
+
865
+ self.head = nn.Sequential(*modules_head)
866
+ self.body_nl_low = nn.Sequential(*modules_body_nl_low)
867
+ self.body = nn.Sequential(*modules_body)
868
+ self.body_nl_high = nn.Sequential(*modules_body_nl_high)
869
+ self.tail = nn.Sequential(*modules_tail)
870
+
871
+ def forward(self, x):
872
+
873
+ # x = self.sub_mean(x)
874
+ feats_shallow = self.head(x)
875
+
876
+ res = self.body_nl_low(feats_shallow)
877
+ res = self.body(res)
878
+ res = self.body_nl_high(res)
879
+ res += feats_shallow
880
+
881
+ res_main = self.tail(res)
882
+
883
+ # res_main = self.add_mean(res_main)
884
+
885
+ return res_main
886
+
887
+
888
+
889
+
890
+
891
+
892
+
893
+
894
+
895
+
896
+
897
+
898
+ class FourierNet(nn.Module):
899
+
900
+ def __init__(self):
901
+ super(FourierNet, self).__init__()
902
+ self.inp = nn.Linear(85*85*9,85*85)
903
+
904
+
905
+ def forward(self, x):
906
+ x = x.view(-1,85*85*9)
907
+ x = (self.inp(x))
908
+ # x = (self.lay1(x))
909
+ x = x.view(-1,1,85,85)
910
+ return x
911
+
912
+
913
+ class FourierConvNet(nn.Module):
914
+
915
+ def __init__(self):
916
+ super(FourierConvNet, self).__init__()
917
+
918
+
919
+ # self.inp = nn.Conv2d(18,32,3, stride=1, padding=1)
920
+ # self.lay1 = nn.Conv2d(32,32,3, stride=1, padding=1)
921
+ # self.lay2 = nn.Conv2d(32,32,3, stride=1, padding=1)
922
+ # self.lay3 = nn.Conv2d(32,32,3, stride=1, padding=1)
923
+
924
+ # self.pool = nn.MaxPool2d(2,2)
925
+ # self.out = nn.Conv2d(32,1,3, stride=1, padding=1)
926
+
927
+ # self.labels = nn.Linear(4096,18)
928
+
929
+ self.inc = inconv(18, 64)
930
+ self.down1 = down(64, 128)
931
+ self.down2 = down(128, 256)
932
+ self.down3 = down(256, 512)
933
+ self.down4 = down(512, 512)
934
+ self.up1 = up(1024, 256)
935
+ self.up2 = up(512, 128)
936
+ self.up3 = up(256, 64)
937
+ self.up4 = up(128, 64)
938
+ self.outc = outconv(64, 9) # two channels for complex
939
+
940
+
941
+ def forward(self, x):
942
+ # x = self.inp(x)
943
+
944
+ # x = torch.rfft(x,2,onesided=False)
945
+ # # x = torch.log( torch.abs(x) + 1 )
946
+
947
+ # x = x.permute(0,1,4,2,3) # put real and imag parts after stack index
948
+ # x = x.contiguous().view(-1,18,256,256)
949
+
950
+ # x = F.relu(self.inp(x))
951
+
952
+ # x = self.pool(x) # to 128
953
+ # x = F.relu(self.lay2(x))
954
+ # x = self.pool(x) # to 64
955
+ # x = F.relu(self.lay3(x))
956
+
957
+ # x = self.out(x)
958
+
959
+ # x = x.view(-1,4096)
960
+
961
+ # x = self.labels(x)
962
+
963
+ x1 = self.inc(x)
964
+ x2 = self.down1(x1)
965
+ x3 = self.down2(x2)
966
+ x4 = self.down3(x3)
967
+ x5 = self.down4(x4)
968
+ x = self.up1(x5, x4)
969
+ x = self.up2(x, x3)
970
+ x = self.up3(x, x2)
971
+ x = self.up4(x, x1)
972
+ x = self.outc(x)
973
+
974
+ x = torch.log(torch.abs(x))
975
+
976
+ # x = x.permute(0,2,3,1)
977
+ # x = torch.irfft(x,2,onesided=False)
978
+ return x
979
+
980
+
981
+ # super(UNet, self).__init__()
982
+ # self.inc = inconv(n_channels, 64)
983
+ # self.down1 = down(64, 128)
984
+ # self.down2 = down(128, 256)
985
+ # self.down3 = down(256, 512)
986
+ # self.down4 = down(512, 512)
987
+ # self.up1 = up(1024, 256)
988
+ # self.up2 = up(512, 128)
989
+ # self.up3 = up(256, 64)
990
+ # self.up4 = up(128, 64)
991
+
992
+ # if opt.task == 'segment':
993
+ # self.outc = outconv(64, 2)
994
+ # else:
995
+ # self.outc = outconv(64, n_classes)
996
+
997
+ # # Initialize weights
998
+ # # self._init_weights()
999
+
1000
+
1001
+ # def _init_weights(self):
1002
+ # """Initializes weights using He et al. (2015)."""
1003
+
1004
+ # for m in self.modules():
1005
+ # if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
1006
+ # nn.init.kaiming_normal_(m.weight.data)
1007
+ # m.bias.data.zero_()
1008
+
1009
+
1010
+ # def forward(self, x):
1011
+ # x1 = self.inc(x)
1012
+ # x2 = self.down1(x1)
1013
+ # x3 = self.down2(x2)
1014
+ # x4 = self.down3(x3)
1015
+ # x5 = self.down4(x4)
1016
+ # x = self.up1(x5, x4)
1017
+ # x = self.up2(x, x3)
1018
+ # x = self.up3(x, x2)
1019
+ # x = self.up4(x, x1)
1020
+ # x = self.outc(x)
1021
+ # return F.sigmoid(x)
1022
+
1023
+
1024
+ # ----------------------------------- RRDB (ESRGAN) ------------------------------------------
1025
+
1026
+
1027
+ def initialize_weights(net_l, scale=1):
1028
+ if not isinstance(net_l, list):
1029
+ net_l = [net_l]
1030
+ for net in net_l:
1031
+ for m in net.modules():
1032
+ if isinstance(m, nn.Conv2d):
1033
+ torch.nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
1034
+ m.weight.data *= scale # for residual block
1035
+ if m.bias is not None:
1036
+ m.bias.data.zero_()
1037
+ elif isinstance(m, nn.Linear):
1038
+ torch.nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
1039
+ m.weight.data *= scale
1040
+ if m.bias is not None:
1041
+ m.bias.data.zero_()
1042
+ elif isinstance(m, nn.BatchNorm2d):
1043
+ torch.nn.init.constant_(m.weight, 1)
1044
+ torch.nn.init.constant_(m.bias.data, 0.0)
1045
+
1046
+
1047
+ def make_layer(block, n_layers):
1048
+ layers = []
1049
+ for _ in range(n_layers):
1050
+ layers.append(block())
1051
+ return nn.Sequential(*layers)
1052
+
1053
+
1054
+
1055
+
1056
+ class ResidualDenseBlock_5C(nn.Module):
1057
+ def __init__(self, nf=64, gc=32, bias=True):
1058
+ super(ResidualDenseBlock_5C, self).__init__()
1059
+ # gc: growth channel, i.e. intermediate channels
1060
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
1061
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
1062
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
1063
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
1064
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
1065
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1066
+
1067
+ # initialization
1068
+ initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],0.1)
1069
+
1070
+ def forward(self, x):
1071
+ x1 = self.lrelu(self.conv1(x))
1072
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
1073
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
1074
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
1075
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
1076
+ return x5 * 0.2 + x
1077
+
1078
+
1079
+ class RRDB(nn.Module):
1080
+ '''Residual in Residual Dense Block'''
1081
+
1082
+ def __init__(self, nf, gc=32):
1083
+ super(RRDB, self).__init__()
1084
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
1085
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
1086
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
1087
+
1088
+ def forward(self, x):
1089
+ out = self.RDB1(x)
1090
+ out = self.RDB2(out)
1091
+ out = self.RDB3(out)
1092
+ return out * 0.2 + x
1093
+
1094
+
1095
+ class RRDBNet(nn.Module):
1096
+ def __init__(self, opt, gc=32):
1097
+ super(RRDBNet, self).__init__()
1098
+ RRDB_block_f = functools.partial(RRDB, nf=opt.n_feats, gc=gc)
1099
+
1100
+ self.conv_first = nn.Conv2d(opt.nch_in, opt.n_feats, 3, 1, 1, bias=True)
1101
+ self.RRDB_trunk = make_layer(RRDB_block_f, opt.n_resblocks)
1102
+ self.trunk_conv = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True)
1103
+ #### upsampling
1104
+ self.upconv1 = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True)
1105
+ self.upconv2 = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True)
1106
+ self.HRconv = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True)
1107
+ self.conv_last = nn.Conv2d(opt.n_feats, opt.nch_out, 3, 1, 1, bias=True)
1108
+
1109
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1110
+ self.scale = opt.scale
1111
+
1112
+ def forward(self, x):
1113
+ fea = self.conv_first(x)
1114
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
1115
+ fea = fea + trunk
1116
+
1117
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=self.scale, mode='nearest')))
1118
+ # fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=self.scale, mode='nearest')))
1119
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
1120
+
1121
+ return out
1122
+
1123
+
1124
+
1125
+
1126
+ # ----------------------------------- SRGAN ------------------------------------------
1127
+
1128
+
1129
+ def swish(x):
1130
+ return x * torch.sigmoid(x)
1131
+
1132
+ class FeatureExtractor(nn.Module):
1133
+ def __init__(self, cnn, feature_layer=11):
1134
+ super(FeatureExtractor, self).__init__()
1135
+ self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer+1)])
1136
+
1137
+ def forward(self, x):
1138
+ return self.features(x)
1139
+
1140
+
1141
+ class residualBlock(nn.Module):
1142
+ def __init__(self, in_channels=64, k=3, n=64, s=1):
1143
+ super(residualBlock, self).__init__()
1144
+
1145
+ self.conv1 = nn.Conv2d(in_channels, n, k, stride=s, padding=1)
1146
+ self.bn1 = nn.BatchNorm2d(n)
1147
+ self.conv2 = nn.Conv2d(n, n, k, stride=s, padding=1)
1148
+ self.bn2 = nn.BatchNorm2d(n)
1149
+
1150
+ def forward(self, x):
1151
+ y = swish(self.bn1(self.conv1(x)))
1152
+ return self.bn2(self.conv2(y)) + x
1153
+
1154
+ class upsampleBlock(nn.Module):
1155
+ # Implements resize-convolution
1156
+ def __init__(self, in_channels, out_channels):
1157
+ super(upsampleBlock, self).__init__()
1158
+ self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
1159
+ self.shuffler = nn.PixelShuffle(2)
1160
+
1161
+ def forward(self, x):
1162
+ return swish(self.shuffler(self.conv(x)))
1163
+
1164
+ class Generator(nn.Module):
1165
+ def __init__(self, n_residual_blocks, opt):
1166
+ super(Generator, self).__init__()
1167
+ self.n_residual_blocks = n_residual_blocks
1168
+ self.upsample_factor = opt.scale
1169
+
1170
+ self.conv1 = nn.Conv2d(opt.nch_in, 64, 9, stride=1, padding=4)
1171
+
1172
+ if not opt.norm == None:
1173
+ self.normalize, self.unnormalize = normalizationTransforms(opt.norm)
1174
+ else:
1175
+ self.normalize, self.unnormalize = None, None
1176
+
1177
+
1178
+ for i in range(self.n_residual_blocks):
1179
+ self.add_module('residual_block' + str(i+1), residualBlock())
1180
+
1181
+ self.conv2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
1182
+ self.bn2 = nn.BatchNorm2d(64)
1183
+
1184
+ # for i in range(int(self.upsample_factor/2)):
1185
+ # self.add_module('upsample' + str(i+1), upsampleBlock(64, 256))
1186
+
1187
+ if opt.task == 'segment':
1188
+ self.conv3 = nn.Conv2d(64, 2, 1)
1189
+ else:
1190
+ self.conv3 = nn.Conv2d(64, opt.nch_out, 9, stride=1, padding=4)
1191
+
1192
+ def forward(self, x):
1193
+
1194
+ if not self.normalize == None:
1195
+ x = self.normalize(x)
1196
+
1197
+ x = swish(self.conv1(x))
1198
+
1199
+ y = x.clone()
1200
+ for i in range(self.n_residual_blocks):
1201
+ y = self.__getattr__('residual_block' + str(i+1))(y)
1202
+
1203
+ x = self.bn2(self.conv2(y)) + x
1204
+
1205
+ # for i in range(int(self.upsample_factor/2)):
1206
+ # x = self.__getattr__('upsample' + str(i+1))(x)
1207
+
1208
+ x = self.conv3(x)
1209
+
1210
+ if not self.unnormalize == None:
1211
+ x = self.unnormalize(x)
1212
+
1213
+ return x
1214
+
1215
+ class Discriminator(nn.Module):
1216
+ def __init__(self,opt):
1217
+ super(Discriminator, self).__init__()
1218
+ self.conv1 = nn.Conv2d(opt.nch_out, 64, 3, stride=1, padding=1)
1219
+
1220
+ self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1)
1221
+ self.bn2 = nn.BatchNorm2d(64)
1222
+ self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
1223
+ self.bn3 = nn.BatchNorm2d(128)
1224
+ self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
1225
+ self.bn4 = nn.BatchNorm2d(128)
1226
+ self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
1227
+ self.bn5 = nn.BatchNorm2d(256)
1228
+ self.conv6 = nn.Conv2d(256, 256, 3, stride=2, padding=1)
1229
+ self.bn6 = nn.BatchNorm2d(256)
1230
+ self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
1231
+ self.bn7 = nn.BatchNorm2d(512)
1232
+ self.conv8 = nn.Conv2d(512, 512, 3, stride=2, padding=1)
1233
+ self.bn8 = nn.BatchNorm2d(512)
1234
+
1235
+ # Replaced original paper FC layers with FCN
1236
+ self.conv9 = nn.Conv2d(512, 1, 1, stride=1, padding=1)
1237
+
1238
+ def forward(self, x):
1239
+ x = swish(self.conv1(x))
1240
+
1241
+ x = swish(self.bn2(self.conv2(x)))
1242
+ x = swish(self.bn3(self.conv3(x)))
1243
+ x = swish(self.bn4(self.conv4(x)))
1244
+ x = swish(self.bn5(self.conv5(x)))
1245
+ x = swish(self.bn6(self.conv6(x)))
1246
+ x = swish(self.bn7(self.conv7(x)))
1247
+ x = swish(self.bn8(self.conv8(x)))
1248
+
1249
+ x = self.conv9(x)
1250
+ return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1)
1251
+
1252
+
1253
+
1254
+
1255
+
1256
+
1257
+
1258
+
1259
+
1260
+ class UNet_n2n(nn.Module):
1261
+ """Custom U-Net architecture for Noise2Noise (see Appendix, Table 2)."""
1262
+
1263
+ def __init__(self, in_channels=3, out_channels=3, opt = {}):
1264
+ """Initializes U-Net."""
1265
+
1266
+ super(UNet_n2n, self).__init__()
1267
+
1268
+ # Layers: enc_conv0, enc_conv1, pool1
1269
+ self._block1 = nn.Sequential(
1270
+ nn.Conv2d(in_channels, 48, 3, stride=1, padding=1),
1271
+ nn.ReLU(inplace=True),
1272
+ nn.Conv2d(48, 48, 3, padding=1),
1273
+ nn.ReLU(inplace=True),
1274
+ nn.MaxPool2d(2))
1275
+
1276
+ # Layers: enc_conv(i), pool(i); i=2..5
1277
+ self._block2 = nn.Sequential(
1278
+ nn.Conv2d(48, 48, 3, stride=1, padding=1),
1279
+ nn.ReLU(inplace=True),
1280
+ nn.MaxPool2d(2))
1281
+
1282
+ # Layers: enc_conv6, upsample5
1283
+ self._block3 = nn.Sequential(
1284
+ nn.Conv2d(48, 48, 3, stride=1, padding=1),
1285
+ nn.ReLU(inplace=True),
1286
+ nn.ConvTranspose2d(48, 48, 3, stride=2, padding=1, output_padding=1))
1287
+ #nn.Upsample(scale_factor=2, mode='nearest'))
1288
+
1289
+ # Layers: dec_conv5a, dec_conv5b, upsample4
1290
+ self._block4 = nn.Sequential(
1291
+ nn.Conv2d(96, 96, 3, stride=1, padding=1),
1292
+ nn.ReLU(inplace=True),
1293
+ nn.Conv2d(96, 96, 3, stride=1, padding=1),
1294
+ nn.ReLU(inplace=True),
1295
+ nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1))
1296
+ #nn.Upsample(scale_factor=2, mode='nearest'))
1297
+
1298
+ # Layers: dec_deconv(i)a, dec_deconv(i)b, upsample(i-1); i=4..2
1299
+ self._block5 = nn.Sequential(
1300
+ nn.Conv2d(144, 96, 3, stride=1, padding=1),
1301
+ nn.ReLU(inplace=True),
1302
+ nn.Conv2d(96, 96, 3, stride=1, padding=1),
1303
+ nn.ReLU(inplace=True),
1304
+ nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1))
1305
+ #nn.Upsample(scale_factor=2, mode='nearest'))
1306
+
1307
+ # Layers: dec_conv1a, dec_conv1b, dec_conv1c,
1308
+ self._block6 = nn.Sequential(
1309
+ nn.Conv2d(96 + in_channels, 64, 3, stride=1, padding=1),
1310
+ nn.ReLU(inplace=True),
1311
+ nn.Conv2d(64, 32, 3, stride=1, padding=1),
1312
+ nn.ReLU(inplace=True),
1313
+ nn.Conv2d(32, out_channels, 3, stride=1, padding=1),
1314
+ nn.LeakyReLU(0.1))
1315
+
1316
+ # Initialize weights
1317
+ self._init_weights()
1318
+
1319
+ self.task = opt.task
1320
+ if opt.task == 'segment':
1321
+ self._block6 = nn.Sequential(
1322
+ nn.Conv2d(96 + in_channels, 64, 3, stride=1, padding=1),
1323
+ nn.ReLU(inplace=True),
1324
+ nn.Conv2d(64, 32, 3, stride=1, padding=1),
1325
+ nn.ReLU(inplace=True),
1326
+ nn.Conv2d(32, 2, 1))
1327
+
1328
+
1329
+
1330
+ def _init_weights(self):
1331
+ """Initializes weights using He et al. (2015)."""
1332
+
1333
+ for m in self.modules():
1334
+ if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
1335
+ nn.init.kaiming_normal_(m.weight.data)
1336
+ m.bias.data.zero_()
1337
+
1338
+
1339
+ def forward(self, x):
1340
+ """Through encoder, then decoder by adding U-skip connections. """
1341
+
1342
+ # Encoder
1343
+ pool1 = self._block1(x)
1344
+ pool2 = self._block2(pool1)
1345
+ pool3 = self._block2(pool2)
1346
+ pool4 = self._block2(pool3)
1347
+ pool5 = self._block2(pool4)
1348
+
1349
+ # Decoder
1350
+ upsample5 = self._block3(pool5)
1351
+ concat5 = torch.cat((upsample5, pool4), dim=1)
1352
+ upsample4 = self._block4(concat5)
1353
+ concat4 = torch.cat((upsample4, pool3), dim=1)
1354
+ upsample3 = self._block5(concat4)
1355
+ concat3 = torch.cat((upsample3, pool2), dim=1)
1356
+ upsample2 = self._block5(concat3)
1357
+ concat2 = torch.cat((upsample2, pool1), dim=1)
1358
+ upsample1 = self._block5(concat2)
1359
+ concat1 = torch.cat((upsample1, x), dim=1)
1360
+
1361
+ # Final activation
1362
+ return self._block6(concat1)
1363
+
1364
+
1365
+
1366
+ # ------------------ Alternative UNet implementation (batchnorm. outcommented)
1367
+
1368
+
1369
+ class double_conv(nn.Module):
1370
+ '''(conv => BN => ReLU) * 2'''
1371
+ def __init__(self, in_ch, out_ch):
1372
+ super(double_conv, self).__init__()
1373
+ self.conv = nn.Sequential(
1374
+ nn.Conv2d(in_ch, out_ch, 3, padding=1),
1375
+ # nn.BatchNorm2d(out_ch),
1376
+ nn.ReLU(inplace=True),
1377
+ nn.Conv2d(out_ch, out_ch, 3, padding=1),
1378
+ # nn.BatchNorm2d(out_ch),
1379
+ nn.ReLU(inplace=True)
1380
+ )
1381
+
1382
+ def forward(self, x):
1383
+ x = self.conv(x)
1384
+ return x
1385
+
1386
+
1387
+ class inconv(nn.Module):
1388
+ def __init__(self, in_ch, out_ch):
1389
+ super(inconv, self).__init__()
1390
+ self.conv = double_conv(in_ch, out_ch)
1391
+
1392
+ def forward(self, x):
1393
+ x = self.conv(x)
1394
+ return x
1395
+
1396
+
1397
+ class down(nn.Module):
1398
+ def __init__(self, in_ch, out_ch):
1399
+ super(down, self).__init__()
1400
+ self.mpconv = nn.Sequential(
1401
+ nn.MaxPool2d(2),
1402
+ # nn.Conv2d(in_ch,in_ch, 2, stride=2),
1403
+ double_conv(in_ch, out_ch)
1404
+ )
1405
+
1406
+ def forward(self, x):
1407
+ x = self.mpconv(x)
1408
+ return x
1409
+
1410
+
1411
+ class up(nn.Module):
1412
+ def __init__(self, in_ch, out_ch, bilinear=False):
1413
+ super(up, self).__init__()
1414
+
1415
+ # would be a nice idea if the upsampling could be learned too,
1416
+ # but my machine do not have enough memory to handle all those weights
1417
+ if bilinear:
1418
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
1419
+ else:
1420
+ self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
1421
+
1422
+ self.conv = double_conv(in_ch, out_ch)
1423
+
1424
+ def forward(self, x1, x2):
1425
+ x1 = self.up(x1)
1426
+
1427
+ # input is CHW
1428
+ diffY = x2.size()[2] - x1.size()[2]
1429
+ diffX = x2.size()[3] - x1.size()[3]
1430
+
1431
+ x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
1432
+ diffY // 2, diffY - diffY//2))
1433
+
1434
+ # for padding issues, see
1435
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
1436
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
1437
+
1438
+ x = torch.cat([x2, x1], dim=1)
1439
+ x = self.conv(x)
1440
+ return x
1441
+
1442
+
1443
+ class outconv(nn.Module):
1444
+ def __init__(self, in_ch, out_ch):
1445
+ super(outconv, self).__init__()
1446
+ self.conv = nn.Conv2d(in_ch, out_ch, 1)
1447
+
1448
+ def forward(self, x):
1449
+ x = self.conv(x)
1450
+ return x
1451
+
1452
+
1453
+ class UNet(nn.Module):
1454
+ def __init__(self, n_channels, n_classes,opt):
1455
+ super(UNet, self).__init__()
1456
+ self.inc = inconv(n_channels, 64)
1457
+ self.down1 = down(64, 128)
1458
+ self.down2 = down(128, 256)
1459
+ self.down3 = down(256, 512)
1460
+ self.down4 = down(512, 512)
1461
+ self.up1 = up(1024, 256)
1462
+ self.up2 = up(512, 128)
1463
+ self.up3 = up(256, 64)
1464
+ self.up4 = up(128, 64)
1465
+
1466
+ if opt.task == 'segment':
1467
+ self.outc = outconv(64, 2)
1468
+ else:
1469
+ self.outc = outconv(64, n_classes)
1470
+
1471
+ # Initialize weights
1472
+ # self._init_weights()
1473
+
1474
+
1475
+ def _init_weights(self):
1476
+ """Initializes weights using He et al. (2015)."""
1477
+
1478
+ for m in self.modules():
1479
+ if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
1480
+ nn.init.kaiming_normal_(m.weight.data)
1481
+ m.bias.data.zero_()
1482
+
1483
+
1484
+ def forward(self, x):
1485
+ x1 = self.inc(x)
1486
+ x2 = self.down1(x1)
1487
+ x3 = self.down2(x2)
1488
+ x4 = self.down3(x3)
1489
+ x5 = self.down4(x4)
1490
+ x = self.up1(x5, x4)
1491
+ x = self.up2(x, x3)
1492
+ x = self.up3(x, x2)
1493
+ x = self.up4(x, x1)
1494
+ x = self.outc(x)
1495
+ return F.sigmoid(x)
1496
+
1497
+
1498
+ class UNet60M(nn.Module):
1499
+ def __init__(self, n_channels, n_classes):
1500
+ super(UNet60M, self).__init__()
1501
+ self.inc = inconv(n_channels, 64)
1502
+ self.down1 = down(64, 128)
1503
+ self.down2 = down(128, 256)
1504
+ self.down3 = down(256, 512)
1505
+ self.down4 = down(512, 1024)
1506
+ self.down5 = down(1024, 1024)
1507
+ self.up1 = up(2048, 512)
1508
+ self.up2 = up(1024, 256)
1509
+ self.up3 = up(512, 128)
1510
+ self.up4 = up(256, 64)
1511
+ self.up5 = up(128, 64)
1512
+ self.outc = outconv(64, n_classes)
1513
+
1514
+ def forward(self, x):
1515
+ x1 = self.inc(x)
1516
+ x2 = self.down1(x1)
1517
+ x3 = self.down2(x2)
1518
+ x4 = self.down3(x3)
1519
+ x5 = self.down4(x4)
1520
+ x6 = self.down5(x5)
1521
+ x = self.up1(x6, x5)
1522
+ x = self.up2(x, x4)
1523
+ x = self.up3(x, x3)
1524
+ x = self.up4(x, x2)
1525
+ x = self.up5(x, x1)
1526
+ x = self.outc(x)
1527
+ return F.sigmoid(x)
1528
+
1529
+
1530
+ class UNetRep(nn.Module):
1531
+ def __init__(self, n_channels, n_classes):
1532
+ super(UNetRep, self).__init__()
1533
+ self.inc = inconv(n_channels, 64)
1534
+ self.down1 = down(64, 128)
1535
+ self.down2 = down(128, 128)
1536
+ self.up1 = up1(256, 128, 128)
1537
+ self.up2 = up1(192, 64, 128)
1538
+
1539
+ self.outc = outconv(64, n_classes)
1540
+
1541
+ def forward(self, x):
1542
+ x1 = self.inc(x)
1543
+
1544
+ for _ in range(3):
1545
+ x2 = self.down1(x1)
1546
+ x3 = self.down2(x2)
1547
+ x = self.up1(x3,x2)
1548
+ x1 = self.up2(x,x1)
1549
+
1550
+ # x6 = self.down5(x5)
1551
+ # x = self.up1(x6, x5)
1552
+ # x = self.up2(x, x4)
1553
+ # x = self.up3(x, x3)
1554
+ # x = self.up4(x, x2)
1555
+ # x = self.up5(x, x1)
1556
+ x = self.outc(x1)
1557
+ return F.sigmoid(x)
1558
+
1559
+
1560
+
1561
+
1562
+ # ------------------- UNet Noise2noise implementation
1563
+
1564
+ class single_conv(nn.Module):
1565
+ '''(conv => BN => ReLU) * 2'''
1566
+ def __init__(self, in_ch, out_ch):
1567
+ super(single_conv, self).__init__()
1568
+ self.conv = nn.Sequential(
1569
+ nn.Conv2d(in_ch, out_ch, 3, padding=1),
1570
+ # nn.BatchNorm2d(out_ch),
1571
+ nn.ReLU(inplace=True),
1572
+ )
1573
+
1574
+ def forward(self, x):
1575
+ x = self.conv(x)
1576
+ return x
1577
+
1578
+
1579
+ class outconv2(nn.Module):
1580
+ def __init__(self, in_ch, out_ch):
1581
+ super(outconv2, self).__init__()
1582
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
1583
+
1584
+ def forward(self, x):
1585
+ x = self.conv(x)
1586
+ return x
1587
+
1588
+
1589
+ class up1(nn.Module):
1590
+ def __init__(self, in_ch, out_ch, convtr, bilinear=False):
1591
+ super(up1, self).__init__()
1592
+
1593
+ # would be a nice idea if the upsampling could be learned too,
1594
+ # but my machine do not have enough memory to handle all those weights
1595
+ if bilinear:
1596
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
1597
+ else:
1598
+ self.up = nn.ConvTranspose2d(convtr, convtr, 3, stride=2)
1599
+
1600
+ self.conv = double_conv(in_ch, out_ch)
1601
+
1602
+ def forward(self, x1, x2):
1603
+ x1 = self.up(x1)
1604
+
1605
+ # input is CHW
1606
+ diffY = x2.size()[2] - x1.size()[2]
1607
+ diffX = x2.size()[3] - x1.size()[3]
1608
+
1609
+ x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
1610
+ diffY // 2, diffY - diffY//2))
1611
+
1612
+ # for padding issues, see
1613
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
1614
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
1615
+
1616
+ x = torch.cat([x2, x1], dim=1)
1617
+ x = self.conv(x)
1618
+ return x
1619
+
1620
+ class up2(nn.Module):
1621
+ def __init__(self, in_ch, in_ch2, out_ch,out_ch2,convtr, bilinear=False):
1622
+ super(up2, self).__init__()
1623
+
1624
+ # would be a nice idea if the upsampling could be learned too,
1625
+ # but my machine do not have enough memory to handle all those weights
1626
+ if bilinear:
1627
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
1628
+ else:
1629
+ self.up = nn.ConvTranspose2d(convtr, convtr, 3, stride=2)
1630
+
1631
+ # self.conv = double_conv(in_ch, out_ch)
1632
+ self.conv = nn.Conv2d(in_ch + in_ch2, out_ch, 3, padding=1)
1633
+ self.conv2 = nn.Conv2d(out_ch, out_ch2, 3, padding=1)
1634
+
1635
+
1636
+ def forward(self, x1, x2):
1637
+ x1 = self.up(x1)
1638
+
1639
+ # input is CHW
1640
+ diffY = x2.size()[2] - x1.size()[2]
1641
+ diffX = x2.size()[3] - x1.size()[3]
1642
+
1643
+ x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
1644
+ diffY // 2, diffY - diffY//2))
1645
+
1646
+ # for padding issues, see
1647
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
1648
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
1649
+ x = torch.cat([x2, x1], dim=1)
1650
+ x = self.conv(x)
1651
+ x = self.conv2(x)
1652
+ return x
1653
+
1654
+ class down2(nn.Module):
1655
+ def __init__(self, in_ch, out_ch):
1656
+ super(down2, self).__init__()
1657
+ self.mpconv = nn.Sequential(
1658
+ # nn.MaxPool2d(2),
1659
+ nn.Conv2d(in_ch,in_ch, 2, stride=2),
1660
+ single_conv(in_ch, out_ch)
1661
+ )
1662
+
1663
+ def forward(self, x):
1664
+ x = self.mpconv(x)
1665
+ return x
1666
+
1667
+
1668
+ class UNetGreedy(nn.Module):
1669
+ def __init__(self, n_channels, n_classes):
1670
+ super(UNetGreedy, self).__init__()
1671
+ self.inc = inconv(n_channels, 144)
1672
+ self.down1 = down(144, 144)
1673
+ self.down2 = down2(144, 144)
1674
+ self.down3 = down2(144, 144)
1675
+ self.down4 = down2(144, 144)
1676
+ self.down5 = down2(144, 144)
1677
+ self.up1 = up1(288, 288,144)
1678
+ self.up2 = up1(432, 288,288)
1679
+ self.up3 = up1(432, 288,288)
1680
+ self.up4 = up1(432, 288,288)
1681
+ self.up5 = up2(288, n_channels, 64, 32,288)
1682
+ self.outc = outconv2(32, n_classes)
1683
+
1684
+ def forward(self, x0):
1685
+ x1 = self.inc(x0)
1686
+ x2 = self.down1(x1)
1687
+ x3 = self.down2(x2)
1688
+ x4 = self.down3(x3)
1689
+ x5 = self.down4(x4)
1690
+ x6 = self.down5(x5)
1691
+ x = self.up1(x6, x5)
1692
+ x = self.up2(x, x4)
1693
+ x = self.up3(x, x3)
1694
+ x = self.up4(x, x2)
1695
+ x = self.up5(x, x0)
1696
+ x = self.outc(x)
1697
+ return F.sigmoid(x)
1698
+
1699
+
1700
+ class UNet2(nn.Module):
1701
+ def __init__(self, n_channels, n_classes):
1702
+ super(UNet2, self).__init__()
1703
+ self.inc = inconv(n_channels, 48)
1704
+ self.down1 = down(48, 48)
1705
+ self.down2 = down2(48, 48)
1706
+ self.down3 = down2(48, 48)
1707
+ self.down4 = down2(48, 48)
1708
+ self.down5 = down2(48, 48)
1709
+ self.up1 = up1(96, 96,48)
1710
+ self.up2 = up1(144, 96,96)
1711
+ self.up3 = up1(144, 96,96)
1712
+ self.up4 = up1(144, 96,96)
1713
+ self.up5 = up2(96, n_channels, 64, 32,96)
1714
+ self.outc = outconv2(32, n_classes)
1715
+
1716
+ def forward(self, x0):
1717
+ x1 = self.inc(x0)
1718
+ x2 = self.down1(x1)
1719
+ x3 = self.down2(x2)
1720
+ x4 = self.down3(x3)
1721
+ x5 = self.down4(x4)
1722
+ x6 = self.down5(x5)
1723
+ x = self.up1(x6, x5)
1724
+ x = self.up2(x, x4)
1725
+ x = self.up3(x, x3)
1726
+ x = self.up4(x, x2)
1727
+ x = self.up5(x, x0)
1728
+ x = self.outc(x)
1729
+ return F.sigmoid(x)
1730
+
1731
+
1732
+ class MLPNet(nn.Module):
1733
+ def __init__(self):
1734
+ super(MLPNet, self).__init__()
1735
+ # 1 input image channel, 6 output channels, 5x5 square convolution kernel
1736
+ self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
1737
+ self.conv12 = nn.Conv2d(64, 64, 3, padding=1)
1738
+ self.pool = nn.MaxPool2d(2,2)
1739
+ self.conv2 = nn.Conv2d(64, 96, 3, padding=1)
1740
+ self.conv22 = nn.Conv2d(96, 128, 3, padding=1)
1741
+ self.conv3 = nn.Conv2d(96, 128, 3, padding=1)
1742
+ self.conv4 = nn.Conv2d(128, 128, 3, padding=1)
1743
+ self.conv5 = nn.Conv2d(128, 64, 3, padding=1)
1744
+ self.conv6 = nn.Conv2d(64, 32, 5)
1745
+ # self.conv3 = nn.Conv2d(24, 48, 3, padding=1)
1746
+ self.fc = nn.Sequential(
1747
+ nn.Linear(6*6*32, 100),
1748
+ nn.ReLU(),
1749
+ nn.Dropout2d(p=0.2),
1750
+ nn.Linear(100, 1),
1751
+ nn.ReLU()
1752
+ )
1753
+ # self.fc2 = nn.Linear(100,50)
1754
+ # self.fc3 = nn.Linear(50,20)
1755
+ # self.fc4 = nn.Linear(20,1)
1756
+
1757
+ def forward(self, x):
1758
+ x = F.relu(self.conv1(x))
1759
+ x = F.relu(self.conv12(x))
1760
+ x = self.pool(x)
1761
+ x = F.relu(self.conv2(x))
1762
+ # x = F.relu(self.conv22(x))
1763
+ x = self.pool(x)
1764
+ x = F.relu(self.conv3(x))
1765
+ x = F.relu(self.conv4(x))
1766
+ x = self.pool(x)
1767
+ x = F.relu(self.conv5(x))
1768
+ x = self.pool(x)
1769
+ x = F.relu(self.conv6(x))
1770
+ x = self.pool(x)
1771
+ x = x.view(-1, 6*6*32)
1772
+ x = self.fc(x)
1773
+ # x = F.relu(self.fc1(x))
1774
+ # x = F.relu(self.fc2(x))
1775
+ return x
1776
+
1777
+
1778
+
1779
+ # --------------------- FFDNet
1780
+ from torch.autograd import Function, Variable
1781
+
1782
+ def concatenate_input_noise_map(input, noise_sigma):
1783
+ r"""Implements the first layer of FFDNet. This function returns a
1784
+ torch.autograd.Variable composed of the concatenation of the downsampled
1785
+ input image and the noise map. Each image of the batch of size CxHxW gets
1786
+ converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
1787
+ non-overlapped 2x2 patches of the input image are placed in the new array
1788
+ along the first dimension.
1789
+
1790
+ Args:
1791
+ input: batch containing CxHxW images
1792
+ noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
1793
+ """
1794
+ # noise_sigma is a list of length batch_size
1795
+ N, C, H, W = input.size()
1796
+ dtype = input.type()
1797
+ sca = 2
1798
+ sca2 = sca*sca
1799
+ Cout = sca2*C
1800
+ Hout = H//sca
1801
+ Wout = W//sca
1802
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
1803
+
1804
+ # Fill the downsampled image with zeros
1805
+ if 'cuda' in dtype:
1806
+ downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
1807
+ else:
1808
+ downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
1809
+
1810
+ # Build the CxH/2xW/2 noise map
1811
+ noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
1812
+
1813
+ # Populate output
1814
+ for idx in range(sca2):
1815
+ downsampledfeatures[:, idx:Cout:sca2, :, :] = \
1816
+ input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
1817
+
1818
+ # concatenate de-interleaved mosaic with noise map
1819
+ return torch.cat((noise_map, downsampledfeatures), 1)
1820
+
1821
+ class UpSampleFeaturesFunction(Function):
1822
+ r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
1823
+ This class implements the forward and backward methods of the last layer
1824
+ of FFDNet. It basically performs the inverse of
1825
+ concatenate_input_noise_map(): it converts each of the images of a
1826
+ batch of size CxH/2xW/2 to images of size C/4xHxW
1827
+ """
1828
+ @staticmethod
1829
+ def forward(ctx, input):
1830
+ N, Cin, Hin, Win = input.size()
1831
+ dtype = input.type()
1832
+ sca = 2
1833
+ sca2 = sca*sca
1834
+ Cout = Cin//sca2
1835
+ Hout = Hin*sca
1836
+ Wout = Win*sca
1837
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
1838
+
1839
+ assert (Cin%sca2 == 0), \
1840
+ 'Invalid input dimensions: number of channels should be divisible by 4'
1841
+
1842
+ result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
1843
+ for idx in range(sca2):
1844
+ result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = \
1845
+ input[:, idx:Cin:sca2, :, :]
1846
+
1847
+ return result
1848
+
1849
+ @staticmethod
1850
+ def backward(ctx, grad_output):
1851
+ N, Cg_out, Hg_out, Wg_out = grad_output.size()
1852
+ dtype = grad_output.data.type()
1853
+ sca = 2
1854
+ sca2 = sca*sca
1855
+ Cg_in = sca2*Cg_out
1856
+ Hg_in = Hg_out//sca
1857
+ Wg_in = Wg_out//sca
1858
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
1859
+
1860
+ # Build output
1861
+ grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
1862
+ # Populate output
1863
+ for idx in range(sca2):
1864
+ grad_input[:, idx:Cg_in:sca2, :, :] = \
1865
+ grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
1866
+
1867
+ return Variable(grad_input)
1868
+
1869
+ # Alias functions
1870
+ upsamplefeatures = UpSampleFeaturesFunction.apply
1871
+
1872
+
1873
+
1874
+
1875
+ class UpSampleFeatures(nn.Module):
1876
+ r"""Implements the last layer of FFDNet
1877
+ """
1878
+ def __init__(self):
1879
+ super(UpSampleFeatures, self).__init__()
1880
+ def forward(self, x):
1881
+ return upsamplefeatures(x)
1882
+
1883
+ class IntermediateDnCNN(nn.Module):
1884
+ r"""Implements the middel part of the FFDNet architecture, which
1885
+ is basically a DnCNN net
1886
+ """
1887
+ def __init__(self, input_features, middle_features, num_conv_layers):
1888
+ super(IntermediateDnCNN, self).__init__()
1889
+ self.kernel_size = 3
1890
+ self.padding = 1
1891
+ self.input_features = input_features
1892
+ self.num_conv_layers = num_conv_layers
1893
+ self.middle_features = middle_features
1894
+ if self.input_features == 5:
1895
+ self.output_features = 4 #Grayscale image
1896
+ elif self.input_features == 15:
1897
+ self.output_features = 12 #RGB image
1898
+ else:
1899
+ self.output_features = 3
1900
+ # raise Exception('Invalid number of input features')
1901
+
1902
+
1903
+ layers = []
1904
+ layers.append(nn.Conv2d(in_channels=self.input_features,\
1905
+ out_channels=self.middle_features,\
1906
+ kernel_size=self.kernel_size,\
1907
+ padding=self.padding,\
1908
+ bias=False))
1909
+ layers.append(nn.ReLU(inplace=True))
1910
+ for _ in range(self.num_conv_layers-2):
1911
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
1912
+ out_channels=self.middle_features,\
1913
+ kernel_size=self.kernel_size,\
1914
+ padding=self.padding,\
1915
+ bias=False))
1916
+ # layers.append(nn.BatchNorm2d(self.middle_features))
1917
+ layers.append(nn.ReLU(inplace=True))
1918
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
1919
+ out_channels=self.output_features,\
1920
+ kernel_size=self.kernel_size,\
1921
+ padding=self.padding,\
1922
+ bias=False))
1923
+ self.itermediate_dncnn = nn.Sequential(*layers)
1924
+ def forward(self, x):
1925
+ out = self.itermediate_dncnn(x)
1926
+ return out
1927
+
1928
+ class FFDNet(nn.Module):
1929
+ r"""Implements the FFDNet architecture
1930
+ """
1931
+ def __init__(self, num_input_channels, test_mode=False):
1932
+ super(FFDNet, self).__init__()
1933
+ self.num_input_channels = num_input_channels
1934
+ self.test_mode = test_mode
1935
+ if self.num_input_channels == 1:
1936
+ # Grayscale image
1937
+ self.num_feature_maps = 64
1938
+ self.num_conv_layers = 15
1939
+ self.downsampled_channels = 5
1940
+ self.output_features = 4
1941
+ elif self.num_input_channels == 3:
1942
+ # RGB image
1943
+ self.num_feature_maps = 96
1944
+ self.num_conv_layers = 12
1945
+ self.downsampled_channels = 15
1946
+ self.output_features = 12
1947
+ else:
1948
+ raise Exception('Invalid number of input features')
1949
+
1950
+ self.intermediate_dncnn = IntermediateDnCNN(\
1951
+ input_features=self.downsampled_channels,\
1952
+ middle_features=self.num_feature_maps,\
1953
+ num_conv_layers=self.num_conv_layers)
1954
+ self.upsamplefeatures = UpSampleFeatures()
1955
+
1956
+ def forward(self, x, noise_sigma):
1957
+ concat_noise_x = concatenate_input_noise_map(\
1958
+ x.data, noise_sigma.data)
1959
+ if self.test_mode:
1960
+ concat_noise_x = Variable(concat_noise_x, volatile=True)
1961
+ else:
1962
+ concat_noise_x = Variable(concat_noise_x)
1963
+ h_dncnn = self.intermediate_dncnn(concat_noise_x)
1964
+ pred_noise = self.upsamplefeatures(h_dncnn)
1965
+ return pred_noise
1966
+
1967
+
1968
+ class DNCNN(nn.Module):
1969
+ r"""Implements the DNCNNNet architecture
1970
+ """
1971
+ def __init__(self, num_input_channels, test_mode=False):
1972
+ super(DNCNN, self).__init__()
1973
+ self.num_input_channels = num_input_channels
1974
+ self.test_mode = test_mode
1975
+ if self.num_input_channels == 1:
1976
+ # Grayscale image
1977
+ self.num_feature_maps = 64
1978
+ self.num_conv_layers = 15
1979
+ self.downsampled_channels = 5
1980
+ self.output_features = 4
1981
+ elif self.num_input_channels == 3:
1982
+ # RGB image
1983
+ self.num_feature_maps = 96
1984
+ self.num_conv_layers = 12
1985
+ self.downsampled_channels = 15
1986
+ self.output_features = 12
1987
+ else:
1988
+ raise Exception('Invalid number of input features')
1989
+
1990
+ self.intermediate_dncnn = IntermediateDnCNN(\
1991
+ input_features=self.num_input_channels,\
1992
+ middle_features=self.num_feature_maps,\
1993
+ num_conv_layers=self.num_conv_layers)
1994
+
1995
+ def forward(self, x):
1996
+ dncnn = self.intermediate_dncnn(x)
1997
+ return dncnn
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
  huggingface_hub
2
- tensorflow
 
3
  pillow
 
 
 
 
1
  huggingface_hub
2
+ torch
3
+ torchvision
4
  pillow
5
+ scikit-image
6
+ opencv-python
7
+ numpy
8
+ matplotlib