maxbetjes's picture
Update app.py
f5ae7ba verified
raw
history blame
12.3 kB
import gradio as gr
import tifffile as tiff
import numpy as np
import os
import cv2
from PIL import Image
import tensorflow as tf
from skimage.feature import peak_local_max
fp0 = np.zeros((96, 128), dtype = np.uint8)
fp1 = np.ones((96, 128), dtype = np.uint8)*200
# generic image reader
def imread(filepath):
print('imread')
fpath, fext = os.path.splitext(filepath)
if fext in ['.tiff', '.tif']:
print('imread_tiff')
img = tiff.imread(filepath)
else:
print('imread_cv2')
img = cv2.imread(filepath)
return img
# tiff volume to png slice
def tif_view(filepath, z, show_depth=True):
fpath, fext = os.path.splitext(filepath)
print('tif'+filepath)
print('tif'+ fext)
if fext in ['.tiff', '.tif']:
print('happens?')
img = tiff.imread(filepath)
print(img.shape)
if img.ndim==2:
img = np.tile(img[:,:,np.newaxis], [1,1,3])
elif img.ndim==3:
imin = np.argmin(img.shape)
print(imin)
if imin<2:
img = np.moveaxis(img, imin, 2)
print(img.shape)
else:
raise ValueError("TIF cannot have more than three dimensions")
print(z)
if show_depth:
img = img[:, :, z:(z+3)]
else:
img = img[:, :, (z,z,z)]
Ly, Lx, nchan = img.shape
imgi = np.zeros((Ly, Lx, 3))
nn = np.minimum(3, img.shape[-1])
imgi[:,:,:nn] = img[:,:,:nn]
imgi = imgi/(np.max(imgi)+0.0000001)
imgi = (255. * imgi)
filepath = fpath+'z'+str(z)+'.png'
tiff.imwrite(filepath, imgi.astype('uint8'))
print('tif'+filepath)
return filepath
def tif_view_3D(filepath, z):
fpath, fext = os.path.splitext(filepath)
print('tif'+filepath)
print('tif'+ fext)
# assumes (t,)z,(c,)y,x for now
if fext in ['.tiff', '.tif']:
print('happens?')
img = tiff.imread(filepath)
print(img.shape)
if img.ndim==2:
raise ValueError("TIF has only two dimensions")
# select first timepoint
if img.ndim==5:
img = img[0,:,:,:,:]
print(img.shape)
#distinguishes between z,y,x and z,c,y,x
if img.ndim==4:
img = img[z,:,:,:]
print(img.shape)
elif img.ndim==3:
img = img[z,:,:]
print(img.shape)
img = np.tile(img[:,:,np.newaxis], [1,1,3])
else:
raise ValueError("TIF cannot have more than five dimensions")
imin = np.argmin(img.shape)
img = np.moveaxis(img, imin, 2)
print(img.shape)
Ly, Lx, nchan = img.shape
imgi = np.zeros((Ly, Lx, 3))
nn = np.minimum(3, img.shape[-1])
imgi[:,:,:nn] = img[:,:,:nn]
imgi = imgi/(np.max(imgi)+0.0000001)
imgi = (255. * imgi)
filepath = fpath+'.png'
tiff.imwrite(filepath, imgi.astype('uint8'))
else:
raise ValueError("not a TIF/TIFF")
print('tif'+filepath)
return filepath
# function to change image appearance
def norm_path(filepath):
img = imread(filepath)
img = img/(np.max(img)+0.0000001)
#img = np.clip(img, 0, 1)
fpath, fext = os.path.splitext(filepath)
filepath = fpath +'.png'
pil_image = Image.fromarray((255. * img).astype(np.uint8))
pil_image.save(filepath)
#imsave(filepath, pil_image)
print('norm'+filepath)
return filepath
def update_image(filepath, z):
print('update_img')
#for f in filepath:
#f = tif_view(f, z)
filepath_show = tif_view(filepath[-1], z)
filepath_show = norm_path(filepath_show)
print(filepath_show)
print(filepath)
return (filepath_show, [((5, 5, 10, 10), 'nothing')]), filepath, (fp0, [((5, 5, 10, 10), 'nothing')])
def update_with_example(filepath):
print('update_btn')
print(filepath)
filepath_show = filepath
fpath, fext = os.path.splitext(filepath)
filepath = fpath+ '.tif'
return (filepath_show, [((5, 5, 10, 10), 'nothing')]), [filepath], (fp0, [((5, 5, 10, 10), 'nothing')])
def example(filepath):
print(filepath)
return(filepath)
def update_button(filepath, z):
print('update_btn')
print(filepath)
filepath_show = tif_view(filepath, z)
filepath_show = norm_path(filepath_show)
print(filepath_show)
return (filepath_show, [((5, 5, 10, 10), 'nothing')]), [filepath], (fp0, [((5, 5, 10, 10), 'nothing')])
def update_z(filepath, filepath_result, filepath_coordinates, z):
print('update_img')
#for f in filepath:
#f = tif_view(f, z)
filepath_show = tif_view(filepath[-1], z)
filepath_show = norm_path(filepath_show)
if isinstance(filepath_result, str):
filepath_result_show = tif_view(filepath_result, z, show_depth=False)
filepath_result_show = norm_path(filepath_result_show)
else:
filepath_result_show = fp0
print(filepath_show)
print(filepath)
if filepath_coordinates is None:
display_boxes = []
else:
display_boxes = filter_coordinates(filepath_coordinates, z)
return (filepath_show, display_boxes), (filepath_result_show, display_boxes)
def detect_cells(filepath, z):
model = tf.keras.models.load_model('./model_positions', compile=False)
img = tiff.imread(filepath[-1])
img = img/np.max(img)
img = np.tile(img[:,:,:,np.newaxis], [1,1,2])
img = img[np.newaxis,:,:,:,:]
img= pad(img)
tiles = split_z(img)
results = []
for tile in tiles:
tensor = tf.convert_to_tensor(tile)
result = model(tensor).numpy()
result = result[0, :, :, :, 0]
results.append(result)
result = reconstruct_z(results)
print(result.shape)
print(filepath)
fpath, fext = os.path.splitext(filepath[-1])
filepath_result = fpath+'result'+'.tiff'
tiff.imwrite(filepath_result, result)
filepath_result_show = tif_view(filepath_result, z, show_depth=False)
filepath_result_show = norm_path(filepath_result_show)
coordinates = peak_local_max(result, min_distance=2, threshold_abs=0.2, exclude_border=False)
print(coordinates)
filepath_coordinates = fpath+'coordinates'+'.csv'
np.savetxt(filepath_coordinates, coordinates, delimiter=",")
display_boxes = filter_coordinates(filepath_coordinates, z)
return filepath_result, filepath_coordinates, (filepath_result_show, display_boxes)
def pad(img, z_tile = 32, xy_tile = 96):
pad_z = z_tile-np.mod(img.shape[0], z_tile)
pad_y = xy_tile-np.mod(img.shape[1], xy_tile)
pad_x = xy_tile-np.mod(img.shape[2], xy_tile)
print(pad_x)
return np.pad(img, ((0, pad_z), (0, pad_y), (0, pad_x)))
def split_z(img, z_tile=32, z_buffer=2):
if img.shape[0]==32:
return([img])
tiles = []
height = 0
while height<img.shape[0]:
tiles.append(img[height:(height+z_tile), :, :])
height = height+z_tile-z_buffer
return tiles
def reconstruct_z(tiles, z_tile=32, z_buffer=2):
if len(tiles)==1:
return tiles[0]
tiles = [tile[0:(z_tile-z_buffer), :, :] for tile in tiles]
return np.stack(tiles, axis = 0)
def filter_coordinates(filepath_coordinates, z):
coordinates = np.loadtxt(filepath_coordinates, delimiter=",")
print(coordinates)
coordinates = coordinates[np.abs(coordinates[:,0]-z)<3, :]
print(coordinates)
xy_coordinates = coordinates[:, (2,1)]
rel_z = np.abs(coordinates[:, 0]-z)
rel_z = rel_z[:, np.newaxis]
print(rel_z)
rel_z =1
boxes = np.concatenate((xy_coordinates-4+rel_z, xy_coordinates+4-rel_z), axis=1).astype('uint32')
print(boxes)
boxes = [(tuple(box.tolist()),'nothing') for box in boxes]
print(boxes)
return boxes
with gr.Blocks(title = "Hello",
css=".gradio-container {background:purple;}") as demo:
#filepath = ""
with gr.Row():
with gr.Column(scale=2):
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:20pt; font-weight:bold; text-align:center; color:white;">Cellpose-SAM for cellular
segmentation <a style="color:#cfe7fe; font-size:14pt;" href="https://www.biorxiv.org/content/10.1101/2025.04.28.651001v1" target="_blank">[paper]</a>
<a style="color:white; font-size:14pt;" href="https://github.com/MouseLand/cellpose" target="_blank">[github]</a>
<a style="color:white; font-size:14pt;" href="https://www.youtube.com/watch?v=KIdYXgQemcI" target="_blank">[talk]</a>
</div>""")
gr.HTML("""<h4 style="color:white;">You may need to login/refresh for 5 minutes of free GPU compute per day (enough to process hundreds of images). </h4>""")
#input_image = gr.Image(label = "Input", type = "filepath")
input_image = gr.AnnotatedImage(label = "Input", show_legend=False, color_map = {'nothing': '#FFFF00'})
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
resize = gr.Number(label = 'max resize', value = 1000)
max_iter = gr.Number(label = 'max iterations', value = 250)
depth = gr.Number(label = 'z-scale', value = 10)
up_btn = gr.UploadButton("Multi-file upload (png, jpg, tif etc)", visible=True, file_count = "multiple")
#gr.HTML("""<h4 style="color:white;"> Note2: Only the first image of a tif will display the segmentations, but you can download segmentations for all planes. </h4>""")
with gr.Column(scale=1):
send_btn = gr.Button("Run Cellpose-SAM")
down_btn = gr.DownloadButton("Download masks (TIF)", visible=False)
down_btn2 = gr.DownloadButton("Download outlines (PNG)", visible=False)
with gr.Column(scale=2):
#
#output_image = gr.Image(label = "Output", type = "filepath")
output_image = gr.AnnotatedImage(label = "Output", show_legend=False, color_map = {'nothing': '#FFFF00'})
sample_list = os.listdir("./gradio_examples/jpegs")
#sample_list = [ ("./gradio_examples/jpegs/"+sample, [((5, 5, 10, 10), 'nothing')]) for sample in sample_list]
print(sample_list)
sample_list = [ "./gradio_examples/jpegs/"+sample for sample in sample_list]
#sample_list = []
#for j in range(23):
# sample_list.append("samples/img%0.2d.png"%j)
#gr.Examples(sample_list, fn = update_with_example, inputs=input_image, outputs = [input_image, up_btn, output_image], examples_per_page=50, label = "Click on an example to try it")
example_image = gr.Image(visible=False, type='filepath')
gr.Examples(sample_list, fn= example, inputs=example_image, outputs=[example_image], examples_per_page=5, label = "Click on an example to try it")
#input_image.upload(update_button, [input_image, depth], [input_image, up_btn, output_image])
up_btn.upload(update_image, [up_btn, depth], [input_image, up_btn, output_image])
depth.change(update_z, [up_btn, down_btn, down_btn2, depth], [input_image, output_image])
#depth.change(update_depth, [up_btn, depth], depth)
# DO NOT RENDER OUTPUT TWICE
send_btn.click(detect_cells, [up_btn, depth], [down_btn, down_btn2, output_image]).then(update_image, [up_btn, depth], [input_image, up_btn, output_image])# flows, down_btn, down_btn2])
#down_btn.click(download_function, None, [down_btn, down_btn2])
gr.HTML("""<h4 style="color:white;"> Notes:<br>
<li>you can load and process 2D, multi-channel tifs.
<li>the smallest dimension of a tif --> channels
<li>you can upload multiple files and download a zip of the segmentations
<li>install Cellpose-SAM locally for full functionality.
</h4>""")
demo.launch()