ericup commited on
Commit
20b8a33
1 Parent(s): bd5e5cb

General Update

Browse files
Files changed (4) hide show
  1. app.py +143 -29
  2. cpn.py +6 -2
  3. examples/bbbc039_test_00014.png +0 -0
  4. util.py +9 -9
app.py CHANGED
@@ -1,24 +1,53 @@
1
  import spaces
2
  import gradio as gr
3
- from util import imread, imsave, get_examples
4
  import torch
 
 
 
 
5
 
6
  def torch_compile(*args, **kwargs):
7
  def decorator(func):
8
  return func
 
9
  return decorator
10
 
 
11
  torch.compile = torch_compile # temporary workaround
12
 
13
  default_model = 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  @spaces.GPU
17
- def predict(filename, model=None, device=None, reduce_labels=True):
 
 
 
 
 
 
 
 
18
  from cpn import CpnInterface
19
  from prep import multi_norm
20
- from celldetection import label_cmap
21
-
22
  global default_model
23
  assert isinstance(filename, str)
24
 
@@ -27,40 +56,125 @@ def predict(filename, model=None, device=None, reduce_labels=True):
27
  device = 'cuda'
28
  else:
29
  device = 'cpu'
30
-
31
- print(dict(
32
- filename=filename,
 
33
  model=model,
34
  device=device,
35
- reduce_labels=reduce_labels
36
- ), flush=True)
 
 
 
 
 
 
 
37
 
38
- img = imread(filename)
39
  print('Image:', img.dtype, img.shape, (img.min(), img.max()), flush=True)
40
  if model is None or len(str(model)) <= 0:
41
  model = default_model
42
 
43
  img = multi_norm(img, 'cstm-mix') # TODO
44
 
45
- m = CpnInterface(model.strip(), device=device)
46
- y = m(img, reduce_labels=reduce_labels)
 
 
 
 
 
 
 
 
 
47
 
48
- labels = y['labels']
 
 
 
 
49
 
 
50
  vis_labels = label_cmap(labels)
51
- dst = '.'.join(filename.split('.')[:-1]) + '_labels.tiff'
52
- imsave(dst, labels)
53
-
54
- return img, vis_labels, dst
55
-
56
-
57
- gr.Interface(
58
- predict,
59
- inputs=[gr.components.Image(label="Upload Input Image", type="filepath"),
60
- gr.components.Textbox(label='Model Name', value=default_model, max_lines=1)],
61
- outputs=[gr.Image(label="Processed Image"),
62
- gr.Image(label="Label Image"),
63
- gr.File(label="Download Label Image")],
64
- title="Cell Detection with Contour Proposal Networks",
65
- examples=get_examples(default_model)
66
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
+ from util import imread, imsave, copy_skimage_data
4
  import torch
5
+ from PIL import Image, ImageDraw
6
+ import numpy as np
7
+ from os.path import join
8
+
9
 
10
  def torch_compile(*args, **kwargs):
11
  def decorator(func):
12
  return func
13
+
14
  return decorator
15
 
16
+
17
  torch.compile = torch_compile # temporary workaround
18
 
19
  default_model = 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c'
20
+ default_score_thresh = .9
21
+ default_nms_thresh = np.round(np.pi / 10, 4)
22
+ default_samples = 128
23
+ default_order = 5
24
+
25
+ examples_dir = 'examples'
26
+ copy_skimage_data(examples_dir)
27
+ examples = [
28
+ [join(examples_dir, 'bbbc039_test_00014.png'), 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c', False, default_score_thresh, False,
29
+ default_nms_thresh, True, 64, True],
30
+ [join(examples_dir, 'coins.png'), 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c', False, default_score_thresh, False,
31
+ default_nms_thresh, True, 64, True],
32
+ [join(examples_dir, 'cell.png'), 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c', False, default_score_thresh, False,
33
+ default_nms_thresh, True, 64, True],
34
+ ]
35
 
36
 
37
  @spaces.GPU
38
+ def predict(
39
+ filename, model=None,
40
+ enable_score_threshold=False, score_threshold=.9,
41
+ enable_nms_threshold=False, nms_threshold=0.3141592653589793,
42
+ enable_samples=False, samples=128,
43
+ use_label_channels=False,
44
+ enable_order=False, order=5,
45
+ device=None,
46
+ ):
47
  from cpn import CpnInterface
48
  from prep import multi_norm
49
+ from celldetection import label_cmap, to_h5, data, __version__
50
+
51
  global default_model
52
  assert isinstance(filename, str)
53
 
 
56
  device = 'cuda'
57
  else:
58
  device = 'cpu'
59
+
60
+ meta = dict(
61
+ cd_version=__version__,
62
+ filename=str(filename),
63
  model=model,
64
  device=device,
65
+ use_label_channels=use_label_channels,
66
+ enable_score_threshold=enable_score_threshold,
67
+ score_threshold=float(score_threshold),
68
+ enable_order=enable_order,
69
+ order=order,
70
+ enable_nms_threshold=enable_nms_threshold,
71
+ nms_threshold=float(nms_threshold),
72
+ )
73
+ print(meta, flush=True)
74
 
75
+ raw = img = imread(filename)
76
  print('Image:', img.dtype, img.shape, (img.min(), img.max()), flush=True)
77
  if model is None or len(str(model)) <= 0:
78
  model = default_model
79
 
80
  img = multi_norm(img, 'cstm-mix') # TODO
81
 
82
+ kw = {}
83
+ if enable_score_threshold:
84
+ kw['score_thresh'] = score_threshold
85
+ if enable_nms_threshold:
86
+ kw['nms_thresh'] = nms_threshold
87
+ if enable_order:
88
+ kw['order'] = order
89
+ if enable_samples:
90
+ kw['samples'] = samples
91
+ m = CpnInterface(model.strip(), device=device, **kw)
92
+ y = m(img, reduce_labels=not use_label_channels)
93
 
94
+ dst_h5 = '.'.join(filename.split('.')[:-1]) + '.h5'
95
+ to_h5(
96
+ dst_h5, inputs=img, **y,
97
+ attributes=dict(inputs=meta)
98
+ )
99
 
100
+ labels = y['labels']
101
  vis_labels = label_cmap(labels)
102
+
103
+ dst_csv = '.'.join(filename.split('.')[:-1]) + '.csv'
104
+ data.labels2property_table(
105
+ labels,
106
+ "label", "area", "feret_diameter_max", "bbox", "centroid", "convex_area",
107
+ "eccentricity", "equivalent_diameter",
108
+ "extent", "filled_area", "major_axis_length",
109
+ "minor_axis_length", "orientation", "perimeter",
110
+ "solidity", "mean_intensity", "max_intensity", "min_intensity",
111
+ intensity_image=raw
112
+ ).to_csv(dst_csv)
113
+
114
+ return vis_labels, img, dst_h5, dst_csv
115
+
116
+
117
+ with gr.Blocks(title='Cell Segmentation with Contour Proposal Networks') as app:
118
+ with gr.Row():
119
+ gr.Markdown("<center><strong><font size='7'>"
120
+ "Cell Segmentation with Contour Proposal Networks 🤗</font></strong></center>")
121
+
122
+ with gr.Row():
123
+ with gr.Column():
124
+ img = gr.components.Image(label="Upload Input Image", type="filepath", interactive=True,
125
+ value=examples[0][0])
126
+ with gr.Column():
127
+ model_name = gr.components.Textbox(label='Model Name', value=default_model, max_lines=1)
128
+ with gr.Row():
129
+ score_thresh_ck = gr.components.Checkbox(label="Use custom Score Threshold", value=False)
130
+ score_thresh = gr.components.Slider(minimum=0, maximum=1, label="Score Threshold",
131
+ value=default_score_thresh)
132
+ with gr.Row():
133
+ nms_thresh_ck = gr.components.Checkbox(label="Use custom NMS Threshold", value=False)
134
+ nms_thresh = gr.components.Slider(minimum=0, maximum=1, label="NMS Threshold", value=default_nms_thresh)
135
+ # with gr.Row():
136
+ # # The range of this would need to be model dependent
137
+ # order_ck = gr.components.Checkbox(label="Use custom Order", value=False)
138
+ # order = gr.components.Slider(minimum=0, maximum=1, label="Order", value=default_order)
139
+ with gr.Row():
140
+ samples_ck = gr.components.Checkbox(label="Use custom Sample Points", value=False)
141
+ samples = gr.components.Slider(minimum=8, maximum=256, label="Sample Points", value=default_samples)
142
+ with gr.Row():
143
+ channels = gr.components.Checkbox(label="Allow overlapping objects", value=True)
144
+ with gr.Row():
145
+ clr = gr.Button('Reset')
146
+ btn = gr.Button('Run')
147
+ with gr.Row():
148
+ with gr.Column():
149
+ out_img = gr.Image(label="Processed Image")
150
+ with gr.Column():
151
+ out_vis = gr.Image(label="Label Image (random colors, transparent overlap)")
152
+ with gr.Row():
153
+ out_h5 = gr.File(label="Download Results as HDF5 File")
154
+ out_csv = gr.File(label="Download Properties as CSV File")
155
+
156
+ with gr.Row():
157
+ gr.Examples(
158
+ fn=predict,
159
+ examples=examples,
160
+ inputs=[img, model_name, score_thresh_ck, score_thresh, nms_thresh_ck, nms_thresh, samples_ck, samples,
161
+ channels],
162
+ outputs=[out_vis, out_img, out_h5, out_csv],
163
+ cache_examples=True,
164
+ batch=False
165
+ )
166
+
167
+ btn.click(
168
+ predict,
169
+ inputs=[img, model_name, score_thresh_ck, score_thresh, nms_thresh_ck, nms_thresh, samples_ck, samples,
170
+ channels],
171
+ outputs=[out_vis, out_img, out_h5, out_csv]
172
+ )
173
+ clr.click(
174
+ lambda: (
175
+ None, default_score_thresh, default_nms_thresh, False, False, None, None, None, False, default_samples),
176
+ inputs=[],
177
+ outputs=[img, score_thresh, nms_thresh, score_thresh_ck, nms_thresh_ck, out_img, out_h5, out_vis, samples_ck,
178
+ samples]
179
+ )
180
+ app.launch()
cpn.py CHANGED
@@ -32,10 +32,14 @@ def contours2labels(contours, size, overlap=False, max_iter=999):
32
 
33
 
34
  class CpnInterface:
35
- def __init__(self, model, device=None):
36
  self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
37
- self.model = cd.models.LitCpn(model).to(device)
 
 
 
38
  self.model.eval()
 
39
  self.tile_size = 1664
40
  self.overlap = 384
41
 
 
32
 
33
 
34
  class CpnInterface:
35
+ def __init__(self, model, device=None, **kwargs):
36
  self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
37
+ model = cd.resolve_model(model, **kwargs)
38
+ if not isinstance(model, cd.models.LitCpn):
39
+ model = cd.models.LitCpn(model)
40
+ self.model = model.to(device)
41
  self.model.eval()
42
+ self.model.requires_grad_(False)
43
  self.tile_size = 1664
44
  self.overlap = 384
45
 
examples/bbbc039_test_00014.png ADDED
util.py CHANGED
@@ -1,7 +1,8 @@
1
  from imageio.v2 import imread as _imread
 
2
  import tifffile as tif
3
 
4
- __all__ = ['imread', 'imsave', 'get_examples']
5
 
6
 
7
  def imread(filename):
@@ -14,14 +15,13 @@ def imsave(filename, img, compression="zlib"):
14
  tif.imwrite(filename, img, compression=compression)
15
 
16
 
17
- def get_examples(default_model):
18
  from skimage import data
 
19
  from os.path import dirname, join, isfile
 
20
 
21
- examples = []
22
- for f in ['coins.png']:
23
- f = join(dirname(data.__file__), f)
24
- if isfile(f):
25
- examples.append([f, default_model])
26
- if len(examples):
27
- return examples
 
1
  from imageio.v2 import imread as _imread
2
+ from shutil import copy2
3
  import tifffile as tif
4
 
5
+ __all__ = ['imread', 'imsave', 'copy_skimage_data']
6
 
7
 
8
  def imread(filename):
 
15
  tif.imwrite(filename, img, compression=compression)
16
 
17
 
18
+ def copy_skimage_data(dst='examples'):
19
  from skimage import data
20
+ from os import makedirs
21
  from os.path import dirname, join, isfile
22
+ from glob import glob
23
 
24
+ makedirs(dst, exist_ok=True)
25
+
26
+ for f in glob(join(dirname(data.__file__), '*.png')):
27
+ copy2(f, dst)