File size: 5,362 Bytes
49c6db7 8786ac3 49c6db7 f0821bf 44c9541 49c6db7 f16ef34 49c6db7 f0821bf 49c6db7 f0821bf c034f55 736285e c034f55 49c6db7 32b0ba4 49c6db7 bce0fa8 86bfbdc bce0fa8 ded4361 958ea27 86bfbdc 49c6db7 958ea27 49c6db7 958ea27 5fa8922 fc03022 49c6db7 2c27168 49c6db7 2c27168 77e473c 2c27168 6ba53a6 fc03022 c034f55 fc03022 6ba53a6 fc03022 a3dc09f fc03022 a3dc09f 2c27168 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import numpy as np
import gradio as gr
import spaces
import cv2
from cellpose import models
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import os, io, base64
from PIL import Image
try:
model = models.CellposeModel(gpu=True, pretrained_model="cyto3")
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
def plot_flows(y):
Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2
X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2
H = (np.arctan2(Y, X) + np.pi) / (2*np.pi)
S = normalize99(y[0][0]**2 + y[1][0]**2)
HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1)
HSV = np.clip(HSV, 0.0, 1.0)
flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
return flow
def plot_outlines(img, masks):
outpix = []
contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE)
for c in range(len(contours)):
pix = contours[c].astype(int).squeeze()
if len(pix)>4:
peri = cv2.arcLength(contours[c], True)
approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:]
outpix.append(approx)
figsize = (6,6)
if img.shape[0]>img.shape[1]:
figsize = (6*img.shape[1]/img.shape[0], 6)
else:
figsize = (6, 6*img.shape[0]/img.shape[1])
fig = plt.figure(figsize=figsize, facecolor='k')
ax = fig.add_axes([0.0,0.0,1,1])
ax.set_xlim([0,img.shape[1]])
ax.set_ylim([0,img.shape[0]])
ax.imshow(img[::-1], origin='upper')
if outpix is not None:
for o in outpix:
ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1)
ax.axis('off')
#bytes_image = io.BytesIO()
#plt.savefig(bytes_image, format='png', facecolor=fig.get_facecolor(), edgecolor='none')
#bytes_image.seek(0)
#img_arr = np.frombuffer(bytes_image.getvalue(), dtype=np.uint8)
#bytes_image.close()
#img = cv2.imdecode(img_arr, 1)
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#del bytes_image
#fig.clf()
#plt.close(fig)
buf = io.BytesIO()
fig.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
def plot_overlay(img, masks):
img = normalize99(img.astype(np.float32).mean(axis=-1))
img -= img.min()
img /= img.max()
HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
HSV[:,:,2] = np.clip(img*1.5, 0, 1.0)
for n in range(int(masks.max())):
ipix = (masks==n+1).nonzero()
HSV[ipix[0],ipix[1],0] = np.random.rand()
HSV[ipix[0],ipix[1],1] = 1.0
RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
return RGB
def normalize99(img):
X = img.copy()
X = (X - np.percentile(X, 1)) / (np.percentile(X, 99) - np.percentile(X, 1))
return X
def image_resize(img, resize=400):
ny,nx = img.shape[:2]
if np.array(img.shape).max() > resize:
if ny>nx:
nx = int(nx/ny * resize)
ny = resize
else:
ny = int(ny/nx * resize)
nx = resize
shape = (nx,ny)
img = cv2.resize(img, shape)
img = img.astype(np.uint8)
return img
#@spaces.GPU(duration=10)
#def run_model_gpu(model, img):
# masks, flows, _ = model.eval(img, channels=[0,0])
# return masks, flows
@spaces.GPU(duration=10)
def cellpose_segment(img_input):
img = image_resize(img_input)
#masks, flows = run_model_gpu(model, img)
masks, flows, _ = model.eval(img, channels=[0,0])
flows = flows[0]
# masks = np.zeros(img.shape[:2])
# flows = np.zeros_like(img)
target_size = (img_input.shape[1], img_input.shape[0])
if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]):
# scale it back to keep the orignal size
masks = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16')
flows = cv2.resize(flows.astype('float32'), target_size).astype('uint8')
outpix = plot_outlines(img_input, masks)
overlay = plot_overlay(img_input, masks)
iperm = np.random.permutation(np.max(masks.flatten()).astype('int')+1)
return outpix, overlay, flows, iperm[masks]
# Gradio Interface
#iface = gr.Interface(
# fn=cellpose_segment,
# inputs="image",
# outputs=["image", "image", "image", "image"],
# title="cellpose segmentation",
# description="upload an image, then cellpose will segment it at a max size of 400x400 (for full functionality, 'pip install cellpose' locally)"
#)
with gr.Blocks(title = "Hello",
css=".gradio-container {background:purple;}") as demo:
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label = "Input image", type = "numpy")
send_btn = gr.Button("Run Cellpose-SAM")
with gr.Column(scale=1):
img_overlay = gr.Image(label = "Output image", type = "numpy")
img_outlines = gr.Image(label = "Output image", type = "pil")
flows = gr.Image(label = "Output image", type = "numpy")
masks = gr.Image(label = "Output image", type = "numpy")
send_btn.click(fn=cellpose_segment, inputs=[input_image], outputs=[img_outlines, img_overlay, flows, masks])
demo.launch()
|