cellpose / app.py
mouseland's picture
Update app.py
736285e verified
raw
history blame
5.36 kB
import numpy as np
import gradio as gr
import spaces
import cv2
from cellpose import models
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import os, io, base64
from PIL import Image
try:
model = models.CellposeModel(gpu=True, pretrained_model="cyto3")
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
def plot_flows(y):
Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2
X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2
H = (np.arctan2(Y, X) + np.pi) / (2*np.pi)
S = normalize99(y[0][0]**2 + y[1][0]**2)
HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1)
HSV = np.clip(HSV, 0.0, 1.0)
flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
return flow
def plot_outlines(img, masks):
outpix = []
contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE)
for c in range(len(contours)):
pix = contours[c].astype(int).squeeze()
if len(pix)>4:
peri = cv2.arcLength(contours[c], True)
approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:]
outpix.append(approx)
figsize = (6,6)
if img.shape[0]>img.shape[1]:
figsize = (6*img.shape[1]/img.shape[0], 6)
else:
figsize = (6, 6*img.shape[0]/img.shape[1])
fig = plt.figure(figsize=figsize, facecolor='k')
ax = fig.add_axes([0.0,0.0,1,1])
ax.set_xlim([0,img.shape[1]])
ax.set_ylim([0,img.shape[0]])
ax.imshow(img[::-1], origin='upper')
if outpix is not None:
for o in outpix:
ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1)
ax.axis('off')
#bytes_image = io.BytesIO()
#plt.savefig(bytes_image, format='png', facecolor=fig.get_facecolor(), edgecolor='none')
#bytes_image.seek(0)
#img_arr = np.frombuffer(bytes_image.getvalue(), dtype=np.uint8)
#bytes_image.close()
#img = cv2.imdecode(img_arr, 1)
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#del bytes_image
#fig.clf()
#plt.close(fig)
buf = io.BytesIO()
fig.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
def plot_overlay(img, masks):
img = normalize99(img.astype(np.float32).mean(axis=-1))
img -= img.min()
img /= img.max()
HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
HSV[:,:,2] = np.clip(img*1.5, 0, 1.0)
for n in range(int(masks.max())):
ipix = (masks==n+1).nonzero()
HSV[ipix[0],ipix[1],0] = np.random.rand()
HSV[ipix[0],ipix[1],1] = 1.0
RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
return RGB
def normalize99(img):
X = img.copy()
X = (X - np.percentile(X, 1)) / (np.percentile(X, 99) - np.percentile(X, 1))
return X
def image_resize(img, resize=400):
ny,nx = img.shape[:2]
if np.array(img.shape).max() > resize:
if ny>nx:
nx = int(nx/ny * resize)
ny = resize
else:
ny = int(ny/nx * resize)
nx = resize
shape = (nx,ny)
img = cv2.resize(img, shape)
img = img.astype(np.uint8)
return img
#@spaces.GPU(duration=10)
#def run_model_gpu(model, img):
# masks, flows, _ = model.eval(img, channels=[0,0])
# return masks, flows
@spaces.GPU(duration=10)
def cellpose_segment(img_input):
img = image_resize(img_input)
#masks, flows = run_model_gpu(model, img)
masks, flows, _ = model.eval(img, channels=[0,0])
flows = flows[0]
# masks = np.zeros(img.shape[:2])
# flows = np.zeros_like(img)
target_size = (img_input.shape[1], img_input.shape[0])
if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]):
# scale it back to keep the orignal size
masks = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16')
flows = cv2.resize(flows.astype('float32'), target_size).astype('uint8')
outpix = plot_outlines(img_input, masks)
overlay = plot_overlay(img_input, masks)
iperm = np.random.permutation(np.max(masks.flatten()).astype('int')+1)
return outpix, overlay, flows, iperm[masks]
# Gradio Interface
#iface = gr.Interface(
# fn=cellpose_segment,
# inputs="image",
# outputs=["image", "image", "image", "image"],
# title="cellpose segmentation",
# description="upload an image, then cellpose will segment it at a max size of 400x400 (for full functionality, 'pip install cellpose' locally)"
#)
with gr.Blocks(title = "Hello",
css=".gradio-container {background:purple;}") as demo:
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label = "Input image", type = "numpy")
send_btn = gr.Button("Run Cellpose-SAM")
with gr.Column(scale=1):
img_overlay = gr.Image(label = "Output image", type = "numpy")
img_outlines = gr.Image(label = "Output image", type = "pil")
flows = gr.Image(label = "Output image", type = "numpy")
masks = gr.Image(label = "Output image", type = "numpy")
send_btn.click(fn=cellpose_segment, inputs=[input_image], outputs=[img_outlines, img_overlay, flows, masks])
demo.launch()