|
import gradio as gr |
|
import numpy as np |
|
import tifffile |
|
import matplotlib.pyplot as plt |
|
from prediction import predict_mask |
|
|
|
|
|
def process_3d_image(image, resx, resy, resz): |
|
|
|
binary_mask = predict_mask(image, resx, resy, resz) |
|
return binary_mask |
|
|
|
def auximread(filepath): |
|
image = tifffile.imread(filepath) |
|
|
|
|
|
original_0 = np.shape(image)[0] |
|
original_1 = np.shape(image)[1] |
|
original_2 = np.shape(image)[2] |
|
|
|
index_min = np.argmin([original_0, original_1, original_2]) |
|
|
|
if index_min == 0: |
|
image = image.transpose(1, 2, 0) |
|
elif index_min == 1: |
|
image = image.transpose(0, 2, 1) |
|
|
|
return image |
|
|
|
|
|
def process_file(file, resx, resy, resz): |
|
""" |
|
Process the uploaded file and return the binary mask. |
|
""" |
|
if file.name.endswith(".tif"): |
|
|
|
image = auximread(file.name) |
|
else: |
|
raise ValueError("Unsupported file format. Please upload a .tif or .czi file.") |
|
|
|
|
|
if len(image.shape) != 3: |
|
raise ValueError("Input image is not 3D.") |
|
|
|
|
|
binary_mask = process_3d_image(image, resx, resy, resz) |
|
|
|
|
|
output_path = "output_mask.tif" |
|
tifffile.imwrite(output_path, binary_mask) |
|
|
|
return image, binary_mask, output_path |
|
|
|
|
|
def visualize_slice(image, mask, slice_index): |
|
""" |
|
Visualizes a 2D slice of the image and the corresponding mask at the given index. |
|
""" |
|
fig, axes = plt.subplots(1, 2, figsize=(12, 6)) |
|
|
|
|
|
image_slice = image[:, :, slice_index] |
|
mask_slice = mask[:, :, slice_index] |
|
|
|
|
|
axes[0].imshow(image_slice, cmap="gray") |
|
axes[0].set_title("Image Slice") |
|
axes[0].axis("off") |
|
|
|
|
|
axes[1].imshow(mask_slice, cmap="gray") |
|
axes[1].set_title("Mask Slice") |
|
axes[1].axis("off") |
|
|
|
|
|
plt.tight_layout() |
|
plt.close(fig) |
|
return fig |
|
|
|
|
|
processed_image = None |
|
processed_mask = None |
|
|
|
def segment_button_click(file, resx, resy, resz): |
|
global processed_image, processed_mask |
|
processed_image, processed_mask, output_path = process_file(file, resx, resy, resz) |
|
num_slices = processed_image.shape[2] |
|
return "Segmentation completed! Use the slider to explore slices.", output_path, gr.update(visible=True, maximum=num_slices - 1) |
|
|
|
def update_visualization(slice_index): |
|
if processed_image is None or processed_mask is None: |
|
raise ValueError("Please process an image first by clicking the Segment button.") |
|
return visualize_slice(processed_image, processed_mask, slice_index) |
|
|
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("""# 3DVascNet: Retinal Blood Vessel Segmentation |
|
Upload a 3D image in .tif format. Click the **Segment** button to process the image and generate a 3D binary mask. |
|
Use the slider to navigate through the 2D slices. This is the official implementation of 3DVascNet, described in this paper: https://www.ahajournals.org/doi/10.1161/ATVBAHA.124.320672. |
|
The raw code is available at https://github.com/HemaxiN/3DVascNet. |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
resx_input = gr.Number(value=0.333, label="Resolution in X (µm)", precision=3) |
|
resy_input = gr.Number(value=0.333, label="Resolution in Y (µm)", precision=3) |
|
resz_input = gr.Number(value=0.5, label="Resolution in Z (µm)", precision=3) |
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="Upload 3D Image (.tif)") |
|
segment_button = gr.Button("Segment") |
|
|
|
status_output = gr.Textbox(label="Status", interactive=False) |
|
download_output = gr.File(label="Download Binary Mask (.tif)") |
|
|
|
with gr.Row(): |
|
slice_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Slice Index", interactive=True, visible=False) |
|
visualization_output = gr.Plot(label="2D Slice Visualization") |
|
|
|
|
|
segment_button.click(segment_button_click, |
|
inputs=[file_input, resx_input, resy_input, resz_input], |
|
outputs=[status_output, download_output, slice_slider]) |
|
|
|
|
|
slice_slider.change(update_visualization, inputs=slice_slider, outputs=visualization_output) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|