3DVascNet / app.py
Hemaxi's picture
Update app.py
f7a8dfa verified
raw
history blame
4.75 kB
import gradio as gr
import numpy as np
import tifffile
import matplotlib.pyplot as plt
from prediction import predict_mask
# Placeholder for your 3D model
def process_3d_image(image, resx, resy, resz):
# Dummy model implementation: Replace with your actual model logic
binary_mask = predict_mask(image, resx, resy, resz)
return binary_mask
def auximread(filepath):
image = tifffile.imread(filepath)
# The output image should be (X,Y,Z)
original_0 = np.shape(image)[0]
original_1 = np.shape(image)[1]
original_2 = np.shape(image)[2]
index_min = np.argmin([original_0, original_1, original_2])
if index_min == 0:
image = image.transpose(1, 2, 0)
elif index_min == 1:
image = image.transpose(0, 2, 1)
return image
# Function to handle file input and processing
def process_file(file, resx, resy, resz):
"""
Process the uploaded file and return the binary mask.
"""
if file.name.endswith(".tif"):
# Load .tif file as a 3D numpy array
image = auximread(file.name)
else:
raise ValueError("Unsupported file format. Please upload a .tif or .czi file.")
# Ensure image is 3D
if len(image.shape) != 3:
raise ValueError("Input image is not 3D.")
# Process image through the model
binary_mask = process_3d_image(image, resx, resy, resz)
# Save binary mask to a .tif file to return
output_path = "output_mask.tif"
tifffile.imwrite(output_path, binary_mask)
return image, binary_mask, output_path
# Function to generate the slice visualization
def visualize_slice(image, mask, slice_index):
"""
Visualizes a 2D slice of the image and the corresponding mask at the given index.
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Extract the 2D slices
image_slice = image[:, :, slice_index]
mask_slice = mask[:, :, slice_index]
# Plot image slice
axes[0].imshow(image_slice, cmap="gray")
axes[0].set_title("Image Slice")
axes[0].axis("off")
# Plot mask slice
axes[1].imshow(mask_slice, cmap="gray")
axes[1].set_title("Mask Slice")
axes[1].axis("off")
# Return the plot as a Gradio-compatible output
plt.tight_layout()
plt.close(fig)
return fig
# Variables to store the processed image and mask
processed_image = None
processed_mask = None
def segment_button_click(file, resx, resy, resz):
global processed_image, processed_mask
processed_image, processed_mask, output_path = process_file(file, resx, resy, resz)
num_slices = processed_image.shape[2]
return "Segmentation completed! Use the slider to explore slices.", output_path, gr.update(visible=True, maximum=num_slices - 1)
def update_visualization(slice_index):
if processed_image is None or processed_mask is None:
raise ValueError("Please process an image first by clicking the Segment button.")
return visualize_slice(processed_image, processed_mask, slice_index)
# Gradio Interface
with gr.Blocks() as iface:
gr.Markdown("""# 3DVascNet: Retinal Blood Vessel Segmentation
Upload a 3D image in .tif format. Click the **Segment** button to process the image and generate a 3D binary mask.
Use the slider to navigate through the 2D slices. This is the official implementation of 3DVascNet, described in this paper: https://www.ahajournals.org/doi/10.1161/ATVBAHA.124.320672.
The raw code is available at https://github.com/HemaxiN/3DVascNet.
""")
# Input fields for resolution in micrometers
with gr.Row():
resx_input = gr.Number(value=0.333, label="Resolution in X (µm)", precision=3)
resy_input = gr.Number(value=0.333, label="Resolution in Y (µm)", precision=3)
resz_input = gr.Number(value=0.5, label="Resolution in Z (µm)", precision=3)
with gr.Row():
file_input = gr.File(label="Upload 3D Image (.tif)")
segment_button = gr.Button("Segment")
status_output = gr.Textbox(label="Status", interactive=False)
download_output = gr.File(label="Download Binary Mask (.tif)")
with gr.Row():
slice_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Slice Index", interactive=True, visible=False)
visualization_output = gr.Plot(label="2D Slice Visualization")
# Button click triggers segmentation
segment_button.click(segment_button_click,
inputs=[file_input, resx_input, resy_input, resz_input],
outputs=[status_output, download_output, slice_slider])
# Slider changes trigger visualization updates
slice_slider.change(update_visualization, inputs=slice_slider, outputs=visualization_output)
if __name__ == "__main__":
iface.launch()