sjtans's picture
Update app.py
0494df2 verified
import gradio as gr
import spaces
import tifffile as tiff
import zarr
import numpy as np
import os
import cv2
from PIL import Image
from skimage.feature import peak_local_max
import scipy as sc
import huggingface_hub
# Available backend options are: "jax", "torch", "tensorflow".
os.environ["KERAS_BACKEND"] = "torch"
import keras
try:
from keras.src import api_export
api_export.REGISTERED_NAMES_TO_OBJS["keras.models.functional.Functional"] = keras.src.models.functional.Functional
api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.Concatenate"] = keras.src.ops.numpy.Concatenate
api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.Flip"] = keras.src.ops.numpy.Flip
api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.GetItem"] = keras.src.ops.numpy.GetItem
api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.Stack"] = keras.src.ops.numpy.Stack
api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.numpy.Absolute"] = keras.src.ops.numpy.Absolute
api_export.REGISTERED_NAMES_TO_OBJS["keras.ops.nn.Conv"] = keras.src.ops.nn.Conv
api_export.REGISTERED_NAMES_TO_OBJS["keras.backend.torch.optimizers.torch_adam.Adam"] = keras.src.optimizers.Adam
except ModuleNotFoundError:
print('pleasssse')
pass # Not necessary for this version of Keras
import keras.saving
model_adresses = {"Intestinal organoids (0.32x0.32x2.0 um)": 'sjtans/organoids_pytorch',
"c. Elegans embryo (0.1x0.1x1.0 um)": 'sjtans/elegans_pytorch'}
fp0 = np.zeros((96, 128), dtype = np.uint8)
fp1 = np.ones((96, 128), dtype = np.uint8)*200
from huggingface_hub import hf_hub_download
def download_model(model):
return hf_hub_download(repo_id=model, filename="model.keras")
# generic image reader
def imread(filepath):
print('imread')
fpath, fext = os.path.splitext(filepath)
if fext in ['.tiff', '.tif']:
print('imread_tiff')
img = tiff.imread(filepath)
else:
print('imread_cv2')
img = cv2.imread(filepath)
return img
def check_dims(filepath):
tif = tiff.TiffFile(filepath)
store = tif.aszarr()
img = zarr.open(store, mode='r', chunks=None)
store.close()
if img.ndim==3:
return img.shape[0], None
if img.ndim==4:
return img.shape[0], img.shape[1]
if img.ndim==5:
return img.shape[1], img.shape[2]
else:
raise ValueError("TIF has wrong dimensions")
# tiff volume to png slice
def tif_view(filepath, z, c=0, show_depth=True):
fpath, fext = os.path.splitext(filepath)
print('tif'+filepath)
print('tif'+ fext)
if fext in ['.tiff', '.tif']:
img = get_slice(filepath, z, c = c)
# get slice above and below
if show_depth:
img = np.stack([img, get_slice(filepath, z-1, c = c), get_slice(filepath, z+1, c = c)],axis=-1)
else:
img = np.stack([img]*3,axis=-1)
Ly, Lx, nchan = img.shape
imgi = np.zeros((Ly, Lx, 3))
nn = np.minimum(3, img.shape[-1])
imgi[:,:,:nn] = img[:,:,:nn]
imgi = imgi - np.min(imgi)
imgi = imgi/(np.max(imgi)+0.0000001)
imgi = (255. * imgi)
filepath = fpath+'z'+str(z)+'c'+str(c)+'.png'
tiff.imwrite(filepath, imgi.astype('uint8'))
else:
raise ValueError("not a TIF/TIFF")
print('tif'+filepath)
return filepath
def get_slice(filepath, z, c=0):
tif = tiff.TiffFile(filepath)
store = tif.aszarr()
img = zarr.open(store, mode='r', chunks=None)
store.close()
print(z)
if img.ndim==3:
if (z>=img.shape[0]) | (z<0):
print('z to big')
return np.zeros((img.shape[1], img.shape[2]))
if img.ndim==4:
if (z>=img.shape[0]) | (z<0):
return np.zeros((img.shape[2], img.shape[3]))
if img.ndim==5:
if (z>=img.shape[1]) | (z<0):
return np.zeros((img.shape[3], img.shape[4]))
if img.ndim==2:
raise ValueError("TIF has only two dimensions")
if img.ndim==3:
img = img[z,:,:]
if img.ndim==4:
img = img[z,c,:,:]
# select first timepoint
if img.ndim==5:
img = img[0,z,c,:,:]
print(img.shape)
if img.ndim>5:
raise ValueError("TIF cannot have more than five dimensions")
return img
def get_volume(filepath, c=0):
img = tiff.imread(filepath)
print(img.shape)
if img.ndim==2:
raise ValueError("TIF has only two dimensions")
if img.ndim==4:
img = img[:,c,:,:]
# select first timepoint
if img.ndim==5:
img = img[0,:,c,:,:]
print(img.shape)
if img.ndim>5:
raise ValueError("TIF cannot have more than five dimensions")
return img
def tif_view_3D(filepath, z):
fpath, fext = os.path.splitext(filepath)
print('tif'+filepath)
print('tif'+ fext)
# assumes (t,)z,(c,)y,x for now
if fext in ['.tiff', '.tif']:
img = tiff.imread(filepath)
print(img.shape)
if img.ndim==2:
raise ValueError("TIF has only two dimensions")
# select first timepoint
if img.ndim==5:
img = img[0,:,:,:,:]
print(img.shape)
#distinguishes between z,y,x and z,c,y,x
if img.ndim==4:
img = img[z,:,:,:]
print(img.shape)
elif img.ndim==3:
img = img[z,:,:]
print(img.shape)
img = np.tile(img[:,:,np.newaxis], [1,1,3])
else:
raise ValueError("TIF cannot have more than five dimensions")
imin = np.argmin(img.shape)
img = np.moveaxis(img, imin, 2)
print(img.shape)
Ly, Lx, nchan = img.shape
imgi = np.zeros((Ly, Lx, 3))
nn = np.minimum(3, img.shape[-1])
imgi[:,:,:nn] = img[:,:,:nn]
imgi = imgi/(np.max(imgi)+0.0000001)
imgi = (255. * imgi)
filepath = fpath+'.png'
tiff.imwrite(filepath, imgi.astype('uint8'))
else:
raise ValueError("not a TIF/TIFF")
print('tif'+filepath)
return filepath
# function to change image appearance
def norm_path(filepath):
img = imread(filepath)
img = img/(np.max(img)+0.0000001)
#img = np.clip(img, 0, 1)
fpath, fext = os.path.splitext(filepath)
filepath = fpath +'.png'
pil_image = Image.fromarray((255. * img).astype(np.uint8))
pil_image.save(filepath)
#imsave(filepath, pil_image)
print('norm'+filepath)
return filepath
def update_image(filepath, z):
print('update_img')
#for f in filepath:
#f = tif_view(f, z)
filepath_show = tif_view(filepath[-1], z)
filepath_show = norm_path(filepath_show)
print(filepath_show)
print(filepath)
max_z, num_c = check_dims(filepath[-1])
z = min(z, max_z)
if num_c is None:
visible_c = False
num_c = 10
else:
visible_c = True
print(visible_c)
return (filepath_show, [((5, 5, 10, 10), 'nothing')]), filepath, (fp0, [((5, 5, 10, 10), 'nothing')]), None, None, gr.Slider(0, max_z-1, value = z, visible=True), gr.Slider(0, num_c-1, visible=visible_c)
def update_with_example(filepath):
print('update_btn')
print(filepath)
fpath, fext = os.path.splitext(filepath)
filepath = fpath+ '.tif'
filepath = filepath.split('/')[-1]
filepath = "./gradio_examples/"+filepath
return update_image([filepath], z=10)
def example(filepath):
print('update_btn')
print(filepath)
filepath_show = filepath
fpath, fext = os.path.splitext(filepath)
filepath = fpath+ '.tif'
filepath = filepath.split('/')[-1]
filepath = "./gradio_examples/"+filepath
print(filepath)
return(filepath_show)
def update_button(filepath, z):
print('update_btn')
print(filepath)
filepath_show = tif_view(filepath, z)
filepath_show = norm_path(filepath_show)
print(filepath_show)
return (filepath_show, [((5, 5, 10, 10), 'nothing')]), [filepath], (fp0, [((5, 5, 10, 10), 'nothing')], z)
def update(filepath, filepath_result, filepath_coordinates, z, c):
print('update_img')
filepath_show = tif_view(filepath[-1], z, c=c)
filepath_show = norm_path(filepath_show)
if isinstance(filepath_result, str):
filepath_result_show = tif_view(filepath_result, z, show_depth=False)
filepath_result_show = norm_path(filepath_result_show)
else:
filepath_result_show = fp0
print(filepath_show)
print(filepath)
if filepath_coordinates is None:
display_boxes = []
else:
print(imread(filepath_show).shape)
display_boxes = filter_coordinates_alt(filepath_coordinates, z, imread(filepath_show).shape[0:2])
return (filepath_show, display_boxes), (filepath_result_show, display_boxes)
def filter_coordinates_alt(filepath_coordinates, z, image_shape=(512, 512)):
depth = 3
coordinates = np.loadtxt(filepath_coordinates, delimiter=",")
coordinates = coordinates.astype('int')
print(coordinates)
print(np.abs(coordinates[:,0]-z))
coordinates = coordinates[np.abs(coordinates[:,0]-z)<depth, :]
print(coordinates)
#xy_coordinates = coordinates[:, (2,1)]
#rel_z = np.abs(coordinates[:, 0]-z)
#rel_z = rel_z[:, np.newaxis]
boxes = np.zeros(image_shape)
x_coord = tuple(coordinates[:,1])
y_coord = tuple(coordinates[:,2])
sizes = tuple(4-np.abs(coordinates[:,0]-z))
print(sizes)
for x, y, size in zip(x_coord, y_coord, sizes):
boxes = draw_box(boxes, x, y, size)
return [(boxes,'nothing')]
def draw_box(array, x, y, size):
x0 = max(x-size, 0)
y0 = max(y-size, 0)
x1 = min(x+size+1, array.shape[0]-1)
y1 = min(y+size+1, array.shape[1]-1)
array[x0:x1, y0:y1] = 1
return array
def add_boxes_norm_path(filepath, boxes):
img = imread(filepath)
img = img/(np.max(img)+0.0000001)
#boxes = np.stack([boxes, np.zeros(boxes.shape),np.zeros(boxes.shape)], axis=-1).astype('int')
boxes = np.stack([boxes]*3, axis=-1).astype('int')
print(img.shape)
print(np.sum(boxes))
print(boxes.shape)
img = np.where(boxes>0, boxes, img)
fpath, fext = os.path.splitext(filepath)
filepath = fpath + '_with_boxes'+'.png'
pil_image = Image.fromarray((255. * img).astype(np.uint8))
pil_image.save(filepath)
#imsave(filepath, pil_image)
print('norm_with_box'+filepath)
return filepath
def loss(y_true, y_pred):
# Calculate weighted mean square error
return None
def position_precision(y_true, y_pred):
return loss(y_true, y_pred)
def position_recall(y_true, y_pred):
return loss(y_true, y_pred)
def overcount(y_true, y_pred):
return loss(y_true, y_pred)
@spaces.GPU(duration=60)
def run_model_gpu60(model, tile):
tensor = keras.ops.convert_to_tensor(tile)
model = keras.saving.load_model(model, custom_objects={'loss': loss,
'position_precision': position_precision,
'position_recall': position_recall,
'overcount': overcount})
result = model(tensor).cpu().detach().numpy()
return result
def detect_cells(filepath, c, model, rescale_z, rescale_xy, progress=gr.Progress()):
model = download_model(model_adresses[model])
#model = tf.keras.models.load_model(model_adresses[model], compile=False)
# model = keras.saving.load_model("hf://sjtans/OrganoidTracker2_pytorch", compile=False)
# model = keras.saving.load_model(model, compile=False)
xy_tile =32
img = get_volume(filepath[-1], c = c)
original_shape = img.shape
img = sc.ndimage.zoom(img, (rescale_z, rescale_xy, rescale_xy))
background = np.quantile(img, 0.75)
img = np.maximum(img, background)-background
img = img/np.max(img)
img_padded= pad(img)
print(img_padded.shape)
tiles = split_z(img_padded)
results = []
print(tiles)
for tile in tiles:
tile = np.tile(tile[:,:,:,np.newaxis], [1,1,2])
tile= tile[np.newaxis,:,:,:,:]
#result = model(tensor).numpy()
result = run_model_gpu60(model, tile)
# remove buffer
result = result[0, :, xy_tile//2 : -xy_tile//2, xy_tile//2 : -xy_tile//2, 0]
results.append(result)
result = reconstruct_z(results)
print(result.shape)
result = sc.ndimage.zoom(result, (1/rescale_z, 1/rescale_xy, 1/rescale_xy))
result = result[0:original_shape[0],0:original_shape[1], 0:original_shape[2]]
print(result.shape)
print(filepath)
fpath, fext = os.path.splitext(filepath[-1])
filepath_result = fpath+'result'+'.tiff'
tiff.imwrite(filepath_result, result)
# filepath_result_show = tif_view(filepath_result, z, show_depth=False)
#filepath_result_show = norm_path(filepath_result_show)
result = result/np.max(result)
#coordinates = peak_local_max(result, min_distance=2, threshold_abs=0.2, exclude_border=False)
coordinates = peak_local_max(result, min_distance=2, footprint=np.ones((5//rescale_z, 13//rescale_xy, 13//rescale_xy)), threshold_abs=0.2, exclude_border=False)
print(coordinates)
filepath_coordinates = fpath+'coordinates'+'.csv'
np.savetxt(filepath_coordinates, coordinates, delimiter=",")
#display_boxes = filter_coordinates(filepath_coordinates, z)
return filepath_result, filepath_coordinates, (fp0, [((5, 5, 10, 10), 'nothing')])#, (filepath_result_show, display_boxes)
def pad(img, z_tile = 32, z_buffer=2, xy_tile = 32):
if img.shape[0]<z_tile:
pad_z = z_tile - img.shape[0]
elif np.mod(img.shape[0], z_tile)>0:
pad_z = z_tile-np.mod(img.shape[0], z_tile-z_buffer)
else:
pad_z = 0
if np.mod(img.shape[1], xy_tile)>0:
pad_y = xy_tile-np.mod(img.shape[1], xy_tile)
else:
pad_y = 0
if np.mod(img.shape[2], xy_tile)>0:
pad_x = xy_tile-np.mod(img.shape[2], xy_tile)
else:
pad_x = 0
return np.pad(img, ((0, pad_z), ( xy_tile//2 , pad_y + xy_tile//2 ), ( xy_tile//2 , pad_x+ xy_tile//2 )))
def split_z(img, z_tile=32, z_buffer=2):
if img.shape[0]==32:
return([img])
tiles = []
height = 0
while height<(img.shape[0]-z_buffer):
tiles.append(img[height:(height+z_tile), :, :])
height = height+z_tile-z_buffer
print(height)
return tiles
def reconstruct_z(tiles, z_tile=32, z_buffer=2):
if len(tiles)==1:
return tiles[0]
tiles = [tile[0:(z_tile-z_buffer), :, :] for tile in tiles]
return np.concatenate(tiles, axis = 0)
def filter_coordinates(filepath_coordinates, z):
coordinates = np.loadtxt(filepath_coordinates, delimiter=",")
print(coordinates)
coordinates = coordinates[np.abs(coordinates[:,0]-z)<3, :]
print(coordinates)
xy_coordinates = coordinates[:, (2,1)]
rel_z = np.abs(coordinates[:, 0]-z)
rel_z = rel_z[:, np.newaxis]
print(rel_z)
boxes = np.concatenate((xy_coordinates-4+rel_z, xy_coordinates+4-rel_z), axis=1).astype('uint32')
print(boxes)
boxes = [(tuple(box.tolist()),'nothing') for box in boxes]
print(boxes)
return boxes
with gr.Blocks(title = "Hello",
css=""".gradio-container {background:green;}
#examples { background:green;}""") as demo:
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:25pt; font-weight:bold; text-align:center; color:white;">OrganoidTracker 2.0 for 3D cell tracking
<a style="color:#cfe7fe; font-size:14pt;" href="https://www.biorxiv.org/content/10.1101/2024.10.11.617799v1" target="_blank">[paper]</a>
<a style="color:#cfe7fe; font-size:14pt;" href="https://organoidtracker.org" target="_blank">[website]</a>
<a style="color:#cfe7fe; font-size:14pt;" href="https://jvzonlab.github.io/OrganoidTracker/index.html" target="_blank">[github]</a>
</div>""")
gr.HTML("""<h4 style="color:white;"> What is this?:<p> </h4>
<ul>
<li style="color:white;">Test the performance of our pre-trained networks on your data.
<li style="color:white;">We implement only the initial cell detection step, but this is a good performance indicator for the other steps.
<li style="color:white;">Does not work? OrganoidTracker 2.0 allows for the easy creation of ground truth datasets and training of new neural networks.
<ul>
""")
#filepath = ""
with gr.Row():
with gr.Column(scale=2):
# <a style="color:white; font-size:14pt;" href="https://www.youtube.com/watch?v=KIdYXgQemcI" target="_blank">[talk]</a>
#input_image = gr.Image(label = "Input", type = "filepath")
input_image = gr.AnnotatedImage(label = "Input", show_legend=False, color_map = {'nothing': '#FFFF00'})
gr.HTML("""<h4 style="color:white;">You may need to login/refresh for 5 minutes of free GPU compute per day. </h4>""")
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
depth = gr.Slider(0, 100, step=1, label = 'z-depth', value = 10, visible=False)
channel = gr.Slider(0, 100, label = 'channel', value = 0, visible=False)
up_btn = gr.UploadButton("Upload image volume (.tif/.tiff)", visible=True, file_count = "multiple")
#gr.HTML("""<h4 style="color:white;"> Note2: Only the first image of a tif will display the segmentations, but you can download segmentations for all planes. </h4>""")
with gr.Column(scale=1):
model = gr.Dropdown(
model_adresses.keys(), label="Detection model (with resolutions)", info="Will add more models later!"
)
with gr.Row():
rescale_xy = gr.Slider(0.2, 2, step=0.1, label = 'resize xy', value = 1)
rescale_z = gr.Slider(0.2, 2, step=0.1,label = 'resize z', value = 1)
send_btn = gr.Button("Run cell detection")
with gr.Column(scale=2):
#
#output_image = gr.Image(label = "Output", type = "filepath")
output_image = gr.AnnotatedImage(label = "Output", show_legend=False, color_map = {'nothing': '#FFFF00'})
down_btn = gr.DownloadButton("Download distance map (.tif)", visible=True)
down_btn2 = gr.DownloadButton("Download cell detections (.csv)", visible=True)
#sample_list = []
#for j in range(23):
# sample_list.append("samples/img%0.2d.png"%j)
gr.HTML("""<h4 style="color:white;"> Click on an example to try it:
</h4>""")
sample_list = os.listdir("./gradio_examples/jpegs")
#sample_list = [ ("./gradio_examples/jpegs/"+sample, [((5, 5, 10, 10), 'nothing')]) for sample in sample_list]
print(sample_list)
sample_list = [ "./gradio_examples/jpegs/"+sample for sample in sample_list]
#gr.Examples(sample_list, fn = update_with_example, inputs=input_image, outputs = [input_image, up_btn, output_image], examples_per_page=50, label = "Click on an example to try it")
example_image = gr.Image(visible=False, type='filepath')
gr.Examples(sample_list, fn= update_with_example, inputs=example_image, outputs=[input_image, up_btn, output_image, down_btn, down_btn2, depth, channel], examples_per_page=5, label=' ',
cache_examples=False, run_on_click=True, elem_id='examples')
#gr.Examples(sample_list, fn= example, inputs=example_image, outputs=[example_image], examples_per_page=5, label = "Click on an example to try it")
#input_image.upload(update_button, [input_image, depth], [input_image, up_btn, output_image])
up_btn.upload(update_image, [up_btn, depth], [input_image, up_btn, output_image, down_btn, down_btn2, depth, channel])
depth.change(update, [up_btn, down_btn, down_btn2, depth, channel], [input_image, output_image])
channel.change(update, [up_btn, down_btn, down_btn2, depth, channel], [input_image, output_image])
#depth.change(update_depth, [up_btn, depth], depth)
# Prediction
send_btn.click(detect_cells, [up_btn, channel, model, rescale_z, rescale_xy], [ down_btn, down_btn2, output_image]).then(update, [up_btn, down_btn, down_btn2, depth, channel], [input_image, output_image])# flows, down_btn, down_btn2])
#down_btn.click(download_function, None, [down_btn, down_btn2])
gr.HTML("""<h4 style="color:white;"> Notes:<br> </h4>
<li style="color:white;">You can load and process 3D tifs in the following dimensions: (T),Z,(C),Y,X. We automatically pick the first timepoint.
<li style="color:white;">Without GPU access, cell detection might take ~30 seconds.
<li style="color:white;">Locally OrganoidTracker wil run faster: ~2 seconds per frame on a dedicated GPU, ~10 seconds on a CPU.
""")
gr.HTML("""<h4 style="color:white;"> Caveats:<br> </h4>
<li style="color:white;">For this demo, an agressive background subtraction step is implemented before prediction, which we find benefits most usecases. For transperency, users have to preprocess the data themselves in OrganoidTracker 2.0.
<li style="color:white;">Because of incompatibilities between TensorFlow and HuggingFace the models here are trained with the upcoming PyTorch version of OrganoidTracker (currently in beta). There might be performance differences when using the TensorFlow-versions presented in our paper.
""")
gr.HTML("""<h4 style="color:white;"> References:<br> </h4>
<li style="color:white;">The blastocyst sample data is taken from the BlastoSPIM dataset (Nunley et al., Development, 2024):
<a style="color:#cfe7fe" href="https://blastospim.flatironinstitute.org/html/index1.html" target="_blank">[website]</a>,
<a style="color:#cfe7fe" href=https://journals.biologists.com/dev/article/151/21/dev202817/362603/Nuclear-instance-segmentation-and-tracking-for target="_blank">[paper]</a>
<li style="color:white;">The c Elegans sample data is taken from the Cell Tracking Challenge (Murray et al., Nature Methods, 2008):
<a style="color:#cfe7fe" href="https://celltrackingchallenge.net/3d-datasets/" target="_blank">[website]</a>,
<a style="color:#cfe7fe" href=https://www.nature.com/articles/nmeth.1228 target="_blank">[paper]</a>
""")
demo.queue().launch()