NIfTI-Viewer / app.py
IFMedTechdemo's picture
Remove text output, display only slice visualization
8de9095 verified
import gradio as gr
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
def visualize_nifti(file, volume_idx=0, slice_idx=0):
if file is None:
return None
try:
# Load the NIfTI file
img = nib.load(file.name)
data = img.get_fdata()
shape = data.shape
# Determine if 3D or 4D
if len(shape) == 3:
# 3D data
max_slice = shape[2] - 1
slice_idx = min(slice_idx, max_slice)
# Extract the axial slice
slice_data = data[:, :, slice_idx]
# Create the plot
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(slice_data.T, cmap='gray', origin='lower')
ax.axis('off')
plt.title(f'Axial Slice: {slice_idx}/{max_slice}')
plt.tight_layout()
return fig
elif len(shape) == 4:
# 4D data
max_volume = shape[3] - 1
max_slice = shape[2] - 1
volume_idx = min(volume_idx, max_volume)
slice_idx = min(slice_idx, max_slice)
# Extract the axial slice from the specified volume
slice_data = data[:, :, slice_idx, volume_idx]
# Create the plot
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(slice_data.T, cmap='gray', origin='lower')
ax.axis('off')
plt.title(f'Volume: {volume_idx}/{max_volume}, Slice: {slice_idx}/{max_slice}')
plt.tight_layout()
return fig
else:
return None
except Exception as e:
return None
# Create Gradio interface with Blocks for dynamic controls
with gr.Blocks() as demo:
gr.Markdown("# NIfTI File Visualizer")
gr.Markdown("Upload a NIfTI (.nii or .nii.gz) file to visualize axial slices.")
with gr.Row():
file_input = gr.File(label="Upload NIfTI file (.nii or .nii.gz)")
with gr.Row():
volume_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Volume/Frame (for 4D)")
slice_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Axial Slice (Z-plane)")
with gr.Row():
visualize_btn = gr.Button("Visualize")
with gr.Row():
output_plot = gr.Plot(label="Slice Visualization")
visualize_btn.click(
fn=visualize_nifti,
inputs=[file_input, volume_slider, slice_slider],
outputs=[output_plot]
)
if __name__ == "__main__":
demo.launch()