Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import tifffile as tiff | |
| import numpy as np | |
| import os | |
| import cv2 | |
| from PIL import Image | |
| import tensorflow as tf | |
| from skimage.feature import peak_local_max | |
| fp0 = np.zeros((96, 128), dtype = np.uint8) | |
| fp1 = np.ones((96, 128), dtype = np.uint8)*200 | |
| # generic image reader | |
| def imread(filepath): | |
| print('imread') | |
| fpath, fext = os.path.splitext(filepath) | |
| if fext in ['.tiff', '.tif']: | |
| print('imread_tiff') | |
| img = tiff.imread(filepath) | |
| else: | |
| print('imread_cv2') | |
| img = cv2.imread(filepath) | |
| return img | |
| # tiff volume to png slice | |
| def tif_view(filepath, z, show_depth=True): | |
| fpath, fext = os.path.splitext(filepath) | |
| print('tif'+filepath) | |
| print('tif'+ fext) | |
| if fext in ['.tiff', '.tif']: | |
| print('happens?') | |
| img = tiff.imread(filepath) | |
| print(img.shape) | |
| if img.ndim==2: | |
| img = np.tile(img[:,:,np.newaxis], [1,1,3]) | |
| elif img.ndim==3: | |
| imin = np.argmin(img.shape) | |
| print(imin) | |
| if imin<2: | |
| img = np.moveaxis(img, imin, 2) | |
| print(img.shape) | |
| else: | |
| raise ValueError("TIF cannot have more than three dimensions") | |
| print(z) | |
| if show_depth: | |
| img = img[:, :, z:(z+3)] | |
| else: | |
| img = img[:, :, (z,z,z)] | |
| Ly, Lx, nchan = img.shape | |
| imgi = np.zeros((Ly, Lx, 3)) | |
| nn = np.minimum(3, img.shape[-1]) | |
| imgi[:,:,:nn] = img[:,:,:nn] | |
| imgi = imgi/(np.max(imgi)+0.0000001) | |
| imgi = (255. * imgi) | |
| filepath = fpath+'z'+str(z)+'.png' | |
| tiff.imwrite(filepath, imgi.astype('uint8')) | |
| print('tif'+filepath) | |
| return filepath | |
| def tif_view_3D(filepath, z): | |
| fpath, fext = os.path.splitext(filepath) | |
| print('tif'+filepath) | |
| print('tif'+ fext) | |
| # assumes (t,)z,(c,)y,x for now | |
| if fext in ['.tiff', '.tif']: | |
| print('happens?') | |
| img = tiff.imread(filepath) | |
| print(img.shape) | |
| if img.ndim==2: | |
| raise ValueError("TIF has only two dimensions") | |
| # select first timepoint | |
| if img.ndim==5: | |
| img = img[0,:,:,:,:] | |
| print(img.shape) | |
| #distinguishes between z,y,x and z,c,y,x | |
| if img.ndim==4: | |
| img = img[z,:,:,:] | |
| print(img.shape) | |
| elif img.ndim==3: | |
| img = img[z,:,:] | |
| print(img.shape) | |
| img = np.tile(img[:,:,np.newaxis], [1,1,3]) | |
| else: | |
| raise ValueError("TIF cannot have more than five dimensions") | |
| imin = np.argmin(img.shape) | |
| img = np.moveaxis(img, imin, 2) | |
| print(img.shape) | |
| Ly, Lx, nchan = img.shape | |
| imgi = np.zeros((Ly, Lx, 3)) | |
| nn = np.minimum(3, img.shape[-1]) | |
| imgi[:,:,:nn] = img[:,:,:nn] | |
| imgi = imgi/(np.max(imgi)+0.0000001) | |
| imgi = (255. * imgi) | |
| filepath = fpath+'.png' | |
| tiff.imwrite(filepath, imgi.astype('uint8')) | |
| else: | |
| raise ValueError("not a TIF/TIFF") | |
| print('tif'+filepath) | |
| return filepath | |
| # function to change image appearance | |
| def norm_path(filepath): | |
| img = imread(filepath) | |
| img = img/(np.max(img)+0.0000001) | |
| #img = np.clip(img, 0, 1) | |
| fpath, fext = os.path.splitext(filepath) | |
| filepath = fpath +'.png' | |
| pil_image = Image.fromarray((255. * img).astype(np.uint8)) | |
| pil_image.save(filepath) | |
| #imsave(filepath, pil_image) | |
| print('norm'+filepath) | |
| return filepath | |
| def update_image(filepath, z): | |
| print('update_img') | |
| #for f in filepath: | |
| #f = tif_view(f, z) | |
| filepath_show = tif_view(filepath[-1], z) | |
| filepath_show = norm_path(filepath_show) | |
| print(filepath_show) | |
| print(filepath) | |
| return (filepath_show, [((5, 5, 10, 10), 'nothing')]), filepath, (fp0, [((5, 5, 10, 10), 'nothing')]) | |
| def update_with_example(filepath): | |
| print('update_btn') | |
| print(filepath) | |
| filepath_show = filepath | |
| fpath, fext = os.path.splitext(filepath) | |
| filepath = fpath+ '.tif' | |
| return (filepath_show, [((5, 5, 10, 10), 'nothing')]), [filepath], (fp0, [((5, 5, 10, 10), 'nothing')]) | |
| def example(filepath): | |
| print(filepath) | |
| return(filepath) | |
| def update_button(filepath, z): | |
| print('update_btn') | |
| print(filepath) | |
| filepath_show = tif_view(filepath, z) | |
| filepath_show = norm_path(filepath_show) | |
| print(filepath_show) | |
| return (filepath_show, [((5, 5, 10, 10), 'nothing')]), [filepath], (fp0, [((5, 5, 10, 10), 'nothing')]) | |
| def update_z(filepath, filepath_result, filepath_coordinates, z): | |
| print('update_img') | |
| #for f in filepath: | |
| #f = tif_view(f, z) | |
| filepath_show = tif_view(filepath[-1], z) | |
| filepath_show = norm_path(filepath_show) | |
| if isinstance(filepath_result, str): | |
| filepath_result_show = tif_view(filepath_result, z, show_depth=False) | |
| filepath_result_show = norm_path(filepath_result_show) | |
| else: | |
| filepath_result_show = fp0 | |
| print(filepath_show) | |
| print(filepath) | |
| if filepath_coordinates is None: | |
| display_boxes = [] | |
| else: | |
| display_boxes = filter_coordinates(filepath_coordinates, z) | |
| return (filepath_show, display_boxes), (filepath_result_show, display_boxes) | |
| def detect_cells(filepath, z): | |
| model = tf.keras.models.load_model('./model_positions', compile=False) | |
| img = tiff.imread(filepath[-1]) | |
| img = img/np.max(img) | |
| img = np.tile(img[:,:,:,np.newaxis], [1,1,2]) | |
| img = img[np.newaxis,:,:,:,:] | |
| img= pad(img) | |
| tiles = split_z(img) | |
| results = [] | |
| for tile in tiles: | |
| tensor = tf.convert_to_tensor(tile) | |
| result = model(tensor).numpy() | |
| result = result[0, :, :, :, 0] | |
| results.append(result) | |
| result = reconstruct_z(results) | |
| print(result.shape) | |
| print(filepath) | |
| fpath, fext = os.path.splitext(filepath[-1]) | |
| filepath_result = fpath+'result'+'.tiff' | |
| tiff.imwrite(filepath_result, result) | |
| filepath_result_show = tif_view(filepath_result, z, show_depth=False) | |
| filepath_result_show = norm_path(filepath_result_show) | |
| coordinates = peak_local_max(result, min_distance=2, threshold_abs=0.2, exclude_border=False) | |
| print(coordinates) | |
| filepath_coordinates = fpath+'coordinates'+'.csv' | |
| np.savetxt(filepath_coordinates, coordinates, delimiter=",") | |
| display_boxes = filter_coordinates(filepath_coordinates, z) | |
| return filepath_result, filepath_coordinates, (filepath_result_show, display_boxes) | |
| def pad(img, z_tile = 32, xy_tile = 96): | |
| pad_z = z_tile-np.mod(img.shape[0], z_tile) | |
| pad_y = xy_tile-np.mod(img.shape[1], xy_tile) | |
| pad_x = xy_tile-np.mod(img.shape[2], xy_tile) | |
| print(pad_x) | |
| return np.pad(img, ((0, pad_z), (0, pad_y), (0, pad_x))) | |
| def split_z(img, z_tile=32, z_buffer=2): | |
| if img.shape[0]==32: | |
| return([img]) | |
| tiles = [] | |
| height = 0 | |
| while height<img.shape[0]: | |
| tiles.append(img[height:(height+z_tile), :, :]) | |
| height = height+z_tile-z_buffer | |
| return tiles | |
| def reconstruct_z(tiles, z_tile=32, z_buffer=2): | |
| if len(tiles)==1: | |
| return tiles[0] | |
| tiles = [tile[0:(z_tile-z_buffer), :, :] for tile in tiles] | |
| return np.stack(tiles, axis = 0) | |
| def filter_coordinates(filepath_coordinates, z): | |
| coordinates = np.loadtxt(filepath_coordinates, delimiter=",") | |
| print(coordinates) | |
| coordinates = coordinates[np.abs(coordinates[:,0]-z)<3, :] | |
| print(coordinates) | |
| xy_coordinates = coordinates[:, (2,1)] | |
| rel_z = np.abs(coordinates[:, 0]-z) | |
| rel_z = rel_z[:, np.newaxis] | |
| print(rel_z) | |
| rel_z =1 | |
| boxes = np.concatenate((xy_coordinates-4+rel_z, xy_coordinates+4-rel_z), axis=1).astype('uint32') | |
| print(boxes) | |
| boxes = [(tuple(box.tolist()),'nothing') for box in boxes] | |
| print(boxes) | |
| return boxes | |
| with gr.Blocks(title = "Hello", | |
| css=".gradio-container {background:purple;}") as demo: | |
| #filepath = "" | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:20pt; font-weight:bold; text-align:center; color:white;">Cellpose-SAM for cellular | |
| segmentation <a style="color:#cfe7fe; font-size:14pt;" href="https://www.biorxiv.org/content/10.1101/2025.04.28.651001v1" target="_blank">[paper]</a> | |
| <a style="color:white; font-size:14pt;" href="https://github.com/MouseLand/cellpose" target="_blank">[github]</a> | |
| <a style="color:white; font-size:14pt;" href="https://www.youtube.com/watch?v=KIdYXgQemcI" target="_blank">[talk]</a> | |
| </div>""") | |
| gr.HTML("""<h4 style="color:white;">You may need to login/refresh for 5 minutes of free GPU compute per day (enough to process hundreds of images). </h4>""") | |
| #input_image = gr.Image(label = "Input", type = "filepath") | |
| input_image = gr.AnnotatedImage(label = "Input", show_legend=False, color_map = {'nothing': '#FFFF00'}) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| resize = gr.Number(label = 'max resize', value = 1000) | |
| max_iter = gr.Number(label = 'max iterations', value = 250) | |
| depth = gr.Number(label = 'z-scale', value = 10) | |
| up_btn = gr.UploadButton("Multi-file upload (png, jpg, tif etc)", visible=True, file_count = "multiple") | |
| #gr.HTML("""<h4 style="color:white;"> Note2: Only the first image of a tif will display the segmentations, but you can download segmentations for all planes. </h4>""") | |
| with gr.Column(scale=1): | |
| send_btn = gr.Button("Run Cellpose-SAM") | |
| down_btn = gr.DownloadButton("Download masks (TIF)", visible=False) | |
| down_btn2 = gr.DownloadButton("Download outlines (PNG)", visible=False) | |
| with gr.Column(scale=2): | |
| # | |
| #output_image = gr.Image(label = "Output", type = "filepath") | |
| output_image = gr.AnnotatedImage(label = "Output", show_legend=False, color_map = {'nothing': '#FFFF00'}) | |
| sample_list = os.listdir("./gradio_examples/jpegs") | |
| #sample_list = [ ("./gradio_examples/jpegs/"+sample, [((5, 5, 10, 10), 'nothing')]) for sample in sample_list] | |
| print(sample_list) | |
| sample_list = [ "./gradio_examples/jpegs/"+sample for sample in sample_list] | |
| #sample_list = [] | |
| #for j in range(23): | |
| # sample_list.append("samples/img%0.2d.png"%j) | |
| #gr.Examples(sample_list, fn = update_with_example, inputs=input_image, outputs = [input_image, up_btn, output_image], examples_per_page=50, label = "Click on an example to try it") | |
| example_image = gr.Image(visible=False, type='filepath') | |
| gr.Examples(sample_list, fn= example, inputs=example_image, outputs=[example_image], examples_per_page=5, label = "Click on an example to try it") | |
| #input_image.upload(update_button, [input_image, depth], [input_image, up_btn, output_image]) | |
| up_btn.upload(update_image, [up_btn, depth], [input_image, up_btn, output_image]) | |
| depth.change(update_z, [up_btn, down_btn, down_btn2, depth], [input_image, output_image]) | |
| #depth.change(update_depth, [up_btn, depth], depth) | |
| # DO NOT RENDER OUTPUT TWICE | |
| send_btn.click(detect_cells, [up_btn, depth], [down_btn, down_btn2, output_image]).then(update_image, [up_btn, depth], [input_image, up_btn, output_image])# flows, down_btn, down_btn2]) | |
| #down_btn.click(download_function, None, [down_btn, down_btn2]) | |
| gr.HTML("""<h4 style="color:white;"> Notes:<br> | |
| <li>you can load and process 2D, multi-channel tifs. | |
| <li>the smallest dimension of a tif --> channels | |
| <li>you can upload multiple files and download a zip of the segmentations | |
| <li>install Cellpose-SAM locally for full functionality. | |
| </h4>""") | |
| demo.launch() | |