charlesnchr commited on
Commit
c5a3315
1 Parent(s): 2b53c3e

Fox for Skimage deprecation of multichannel

Browse files
Files changed (1) hide show
  1. NNfunctions.py +98 -83
NNfunctions.py CHANGED
@@ -26,21 +26,23 @@ from models import *
26
 
27
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
28
 
 
29
  def remove_dataparallel_wrapper(state_dict):
30
- r"""Converts a DataParallel model to a normal one by removing the "module."
31
- wrapper in the module dictionary
 
 
 
 
 
32
 
33
- Args:
34
- state_dict: a torch.nn.DataParallel state dictionary
35
- """
36
- from collections import OrderedDict
37
 
38
- new_state_dict = OrderedDict()
39
- for k, vl in state_dict.items():
40
- name = k[7:] # remove 'module.' of DataParallel
41
- new_state_dict[name] = vl
42
 
43
- return new_state_dict
44
 
45
  from argparse import Namespace
46
 
@@ -48,105 +50,106 @@ from argparse import Namespace
48
  def GetOptions():
49
  # training options
50
  opt = Namespace()
51
- opt.model = 'rcan'
52
  opt.n_resgroups = 3
53
  opt.n_resblocks = 10
54
  opt.n_feats = 96
55
  opt.reduction = 16
56
  opt.narch = 0
57
- opt.norm = 'minmax'
58
 
59
  opt.cpu = False
60
  opt.multigpu = False
61
  opt.undomulti = False
62
- opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')
 
 
63
 
64
  opt.imageSize = 512
65
  opt.weights = "model/simrec_simin_gtout_rcan_512_2_ntrain790-final.pth"
66
  opt.root = "model/0080.jpg"
67
  opt.out = "model/myout"
68
 
69
- opt.task = 'simin_gtout'
70
  opt.scale = 1
71
  opt.nch_in = 9
72
  opt.nch_out = 1
73
 
74
-
75
  return opt
76
 
77
 
78
  def GetOptions_allRnd_0215():
79
  # training options
80
  opt = Namespace()
81
- opt.model = 'rcan'
82
  opt.n_resgroups = 3
83
  opt.n_resblocks = 10
84
  opt.n_feats = 48
85
  opt.reduction = 16
86
  opt.narch = 0
87
- opt.norm = 'adapthist'
88
 
89
  opt.cpu = False
90
  opt.multigpu = False
91
  opt.undomulti = False
92
- opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')
 
 
93
 
94
  opt.imageSize = 512
95
  opt.weights = "model/0216_SIMRec_0214_rndAll_rcan_continued.pth"
96
  opt.root = "model/0080.jpg"
97
  opt.out = "model/myout"
98
 
99
- opt.task = 'simin_gtout'
100
  opt.scale = 1
101
  opt.nch_in = 9
102
  opt.nch_out = 1
103
 
104
-
105
  return opt
106
 
107
 
108
-
109
  def GetOptions_allRnd_0317():
110
  # training options
111
  opt = Namespace()
112
- opt.model = 'rcan'
113
  opt.n_resgroups = 3
114
  opt.n_resblocks = 10
115
  opt.n_feats = 96
116
  opt.reduction = 16
117
  opt.narch = 0
118
- opt.norm = 'minmax'
119
 
120
  opt.cpu = False
121
  opt.multigpu = False
122
  opt.undomulti = False
123
- opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')
 
 
124
 
125
  opt.imageSize = 512
126
  opt.weights = "model/DIV2K_randomised_3x3_20200317.pth"
127
  opt.root = "model/0080.jpg"
128
  opt.out = "model/myout"
129
 
130
- opt.task = 'simin_gtout'
131
  opt.scale = 1
132
  opt.nch_in = 9
133
  opt.nch_out = 1
134
 
135
-
136
  return opt
137
 
138
 
139
-
140
  def LoadModel(opt):
141
- print('Loading model')
142
  print(opt)
143
 
144
  net = GetModel(opt)
145
- print('loading checkpoint',opt.weights)
146
- checkpoint = torch.load(opt.weights,map_location=opt.device)
147
 
148
  if type(checkpoint) is dict:
149
- state_dict = checkpoint['state_dict']
150
  else:
151
  state_dict = checkpoint
152
 
@@ -157,22 +160,22 @@ def LoadModel(opt):
157
  return net
158
 
159
 
160
- def prepimg(stack,self):
161
-
162
  inputimg = stack[:9]
163
 
164
  if self.nch_in == 6:
165
- inputimg = inputimg[[0,1,3,4,6,7]]
166
  elif self.nch_in == 3:
167
- inputimg = inputimg[[0,4,8]]
168
 
169
  if inputimg.shape[1] > 512 or inputimg.shape[2] > 512:
170
- print('Over 512x512! Cropping')
171
- inputimg = inputimg[:,:512,:512]
172
 
173
-
174
- if self.norm == 'convert': # raw img from microscope, needs normalisation and correct frame ordering
175
- print('Raw input assumed - converting')
 
176
  # NCHW
177
  # I = np.zeros((9,opt.imageSize,opt.imageSize),dtype='uint16')
178
 
@@ -183,86 +186,96 @@ def prepimg(stack,self):
183
  # I[t,:,:] = frame
184
  # inputimg = I
185
 
186
- inputimg = np.rot90(inputimg,axes=(1,2))
187
- inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0]
 
 
188
  for i in range(len(inputimg)):
189
  inputimg[i] = 100 / np.max(inputimg[i]) * inputimg[i]
190
- elif 'convert' in self.norm:
191
  fac = float(self.norm[7:])
192
- inputimg = np.rot90(inputimg,axes=(1,2))
193
- inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0]
 
 
194
  for i in range(len(inputimg)):
195
  inputimg[i] = fac * 255 / np.max(inputimg[i]) * inputimg[i]
196
 
 
 
197
 
198
- inputimg = inputimg.astype('float') / np.max(inputimg) # used to be /255
199
- widefield = np.mean(inputimg,0)
200
-
201
- if self.norm == 'adapthist':
202
  for i in range(len(inputimg)):
203
- inputimg[i] = exposure.equalize_adapthist(inputimg[i],clip_limit=0.001)
204
- widefield = exposure.equalize_adapthist(widefield,clip_limit=0.001)
205
  else:
206
  # normalise
207
  inputimg = torch.tensor(inputimg).float()
208
  widefield = torch.tensor(widefield).float()
209
- widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield))
 
 
210
 
211
- if self.norm == 'minmax':
212
  for i in range(len(inputimg)):
213
- inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i]))
214
- elif 'minmax' in self.norm:
 
 
215
  fac = float(self.norm[6:])
216
  for i in range(len(inputimg)):
217
- inputimg[i] = fac * (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i]))
218
-
219
-
 
 
220
 
221
  # otf = torch.tensor(otf.astype('float') / np.max(otf)).unsqueeze(0).float()
222
  # gt = torch.tensor(gt.astype('float') / 255).unsqueeze(0).float()
223
  # simimg = torch.tensor(simimg.astype('float') / 255).unsqueeze(0).float()
224
  # widefield = torch.mean(inputimg,0).unsqueeze(0)
225
 
226
-
227
  # normalise
228
  # gt = (gt - torch.min(gt)) / (torch.max(gt) - torch.min(gt))
229
  # simimg = (simimg - torch.min(simimg)) / (torch.max(simimg) - torch.min(simimg))
230
  # widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield))
231
  inputimg = torch.tensor(inputimg).float()
232
  widefield = torch.tensor(widefield).float()
233
- return inputimg,widefield
 
234
 
235
- def save_image(data, filename,cmap):
236
  sizes = np.shape(data)
237
  fig = plt.figure()
238
- fig.set_size_inches(1. * sizes[0] / sizes[1], 1, forward = False)
239
- ax = plt.Axes(fig, [0., 0., 1., 1.])
240
  ax.set_axis_off()
241
  fig.add_axes(ax)
242
  ax.imshow(data, cmap=cmap)
243
- plt.savefig(filename, dpi = sizes[0])
244
  plt.close()
245
 
246
 
247
- def EvaluateModel(net,opt,stack):
248
-
249
- outfile = datetime.datetime.utcnow().strftime('%H-%M-%S')
250
- outfile = 'ML-SIM_%s' % outfile
251
 
252
  os.makedirs(opt.out, exist_ok=True)
253
 
254
  print(stack.shape)
255
  inputimg, widefield = prepimg(stack, opt)
256
 
257
- if opt.norm == 'convert' or 'minmax' in opt.norm or 'adapthist' in opt.norm:
258
- cmap = 'viridis'
259
  else:
260
- cmap = 'gray'
261
 
262
  # skimage.io.imsave('%s_wf.png' % outfile,(255*widefield.numpy()).astype('uint8'))
263
- wf = (255*widefield.numpy()).astype('uint8')
264
- wf_upscaled = skimage.transform.rescale(wf,1.5,order=3,multichannel=False) # should ideally be done by drawing on client side, in javascript
265
- save_image(wf_upscaled,'%s_wf.png' % outfile,cmap)
 
 
266
 
267
  # skimage.io.imsave('%s.tif' % outfile, inputimg.numpy())
268
 
@@ -271,21 +284,23 @@ def EvaluateModel(net,opt,stack):
271
  with torch.no_grad():
272
  sr = net(inputimg.to(opt.device))
273
  sr = sr.cpu()
274
- sr = torch.clamp(sr,min=0,max=1)
275
- print('min max',inputimg.min(),inputimg.max())
276
 
277
  pil_sr_img = toPIL(sr[0])
278
 
279
- if opt.norm == 'convert':
280
- pil_sr_img = transforms.functional.rotate(pil_sr_img,-90)
281
 
282
  # pil_sr_img.save('%s.png' % outfile) # true output for downloading, no LUT
283
  sr_img = np.array(pil_sr_img)
284
  # sr_img = exposure.equalize_adapthist(sr_img,clip_limit=0.01)
285
- skimage.io.imsave('%s.png' % outfile, sr_img) # true out for downloading, no LUT
286
 
287
- sr_img = skimage.transform.rescale(sr_img,1.5,order=3,multichannel=False) # should ideally be done by drawing on client side, in javascript
 
 
288
 
289
- save_image(sr_img,'%s_sr.png' % outfile,cmap)
290
- return outfile + '_sr.png', outfile + '_wf.png', outfile + '.png'
291
  # return wf, sr_img, outfile
26
 
27
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
28
 
29
+
30
  def remove_dataparallel_wrapper(state_dict):
31
+ r"""Converts a DataParallel model to a normal one by removing the "module."
32
+ wrapper in the module dictionary
33
+
34
+ Args:
35
+ state_dict: a torch.nn.DataParallel state dictionary
36
+ """
37
+ from collections import OrderedDict
38
 
39
+ new_state_dict = OrderedDict()
40
+ for k, vl in state_dict.items():
41
+ name = k[7:] # remove 'module.' of DataParallel
42
+ new_state_dict[name] = vl
43
 
44
+ return new_state_dict
 
 
 
45
 
 
46
 
47
  from argparse import Namespace
48
 
50
  def GetOptions():
51
  # training options
52
  opt = Namespace()
53
+ opt.model = "rcan"
54
  opt.n_resgroups = 3
55
  opt.n_resblocks = 10
56
  opt.n_feats = 96
57
  opt.reduction = 16
58
  opt.narch = 0
59
+ opt.norm = "minmax"
60
 
61
  opt.cpu = False
62
  opt.multigpu = False
63
  opt.undomulti = False
64
+ opt.device = torch.device(
65
+ "cuda" if torch.cuda.is_available() and not opt.cpu else "cpu"
66
+ )
67
 
68
  opt.imageSize = 512
69
  opt.weights = "model/simrec_simin_gtout_rcan_512_2_ntrain790-final.pth"
70
  opt.root = "model/0080.jpg"
71
  opt.out = "model/myout"
72
 
73
+ opt.task = "simin_gtout"
74
  opt.scale = 1
75
  opt.nch_in = 9
76
  opt.nch_out = 1
77
 
 
78
  return opt
79
 
80
 
81
  def GetOptions_allRnd_0215():
82
  # training options
83
  opt = Namespace()
84
+ opt.model = "rcan"
85
  opt.n_resgroups = 3
86
  opt.n_resblocks = 10
87
  opt.n_feats = 48
88
  opt.reduction = 16
89
  opt.narch = 0
90
+ opt.norm = "adapthist"
91
 
92
  opt.cpu = False
93
  opt.multigpu = False
94
  opt.undomulti = False
95
+ opt.device = torch.device(
96
+ "cuda" if torch.cuda.is_available() and not opt.cpu else "cpu"
97
+ )
98
 
99
  opt.imageSize = 512
100
  opt.weights = "model/0216_SIMRec_0214_rndAll_rcan_continued.pth"
101
  opt.root = "model/0080.jpg"
102
  opt.out = "model/myout"
103
 
104
+ opt.task = "simin_gtout"
105
  opt.scale = 1
106
  opt.nch_in = 9
107
  opt.nch_out = 1
108
 
 
109
  return opt
110
 
111
 
 
112
  def GetOptions_allRnd_0317():
113
  # training options
114
  opt = Namespace()
115
+ opt.model = "rcan"
116
  opt.n_resgroups = 3
117
  opt.n_resblocks = 10
118
  opt.n_feats = 96
119
  opt.reduction = 16
120
  opt.narch = 0
121
+ opt.norm = "minmax"
122
 
123
  opt.cpu = False
124
  opt.multigpu = False
125
  opt.undomulti = False
126
+ opt.device = torch.device(
127
+ "cuda" if torch.cuda.is_available() and not opt.cpu else "cpu"
128
+ )
129
 
130
  opt.imageSize = 512
131
  opt.weights = "model/DIV2K_randomised_3x3_20200317.pth"
132
  opt.root = "model/0080.jpg"
133
  opt.out = "model/myout"
134
 
135
+ opt.task = "simin_gtout"
136
  opt.scale = 1
137
  opt.nch_in = 9
138
  opt.nch_out = 1
139
 
 
140
  return opt
141
 
142
 
 
143
  def LoadModel(opt):
144
+ print("Loading model")
145
  print(opt)
146
 
147
  net = GetModel(opt)
148
+ print("loading checkpoint", opt.weights)
149
+ checkpoint = torch.load(opt.weights, map_location=opt.device)
150
 
151
  if type(checkpoint) is dict:
152
+ state_dict = checkpoint["state_dict"]
153
  else:
154
  state_dict = checkpoint
155
 
160
  return net
161
 
162
 
163
+ def prepimg(stack, self):
 
164
  inputimg = stack[:9]
165
 
166
  if self.nch_in == 6:
167
+ inputimg = inputimg[[0, 1, 3, 4, 6, 7]]
168
  elif self.nch_in == 3:
169
+ inputimg = inputimg[[0, 4, 8]]
170
 
171
  if inputimg.shape[1] > 512 or inputimg.shape[2] > 512:
172
+ print("Over 512x512! Cropping")
173
+ inputimg = inputimg[:, :512, :512]
174
 
175
+ if (
176
+ self.norm == "convert"
177
+ ): # raw img from microscope, needs normalisation and correct frame ordering
178
+ print("Raw input assumed - converting")
179
  # NCHW
180
  # I = np.zeros((9,opt.imageSize,opt.imageSize),dtype='uint16')
181
 
186
  # I[t,:,:] = frame
187
  # inputimg = I
188
 
189
+ inputimg = np.rot90(inputimg, axes=(1, 2))
190
+ inputimg = inputimg[
191
+ [6, 7, 8, 3, 4, 5, 0, 1, 2]
192
+ ] # could also do [8,7,6,5,4,3,2,1,0]
193
  for i in range(len(inputimg)):
194
  inputimg[i] = 100 / np.max(inputimg[i]) * inputimg[i]
195
+ elif "convert" in self.norm:
196
  fac = float(self.norm[7:])
197
+ inputimg = np.rot90(inputimg, axes=(1, 2))
198
+ inputimg = inputimg[
199
+ [6, 7, 8, 3, 4, 5, 0, 1, 2]
200
+ ] # could also do [8,7,6,5,4,3,2,1,0]
201
  for i in range(len(inputimg)):
202
  inputimg[i] = fac * 255 / np.max(inputimg[i]) * inputimg[i]
203
 
204
+ inputimg = inputimg.astype("float") / np.max(inputimg) # used to be /255
205
+ widefield = np.mean(inputimg, 0)
206
 
207
+ if self.norm == "adapthist":
 
 
 
208
  for i in range(len(inputimg)):
209
+ inputimg[i] = exposure.equalize_adapthist(inputimg[i], clip_limit=0.001)
210
+ widefield = exposure.equalize_adapthist(widefield, clip_limit=0.001)
211
  else:
212
  # normalise
213
  inputimg = torch.tensor(inputimg).float()
214
  widefield = torch.tensor(widefield).float()
215
+ widefield = (widefield - torch.min(widefield)) / (
216
+ torch.max(widefield) - torch.min(widefield)
217
+ )
218
 
219
+ if self.norm == "minmax":
220
  for i in range(len(inputimg)):
221
+ inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / (
222
+ torch.max(inputimg[i]) - torch.min(inputimg[i])
223
+ )
224
+ elif "minmax" in self.norm:
225
  fac = float(self.norm[6:])
226
  for i in range(len(inputimg)):
227
+ inputimg[i] = (
228
+ fac
229
+ * (inputimg[i] - torch.min(inputimg[i]))
230
+ / (torch.max(inputimg[i]) - torch.min(inputimg[i]))
231
+ )
232
 
233
  # otf = torch.tensor(otf.astype('float') / np.max(otf)).unsqueeze(0).float()
234
  # gt = torch.tensor(gt.astype('float') / 255).unsqueeze(0).float()
235
  # simimg = torch.tensor(simimg.astype('float') / 255).unsqueeze(0).float()
236
  # widefield = torch.mean(inputimg,0).unsqueeze(0)
237
 
 
238
  # normalise
239
  # gt = (gt - torch.min(gt)) / (torch.max(gt) - torch.min(gt))
240
  # simimg = (simimg - torch.min(simimg)) / (torch.max(simimg) - torch.min(simimg))
241
  # widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield))
242
  inputimg = torch.tensor(inputimg).float()
243
  widefield = torch.tensor(widefield).float()
244
+ return inputimg, widefield
245
+
246
 
247
+ def save_image(data, filename, cmap):
248
  sizes = np.shape(data)
249
  fig = plt.figure()
250
+ fig.set_size_inches(1.0 * sizes[0] / sizes[1], 1, forward=False)
251
+ ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
252
  ax.set_axis_off()
253
  fig.add_axes(ax)
254
  ax.imshow(data, cmap=cmap)
255
+ plt.savefig(filename, dpi=sizes[0])
256
  plt.close()
257
 
258
 
259
+ def EvaluateModel(net, opt, stack):
260
+ outfile = datetime.datetime.utcnow().strftime("%H-%M-%S")
261
+ outfile = "ML-SIM_%s" % outfile
 
262
 
263
  os.makedirs(opt.out, exist_ok=True)
264
 
265
  print(stack.shape)
266
  inputimg, widefield = prepimg(stack, opt)
267
 
268
+ if opt.norm == "convert" or "minmax" in opt.norm or "adapthist" in opt.norm:
269
+ cmap = "viridis"
270
  else:
271
+ cmap = "gray"
272
 
273
  # skimage.io.imsave('%s_wf.png' % outfile,(255*widefield.numpy()).astype('uint8'))
274
+ wf = (255 * widefield.numpy()).astype("uint8")
275
+ wf_upscaled = skimage.transform.rescale(
276
+ wf, 1.5, order=3
277
+ ) # should ideally be done by drawing on client side, in javascript
278
+ save_image(wf_upscaled, "%s_wf.png" % outfile, cmap)
279
 
280
  # skimage.io.imsave('%s.tif' % outfile, inputimg.numpy())
281
 
284
  with torch.no_grad():
285
  sr = net(inputimg.to(opt.device))
286
  sr = sr.cpu()
287
+ sr = torch.clamp(sr, min=0, max=1)
288
+ print("min max", inputimg.min(), inputimg.max())
289
 
290
  pil_sr_img = toPIL(sr[0])
291
 
292
+ if opt.norm == "convert":
293
+ pil_sr_img = transforms.functional.rotate(pil_sr_img, -90)
294
 
295
  # pil_sr_img.save('%s.png' % outfile) # true output for downloading, no LUT
296
  sr_img = np.array(pil_sr_img)
297
  # sr_img = exposure.equalize_adapthist(sr_img,clip_limit=0.01)
298
+ skimage.io.imsave("%s.png" % outfile, sr_img) # true out for downloading, no LUT
299
 
300
+ sr_img = skimage.transform.rescale(
301
+ sr_img, 1.5, order=3
302
+ ) # should ideally be done by drawing on client side, in javascript
303
 
304
+ save_image(sr_img, "%s_sr.png" % outfile, cmap)
305
+ return outfile + "_sr.png", outfile + "_wf.png", outfile + ".png"
306
  # return wf, sr_img, outfile