|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import spaces |
|
|
import cv2 |
|
|
from cellpose import models |
|
|
from matplotlib.colors import hsv_to_rgb |
|
|
import matplotlib.pyplot as plt |
|
|
import os, io, base64 |
|
|
from PIL import Image |
|
|
from cellpose.io import imread, imsave |
|
|
|
|
|
|
|
|
def download_weights(): |
|
|
import os, requests |
|
|
|
|
|
fname = ['cpsam'] |
|
|
|
|
|
url = ["https://osf.io/d7c8e/download"] |
|
|
|
|
|
for j in range(len(url)): |
|
|
if not os.path.isfile(fname[j]): |
|
|
try: |
|
|
r = requests.get(url[j]) |
|
|
except requests.ConnectionError: |
|
|
print("!!! Failed to download data !!!") |
|
|
else: |
|
|
if r.status_code != requests.codes.ok: |
|
|
print("!!! Failed to download data !!!") |
|
|
else: |
|
|
with open(fname[j], "wb") as fid: |
|
|
fid.write(r.content) |
|
|
|
|
|
try: |
|
|
download_weights() |
|
|
model = models.CellposeModel(gpu=True, pretrained_model="cpsam") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_flows(y): |
|
|
Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2 |
|
|
X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2 |
|
|
H = (np.arctan2(Y, X) + np.pi) / (2*np.pi) |
|
|
S = normalize99(y[0][0]**2 + y[1][0]**2) |
|
|
HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1) |
|
|
HSV = np.clip(HSV, 0.0, 1.0) |
|
|
flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8) |
|
|
return flow |
|
|
|
|
|
def plot_outlines(img, masks): |
|
|
outpix = [] |
|
|
contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE) |
|
|
for c in range(len(contours)): |
|
|
pix = contours[c].astype(int).squeeze() |
|
|
if len(pix)>4: |
|
|
peri = cv2.arcLength(contours[c], True) |
|
|
approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:] |
|
|
outpix.append(approx) |
|
|
|
|
|
figsize = (6,6) |
|
|
if img.shape[0]>img.shape[1]: |
|
|
figsize = (6*img.shape[1]/img.shape[0], 6) |
|
|
else: |
|
|
figsize = (6, 6*img.shape[0]/img.shape[1]) |
|
|
fig = plt.figure(figsize=figsize, facecolor='k') |
|
|
ax = fig.add_axes([0.0,0.0,1,1]) |
|
|
ax.set_xlim([0,img.shape[1]]) |
|
|
ax.set_ylim([0,img.shape[0]]) |
|
|
ax.imshow(img[::-1], origin='upper', aspect = 'auto') |
|
|
if outpix is not None: |
|
|
for o in outpix: |
|
|
ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1) |
|
|
ax.axis('off') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, bbox_inches='tight') |
|
|
buf.seek(0) |
|
|
pil_img = Image.open(buf) |
|
|
|
|
|
return pil_img |
|
|
|
|
|
def plot_overlay(img, masks): |
|
|
img = normalize99(img.astype(np.float32).mean(axis=-1)) |
|
|
img -= img.min() |
|
|
img /= img.max() |
|
|
HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32) |
|
|
HSV[:,:,2] = np.clip(img*1.5, 0, 1.0) |
|
|
for n in range(int(masks.max())): |
|
|
ipix = (masks==n+1).nonzero() |
|
|
HSV[ipix[0],ipix[1],0] = np.random.rand() |
|
|
HSV[ipix[0],ipix[1],1] = 1.0 |
|
|
RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8) |
|
|
return RGB |
|
|
|
|
|
def normalize99(img): |
|
|
X = img.copy() |
|
|
X = (X - np.percentile(X, 1)) / (np.percentile(X, 99) - np.percentile(X, 1)) |
|
|
return X |
|
|
|
|
|
def image_resize(img, resize=400): |
|
|
ny,nx = img.shape[:2] |
|
|
if np.array(img.shape).max() > resize: |
|
|
if ny>nx: |
|
|
nx = int(nx/ny * resize) |
|
|
ny = resize |
|
|
else: |
|
|
ny = int(ny/nx * resize) |
|
|
nx = resize |
|
|
shape = (nx,ny) |
|
|
img = cv2.resize(img, shape) |
|
|
img = img.astype(np.uint8) |
|
|
return img |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=10) |
|
|
def run_model_gpu(img): |
|
|
masks, flows, _ = model.eval(img) |
|
|
return masks, flows |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def run_model_gpu60(img): |
|
|
masks, flows, _ = model.eval(img) |
|
|
return masks, flows |
|
|
|
|
|
@spaces.GPU(duration=240) |
|
|
def run_model_gpu240(img): |
|
|
masks, flows, _ = model.eval(img) |
|
|
return masks, flows |
|
|
|
|
|
@spaces.GPU(duration=1000) |
|
|
def run_model_gpu1000(img): |
|
|
masks, flows, _ = model.eval(img) |
|
|
return masks, flows |
|
|
|
|
|
|
|
|
def cellpose_segment(img_pil, resize = 400): |
|
|
img_input = imread(img_pil) |
|
|
|
|
|
img = image_resize(img_input, resize = resize) |
|
|
|
|
|
resize = np.max(img.shape) |
|
|
if resize<1000: |
|
|
masks, flows = run_model_gpu(img) |
|
|
elif resize < 5000: |
|
|
masks, flows = run_model_gpu60(img) |
|
|
elif resize < 20000: |
|
|
masks, flows = run_model_gpu240(img) |
|
|
else: |
|
|
raise ValueError("Image size must be less than 20,000") |
|
|
|
|
|
|
|
|
|
|
|
flows = flows[0] |
|
|
|
|
|
|
|
|
|
|
|
outpix = plot_outlines(img, masks) |
|
|
overlay = plot_overlay(img, masks) |
|
|
|
|
|
target_size = (img_input.shape[1], img_input.shape[0]) |
|
|
if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]): |
|
|
|
|
|
masks = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay = Image.fromarray(overlay) |
|
|
flows = Image.fromarray(flows) |
|
|
|
|
|
Ly, Lx = img.shape[:2] |
|
|
c = Lx |
|
|
outpix = outpix.resize((Lx, Ly), resample = Image.BICUBIC) |
|
|
overlay = overlay.resize((Lx, Ly), resample = Image.BICUBIC) |
|
|
flows = flows.resize((Lx, Ly), resample = Image.BICUBIC) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fname_out = os.path.splitext(img_pil)[0]+"_outlines.png" |
|
|
fname_masks = os.path.splitext(img_pil)[0]+"_masks.tif" |
|
|
|
|
|
imsave(fname_masks, masks) |
|
|
|
|
|
|
|
|
outpix.save(fname_out) |
|
|
|
|
|
b1 = gr.DownloadButton(visible=True, value = fname_masks) |
|
|
b2 = gr.DownloadButton(visible=True, value = fname_out) |
|
|
|
|
|
return outpix, overlay, flows, b1, b2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_function(): |
|
|
b1 = gr.DownloadButton("Download masks as TIFF", visible=False) |
|
|
b2 = gr.DownloadButton("Download outline image as PNG", visible=False) |
|
|
return b1, b2 |
|
|
|
|
|
with gr.Blocks(title = "Hello", |
|
|
css=".gradio-container {background:purple;}") as demo: |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:white;">Cellpose-SAM for cellular segmentation</div>""") |
|
|
gr.HTML("""<h4 style="color:white;">You may need to refresh/login for 5 minutes of free GPU compute/day. </h4>""") |
|
|
gr.HTML("""<h4 style="color:white;">"pip install cellpose" for full functionality. </h4>""") |
|
|
|
|
|
input_image = gr.Image(label = "Input image", type = "filepath") |
|
|
|
|
|
with gr.Row(): |
|
|
resize = gr.Number(label = 'max resize', value = 400) |
|
|
send_btn = gr.Button("Run Cellpose-SAM") |
|
|
|
|
|
with gr.Row(): |
|
|
down_btn = gr.DownloadButton("Download masks (TIF)", visible=False) |
|
|
down_btn2 = gr.DownloadButton("Download outlines (PNG)", visible=False) |
|
|
|
|
|
gr.HTML("""<a style="color:white;" href="https://github.com/MouseLand/cellpose" target="_blank">github page for cellpose</a>""") |
|
|
gr.HTML("""<a style="color:white;" href="https://github.com/MouseLand/cellpose" target="_blank">Cellpose-SAM paper</a>""") |
|
|
|
|
|
|
|
|
with gr.Column(scale=2): |
|
|
img_outlines = gr.Image(label = "Outlines", type = "pil", format = 'png') |
|
|
img_overlay = gr.Image(label = "Overlay", type = "pil", format = 'png') |
|
|
flows = gr.Image(label = "Cellpose flows", type = "pil", format = 'png') |
|
|
|
|
|
|
|
|
|
|
|
send_btn.click(fn=cellpose_segment, inputs=[input_image, resize], outputs=[img_outlines, img_overlay, flows, down_btn, down_btn2]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|