charlesnchr commited on
Commit
3715c63
1 Parent(s): ba9a83a

First version with RGB image input

Browse files
Files changed (5) hide show
  1. NNfunctions.py +290 -0
  2. app.py +11 -51
  3. model/DIV2K_randomised_3x3_20200317.pth +3 -0
  4. models.py +1997 -0
  5. requirements.txt +6 -1
NNfunctions.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import math
3
+ import os
4
+
5
+ import torch
6
+ import time
7
+
8
+ import skimage.io
9
+ import skimage.transform
10
+ import matplotlib.pyplot as plt
11
+ import glob
12
+
13
+ import torch.optim as optim
14
+ import torchvision
15
+ import torchvision.transforms as transforms
16
+ from skimage import exposure
17
+
18
+ toTensor = transforms.ToTensor()
19
+ toPIL = transforms.ToPILImage()
20
+
21
+
22
+ import numpy as np
23
+ from PIL import Image
24
+
25
+ from models import *
26
+
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
28
+
29
+ def remove_dataparallel_wrapper(state_dict):
30
+ r"""Converts a DataParallel model to a normal one by removing the "module."
31
+ wrapper in the module dictionary
32
+
33
+ Args:
34
+ state_dict: a torch.nn.DataParallel state dictionary
35
+ """
36
+ from collections import OrderedDict
37
+
38
+ new_state_dict = OrderedDict()
39
+ for k, vl in state_dict.items():
40
+ name = k[7:] # remove 'module.' of DataParallel
41
+ new_state_dict[name] = vl
42
+
43
+ return new_state_dict
44
+
45
+ from argparse import Namespace
46
+
47
+
48
+ def GetOptions():
49
+ # training options
50
+ opt = Namespace()
51
+ opt.model = 'rcan'
52
+ opt.n_resgroups = 3
53
+ opt.n_resblocks = 10
54
+ opt.n_feats = 96
55
+ opt.reduction = 16
56
+ opt.narch = 0
57
+ opt.norm = 'minmax'
58
+
59
+ opt.cpu = False
60
+ opt.multigpu = False
61
+ opt.undomulti = False
62
+ opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')
63
+
64
+ opt.imageSize = 512
65
+ opt.weights = "model/simrec_simin_gtout_rcan_512_2_ntrain790-final.pth"
66
+ opt.root = "model/0080.jpg"
67
+ opt.out = "model/myout"
68
+
69
+ opt.task = 'simin_gtout'
70
+ opt.scale = 1
71
+ opt.nch_in = 9
72
+ opt.nch_out = 1
73
+
74
+
75
+ return opt
76
+
77
+
78
+ def GetOptions_allRnd_0215():
79
+ # training options
80
+ opt = Namespace()
81
+ opt.model = 'rcan'
82
+ opt.n_resgroups = 3
83
+ opt.n_resblocks = 10
84
+ opt.n_feats = 48
85
+ opt.reduction = 16
86
+ opt.narch = 0
87
+ opt.norm = 'adapthist'
88
+
89
+ opt.cpu = False
90
+ opt.multigpu = False
91
+ opt.undomulti = False
92
+ opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')
93
+
94
+ opt.imageSize = 512
95
+ opt.weights = "model/0216_SIMRec_0214_rndAll_rcan_continued.pth"
96
+ opt.root = "model/0080.jpg"
97
+ opt.out = "model/myout"
98
+
99
+ opt.task = 'simin_gtout'
100
+ opt.scale = 1
101
+ opt.nch_in = 9
102
+ opt.nch_out = 1
103
+
104
+
105
+ return opt
106
+
107
+
108
+
109
+ def GetOptions_allRnd_0317():
110
+ # training options
111
+ opt = Namespace()
112
+ opt.model = 'rcan'
113
+ opt.n_resgroups = 3
114
+ opt.n_resblocks = 10
115
+ opt.n_feats = 96
116
+ opt.reduction = 16
117
+ opt.narch = 0
118
+ opt.norm = 'minmax'
119
+
120
+ opt.cpu = False
121
+ opt.multigpu = False
122
+ opt.undomulti = False
123
+ opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')
124
+
125
+ opt.imageSize = 512
126
+ opt.weights = "model/DIV2K_randomised_3x3_20200317.pth"
127
+ opt.root = "model/0080.jpg"
128
+ opt.out = "model/myout"
129
+
130
+ opt.task = 'simin_gtout'
131
+ opt.scale = 1
132
+ opt.nch_in = 9
133
+ opt.nch_out = 1
134
+
135
+
136
+ return opt
137
+
138
+
139
+
140
+ def LoadModel(opt):
141
+ print('Loading model')
142
+ print(opt)
143
+
144
+ net = GetModel(opt)
145
+ print('loading checkpoint',opt.weights)
146
+ checkpoint = torch.load(opt.weights,map_location=opt.device)
147
+
148
+ if type(checkpoint) is dict:
149
+ state_dict = checkpoint['state_dict']
150
+ else:
151
+ state_dict = checkpoint
152
+
153
+ if opt.undomulti:
154
+ state_dict = remove_dataparallel_wrapper(state_dict)
155
+ net.load_state_dict(state_dict)
156
+
157
+ return net
158
+
159
+
160
+ def prepimg(stack,self):
161
+
162
+ inputimg = stack[:9]
163
+
164
+ if self.nch_in == 6:
165
+ inputimg = inputimg[[0,1,3,4,6,7]]
166
+ elif self.nch_in == 3:
167
+ inputimg = inputimg[[0,4,8]]
168
+
169
+ if inputimg.shape[1] > 512 or inputimg.shape[2] > 512:
170
+ print('Over 512x512! Cropping')
171
+ inputimg = inputimg[:,:512,:512]
172
+
173
+
174
+ if self.norm == 'convert': # raw img from microscope, needs normalisation and correct frame ordering
175
+ print('Raw input assumed - converting')
176
+ # NCHW
177
+ # I = np.zeros((9,opt.imageSize,opt.imageSize),dtype='uint16')
178
+
179
+ # for t in range(9):
180
+ # frame = inputimg[t]
181
+ # frame = 120 / np.max(frame) * frame
182
+ # frame = np.rot90(np.rot90(np.rot90(frame)))
183
+ # I[t,:,:] = frame
184
+ # inputimg = I
185
+
186
+ inputimg = np.rot90(inputimg,axes=(1,2))
187
+ inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0]
188
+ for i in range(len(inputimg)):
189
+ inputimg[i] = 100 / np.max(inputimg[i]) * inputimg[i]
190
+ elif 'convert' in self.norm:
191
+ fac = float(self.norm[7:])
192
+ inputimg = np.rot90(inputimg,axes=(1,2))
193
+ inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0]
194
+ for i in range(len(inputimg)):
195
+ inputimg[i] = fac * 255 / np.max(inputimg[i]) * inputimg[i]
196
+
197
+
198
+ inputimg = inputimg.astype('float') / np.max(inputimg) # used to be /255
199
+ widefield = np.mean(inputimg,0)
200
+
201
+ if self.norm == 'adapthist':
202
+ for i in range(len(inputimg)):
203
+ inputimg[i] = exposure.equalize_adapthist(inputimg[i],clip_limit=0.001)
204
+ widefield = exposure.equalize_adapthist(widefield,clip_limit=0.001)
205
+ else:
206
+ # normalise
207
+ inputimg = torch.tensor(inputimg).float()
208
+ widefield = torch.tensor(widefield).float()
209
+ widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield))
210
+
211
+ if self.norm == 'minmax':
212
+ for i in range(len(inputimg)):
213
+ inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i]))
214
+ elif 'minmax' in self.norm:
215
+ fac = float(self.norm[6:])
216
+ for i in range(len(inputimg)):
217
+ inputimg[i] = fac * (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i]))
218
+
219
+
220
+
221
+ # otf = torch.tensor(otf.astype('float') / np.max(otf)).unsqueeze(0).float()
222
+ # gt = torch.tensor(gt.astype('float') / 255).unsqueeze(0).float()
223
+ # simimg = torch.tensor(simimg.astype('float') / 255).unsqueeze(0).float()
224
+ # widefield = torch.mean(inputimg,0).unsqueeze(0)
225
+
226
+
227
+ # normalise
228
+ # gt = (gt - torch.min(gt)) / (torch.max(gt) - torch.min(gt))
229
+ # simimg = (simimg - torch.min(simimg)) / (torch.max(simimg) - torch.min(simimg))
230
+ # widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield))
231
+ inputimg = torch.tensor(inputimg).float()
232
+ widefield = torch.tensor(widefield).float()
233
+ return inputimg,widefield
234
+
235
+ def save_image(data, filename,cmap):
236
+ sizes = np.shape(data)
237
+ fig = plt.figure()
238
+ fig.set_size_inches(1. * sizes[0] / sizes[1], 1, forward = False)
239
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
240
+ ax.set_axis_off()
241
+ fig.add_axes(ax)
242
+ ax.imshow(data, cmap=cmap)
243
+ plt.savefig(filename, dpi = sizes[0])
244
+ plt.close()
245
+
246
+
247
+ def EvaluateModel(net,opt,stack):
248
+
249
+ os.makedirs(opt.out, exist_ok=True)
250
+
251
+ print(stack.shape)
252
+ inputimg, widefield = prepimg(stack, opt)
253
+
254
+ if opt.norm == 'convert' or 'minmax' in opt.norm or 'adapthist' in opt.norm:
255
+ cmap = 'magma'
256
+ else:
257
+ cmap = 'gray'
258
+
259
+ # skimage.io.imsave('%s_wf.png' % outfile,(255*widefield.numpy()).astype('uint8'))
260
+ wf = (255*widefield.numpy()).astype('uint8')
261
+ wf_upscaled = skimage.transform.rescale(wf,1.5,order=3,multichannel=False) # should ideally be done by drawing on client side, in javascript
262
+ # save_image(wf_upscaled,'%s_wf.png' % outfile,cmap)
263
+
264
+ # skimage.io.imsave('%s.tif' % outfile, inputimg.numpy())
265
+
266
+ inputimg = inputimg.unsqueeze(0)
267
+
268
+ with torch.no_grad():
269
+ if opt.cpu:
270
+ sr = net(inputimg)
271
+ else:
272
+ sr = net(inputimg.cuda())
273
+ sr = sr.cpu()
274
+ sr = torch.clamp(sr,min=0,max=1)
275
+ print('min max',inputimg.min(),inputimg.max())
276
+
277
+ pil_sr_img = toPIL(sr[0])
278
+
279
+ if opt.norm == 'convert':
280
+ pil_sr_img = transforms.functional.rotate(pil_sr_img,-90)
281
+
282
+ #pil_sr_img.save('%s.png' % outfile) # true output for downloading, no LUT
283
+ sr_img = np.array(pil_sr_img)
284
+ sr_img = exposure.equalize_adapthist(sr_img,clip_limit=0.01)
285
+ # skimage.io.imsave('%s.png' % outfile, sr_img) # true out for downloading, no LUT
286
+
287
+ # sr_img = skimage.transform.rescale(sr_img,1.5,order=3,multichannel=False) # should ideally be done by drawing on client side, in javascript
288
+ # save_image(sr_img,'%s_sr.png' % outfile,cmap)
289
+ # return outfile + '_sr.png', outfile + '_wf.png', outfile + '.png'
290
+ return sr_img
app.py CHANGED
@@ -7,67 +7,27 @@
7
 
8
  from turtle import title
9
  import gradio as gr
10
- from huggingface_hub import from_pretrained_keras
11
- import tensorflow as tf
12
  import numpy as np
13
  from PIL import Image
14
  import io
15
  import base64
 
16
 
17
-
18
- model = tf.keras.models.load_model("./tf_model.h5")
19
-
20
 
21
  def predict(image):
22
  img = np.array(image)
23
- original_shape = img.shape[:2]
24
-
25
- im = tf.image.resize(img, (128, 128))
26
- im = tf.cast(im, tf.float32) / 255.0
27
- pred_mask = model.predict(im[tf.newaxis, ...])
28
-
29
-
30
- # take the best performing class for each pixel
31
- # the output of argmax looks like this [[1, 2, 0], ...]
32
- pred_mask_arg = tf.argmax(pred_mask, axis=-1)
33
-
34
-
35
- # convert the prediction mask into binary masks for each class
36
- binary_masks = {}
37
-
38
- # when we take tf.argmax() over pred_mask, it becomes a tensor object
39
- # the shape becomes TensorShape object, looking like this TensorShape([128])
40
- # we need to take get shape, convert to list and take the best one
41
-
42
- rows = pred_mask_arg[0][1].get_shape().as_list()[0]
43
- cols = pred_mask_arg[0][2].get_shape().as_list()[0]
44
-
45
- for cls in range(pred_mask.shape[-1]):
46
-
47
- binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
48
-
49
- for row in range(rows):
50
-
51
- for col in range(cols):
52
-
53
- if pred_mask_arg[0][row][col] == cls:
54
-
55
- binary_masks[f"mask_{cls}"][row][col] = 1
56
- else:
57
- binary_masks[f"mask_{cls}"][row][col] = 0
58
-
59
- mask = binary_masks[f"mask_{cls}"]
60
- mask *= 255
61
 
62
- mask = np.array(Image.fromarray(mask).convert("L"))
63
- mask = tf.image.resize(mask[..., tf.newaxis], original_shape)
64
- mask = tf.cast(mask, tf.uint8)
65
- mask = mask.numpy().squeeze()
66
 
67
- return mask
68
 
69
 
70
- title = '<h1 style="text-align: center;">Segment Pets</h1>'
71
 
72
  description = """
73
  ## About
@@ -77,9 +37,9 @@ according to the pixels.
77
  Upload a pet image and hit submit or select one from the given examples
78
  """
79
 
80
- inputs = gr.inputs.Image(label="Upload a pet image", type = 'pil', optional=False)
81
  outputs = [
82
- gr.outputs.Image(label="Segmentation")
83
  # , gr.outputs.Textbox(type="auto",label="Pet Prediction")
84
  ]
85
 
7
 
8
  from turtle import title
9
  import gradio as gr
 
 
10
  import numpy as np
11
  from PIL import Image
12
  import io
13
  import base64
14
+ from NNfunctions import *
15
 
16
+ opt = GetOptions_allRnd_0317()
17
+ net = LoadModel(opt)
 
18
 
19
  def predict(image):
20
  img = np.array(image)
21
+ img = np.concatenate((img,img,img),axis=2)
22
+ img = np.transpose(img, (2,0,1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # sr,wf,out = EvaluateModel(net,opt,img,outfile)
25
+ sr_img = EvaluateModel(net,opt,img)
 
 
26
 
27
+ return sr_img
28
 
29
 
30
+ title = '<h1 style="text-align: center;">ML-SIM: Reconstruction of SIM images with deep learning</h1>'
31
 
32
  description = """
33
  ## About
37
  Upload a pet image and hit submit or select one from the given examples
38
  """
39
 
40
+ inputs = gr.inputs.Image(label="Upload a TIFF image", type = 'pil', optional=False)
41
  outputs = [
42
+ gr.outputs.Image(label="SIM Reconstruction")
43
  # , gr.outputs.Textbox(type="auto",label="Pet Prediction")
44
  ]
45
 
model/DIV2K_randomised_3x3_20200317.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4936f5ccf5db42009fa23ca3cbd63f53125dbd787240ca78f09e2b85c682a08
3
+ size 64467635
models.py ADDED
@@ -0,0 +1,1997 @@