carsen-stringer commited on
Commit
49c6db7
·
1 Parent(s): fa1fd8e
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import gradio as gr
4
+ import cv2
5
+ from cellpose import models
6
+ from matplotlib.colors import hsv_to_rgb
7
+ import os
8
+
9
+ try:
10
+ model = models.CellposeModel(gpu=False, pretrained_model="cyto3")
11
+ except Exception as e:
12
+ print(f"Error loading model: {e}")
13
+ exit(1)
14
+
15
+ def plot_flows(y):
16
+ Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2
17
+ X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2
18
+ H = (np.arctan2(Y, X) + np.pi) / (2*np.pi)
19
+ S = normalize99(y[0][0]**2 + y[1][0]**2)
20
+ HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1)
21
+ HSV = np.clip(HSV, 0.0, 1.0)
22
+ flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
23
+ return flow
24
+
25
+ def plot_outlines(masks):
26
+ outpix = []
27
+ contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE)
28
+ for c in range(len(contours)):
29
+ pix = contours[c].astype(int).squeeze()
30
+ if len(pix)>4:
31
+ peri = cv2.arcLength(contours[c], True)
32
+ approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:]
33
+ outpix.append(approx)
34
+ return outpix
35
+
36
+ def plot_overlay(img, masks):
37
+ img = normalize99(img.astype(np.float32).mean(axis=-1))
38
+ img -= img.min()
39
+ img /= img.max()
40
+ HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
41
+ HSV[:,:,2] = np.clip(img*1.5, 0, 1.0)
42
+ for n in range(int(masks.max())):
43
+ ipix = (masks==n+1).nonzero()
44
+ HSV[ipix[0],ipix[1],0] = np.random.rand()
45
+ HSV[ipix[0],ipix[1],1] = 1.0
46
+ RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
47
+ return RGB
48
+
49
+ def normalize99(img):
50
+ X = img.copy()
51
+ X = (X - np.percentile(X, 1)) / (np.percentile(X, 99) - np.percentile(X, 1))
52
+ return X
53
+
54
+ def image_resize(img, resize=224):
55
+ ny,nx = img.shape[:2]
56
+ if np.array(img.shape).max() > resize:
57
+ if ny>nx:
58
+ nx = int(nx/ny * resize)
59
+ ny = resize
60
+ else:
61
+ ny = int(ny/nx * resize)
62
+ nx = resize
63
+ shape = (nx,ny)
64
+ img = cv2.resize(img, shape)
65
+ img = img.astype(np.uint8)
66
+ return img
67
+
68
+ def cellpose_segment(img):
69
+ img_input = image_resize(img)
70
+ masks, flows, _ = model.eval(img_input)
71
+ flows = flows[0]
72
+ # masks = np.zeros(img.shape[:2])
73
+ # flows = np.zeros_like(img)
74
+ target_size = (img_input.shape[1], img_input.shape[0])
75
+ if (target_size[0]!=img.shape[1] and target_size[1]!=img.shape[0]):
76
+ # scale it back to keep the orignal size
77
+ masks = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16')
78
+ flows = cv2.resize(flows.astype('float32'), target_size).astype('uint8')
79
+
80
+ outpix = plot_outlines(masks)
81
+ overlay = plot_overlay(img, masks)
82
+
83
+ return outpix, overlay, flows, masks
84
+
85
+ # Gradio Interface
86
+ iface = gr.Interface(
87
+ fn=cellpose_segment,
88
+ inputs="image",
89
+ outputs=["image", "image", "image", "image"],
90
+ title="cellpose segmentation",
91
+ description="upload an image, then cellpose will segment it at a max size of 224x224"
92
+ )
93
+
94
+ iface.launch()