ericup commited on
Commit
0e6708a
1 Parent(s): 8f5506b
Files changed (6) hide show
  1. README.md +1 -1
  2. app.py +48 -0
  3. cpn.py +71 -0
  4. prep.py +63 -0
  5. requirements.txt +2 -0
  6. util.py +27 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Celldetection
3
  emoji: ⚡
4
  colorFrom: pink
5
  colorTo: gray
 
1
  ---
2
+ title: CellDetection
3
  emoji: ⚡
4
  colorFrom: pink
5
  colorTo: gray
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from cpn import CpnInterface
3
+ from prep import multi_norm
4
+ from util import imread, imsave, get_examples
5
+ from celldetection import label_cmap
6
+
7
+ default_model = 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c'
8
+
9
+
10
+ def predict(filename, model=None, device=None, reduce_labels=True):
11
+ global default_model
12
+ assert isinstance(filename, str)
13
+ print(dict(
14
+ filename=filename,
15
+ model=model,
16
+ device=device,
17
+ reduce_labels=reduce_labels
18
+ ), flush=True)
19
+
20
+ img = imread(filename)
21
+ print('Image:', img.dtype, img.shape, (img.min(), img.max()), flush=True)
22
+ if model is None or len(str(model)) <= 0:
23
+ model = default_model
24
+
25
+ img = multi_norm(img, 'cstm-mix') # TODO
26
+
27
+ m = CpnInterface(model.strip(), device=device)
28
+ y = m(img, reduce_labels=reduce_labels)
29
+
30
+ labels = y['labels']
31
+
32
+ vis_labels = label_cmap(labels)
33
+ dst = '.'.join(filename.split('.')[:-1]) + '_labels.tiff'
34
+ imsave(dst, labels)
35
+
36
+ return img, vis_labels, dst
37
+
38
+
39
+ gr.Interface(
40
+ predict,
41
+ inputs=[gr.components.Image(label="Upload Input Image", type="filepath"),
42
+ gr.components.Textbox(label='Model Name', value=default_model, max_lines=1)],
43
+ outputs=[gr.Image(label="Processed Image"),
44
+ gr.Image(label="Label Image"),
45
+ gr.File(label="Download Label Image")],
46
+ title="Cell Detection with Contour Proposal Networks",
47
+ examples=get_examples(default_model)
48
+ ).launch()
cpn.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import celldetection as cd
3
+ import cv2
4
+ import numpy as np
5
+
6
+ __all__ = ['contours2labels', 'CpnInterface']
7
+
8
+
9
+ def contours2labels(contours, size, overlap=False, max_iter=999):
10
+ labels = cd.data.contours2labels(cd.asnumpy(contours), size, initial_depth=3)
11
+
12
+ if not overlap:
13
+ kernel = cv2.getStructuringElement(1, (3, 3))
14
+ mask_sm = np.sum(labels > 0, axis=-1)
15
+ mask = mask_sm > 1 # all overlaps
16
+ if mask.any():
17
+ mask_ = mask_sm == 1 # all cores
18
+ lbl = np.zeros(labels.shape[:2], dtype='float64')
19
+ lbl[mask_] = labels.max(-1)[mask_]
20
+ for _ in range(max_iter):
21
+ lbl_ = np.copy(lbl)
22
+ m = mask & (lbl <= 0)
23
+ if not np.any(m):
24
+ break
25
+ lbl[m] = cv2.dilate(lbl, kernel=kernel)[m]
26
+ if np.allclose(lbl_, lbl):
27
+ break
28
+ else:
29
+ lbl = labels.max(-1)
30
+ labels = lbl.astype('int')
31
+ return labels
32
+
33
+
34
+ class CpnInterface:
35
+ def __init__(self, model, device=None):
36
+ self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
37
+ self.model = cd.models.LitCpn(model).to(device)
38
+ self.model.eval()
39
+ self.tile_size = 768
40
+ self.overlap = 384
41
+
42
+ def __call__(
43
+ self,
44
+ img,
45
+ div=255,
46
+ reduce_labels=True,
47
+ return_labels=True,
48
+ return_viewable_contours=True,
49
+ ):
50
+ if img.ndim == 2:
51
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
52
+ img = img / div
53
+ x = cd.data.to_tensor(img, transpose=True, dtype=torch.float32)[None]
54
+ with torch.no_grad():
55
+ out = cd.asnumpy(self.model(x, crop_size=self.tile_size,
56
+ stride=max(64, self.tile_size - self.overlap)))
57
+
58
+ contours, = out['contours']
59
+ boxes, = out['boxes']
60
+ scores, = out['scores']
61
+
62
+ labels = None
63
+ if return_labels or return_viewable_contours:
64
+ labels = contours2labels(contours, img.shape[:2], overlap=not reduce_labels)
65
+
66
+ return dict(
67
+ contours=contours,
68
+ labels=labels,
69
+ boxes=boxes,
70
+ scores=scores
71
+ )
prep.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import celldetection as cd
2
+ import numpy as np
3
+ from skimage import img_as_ubyte, exposure
4
+ from PIL import ImageFile
5
+
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ __all__ = ['normalize_img', 'normalize_channel', 'multi_norm']
9
+
10
+
11
+ def normalize_img(img, gamma_spread=17, lower_gamma_bound=.6, percentile=99.88):
12
+ log = []
13
+ if img.dtype.kind == 'f': # floats
14
+ if img.max() < 256:
15
+ img = img_as_ubyte(img / 255)
16
+ log.append('img_as_ubyte')
17
+ else:
18
+ v = 99.95
19
+ img = cd.data.normalize_percentile(img, v)
20
+ log.append(f'cd.data.normalize_percentile(img, {v})')
21
+ elif img.itemsize > 1:
22
+ img = cd.data.normalize_percentile(img, percentile)
23
+ log.append(f'cd.data.normalize_percentile(img, {percentile})')
24
+ mean_thresh = np.pi * gamma_spread
25
+ if img.mean() < mean_thresh:
26
+ gamma = (1 - ((np.cos(1 / gamma_spread * img.mean()) + 1) / 2)) * (1 - lower_gamma_bound) + lower_gamma_bound
27
+ log.append(f'(img / 255) ** {gamma}')
28
+ img = (img / 255) ** gamma
29
+ img = img_as_ubyte(img)
30
+ return img, log
31
+
32
+
33
+ def normalize_channel(img, lower=1, upper=99):
34
+ non_zero_vals = img[np.nonzero(img)]
35
+ percentiles = np.percentile(non_zero_vals, [lower, upper])
36
+ if percentiles[1] - percentiles[0] > 0.001:
37
+ img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
38
+ else:
39
+ img_norm = img
40
+ return img_norm.astype(np.uint8)
41
+
42
+
43
+ def multi_norm(img, method):
44
+ if method == 'prov':
45
+ img = normalize_channel(img)
46
+ elif method == 'rand-mix' or method == 'cstm-mix':
47
+ img0 = normalize_channel(img)
48
+ img1, log = normalize_img(img)
49
+ if method == 'rand-mix':
50
+ alpha = np.random.uniform(0., 1.)
51
+ else:
52
+ is_grayscale = img.ndim == 2 or (img.ndim == 3 and img.shape[2] == 1)
53
+ alpha = 0.
54
+ if not is_grayscale:
55
+ if img[..., 2].mean() > 200 and img[..., 2].std() < 20:
56
+ alpha = 1.
57
+ else:
58
+ if img1.mean() < 45 and img1.std() < 33:
59
+ alpha = .5
60
+ img = np.clip(alpha * img0 + (1 - alpha) * img1, 0, 255).astype(img0.dtype)
61
+ else:
62
+ img, log = normalize_img(img)
63
+ return img
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ celldetection @ git+https://github.com/FZJ-INM1-BDA/celldetection.git
2
+ tifffile
util.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imageio.v2 import imread as _imread
2
+ import tifffile as tif
3
+
4
+ __all__ = ['imread', 'imsave', 'get_examples']
5
+
6
+
7
+ def imread(filename):
8
+ if filename.split('.')[-1] in ('tiff', 'tif'):
9
+ return tif.imread(filename)
10
+ return _imread(filename)
11
+
12
+
13
+ def imsave(filename, img, compression="zlib"):
14
+ tif.imwrite(filename, img, compression=compression)
15
+
16
+
17
+ def get_examples(default_model):
18
+ from skimage import data
19
+ from os.path import dirname, join, isfile
20
+
21
+ examples = []
22
+ for f in ['coins.png']:
23
+ f = join(dirname(data.__file__), f)
24
+ if isfile(f):
25
+ examples.append([f, default_model])
26
+ if len(examples):
27
+ return examples