Duzduran
visualization optimize
efadf48
import numpy as np
import cv2
import nibabel as nib
from PIL import Image
import io
import matplotlib.pyplot as plt
import gradio as gr
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("duzduran/NeuroNest3D")
# Constants
IMG_SIZE = 128
VOLUME_SLICES = 100
VOLUME_START_AT = 22
SEGMENT_CLASSES = ['NOT tumor', 'ENHANCING', 'CORE', 'WHOLE']
def predictByPath(flair, ce):
X = np.empty((VOLUME_SLICES, IMG_SIZE, IMG_SIZE, 2))
for j in range(VOLUME_SLICES):
X[j, :, :, 0] = cv2.resize(flair[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))
X[j, :, :, 1] = cv2.resize(ce[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))
# Normalize and make predictions
X_normalized = X / np.max(X)
return model.predict(X_normalized, verbose=1)
def create_subplot_image(origImage, gt, predictions, slice_index, start_at, img_size):
plt.figure(figsize=(18, 10))
f, axarr = plt.subplots(1, 6, figsize=(18, 10))
for i in range(6):
axarr[i].imshow(cv2.resize(origImage[:, :, slice_index + start_at], (img_size, img_size)), cmap="gray",
interpolation='none')
# Original image flair
axarr[0].title.set_text('Original image flair')
# Ground truth
curr_gt = cv2.resize(gt[:, :, slice_index + start_at], (img_size, img_size), interpolation=cv2.INTER_NEAREST)
axarr[1].imshow(curr_gt, cmap="Reds", interpolation='none', alpha=0.3)
axarr[1].title.set_text('Ground truth')
# All classes
axarr[2].imshow(predictions[slice_index, :, :, 1:4], cmap="Reds", interpolation='none', alpha=0.3)
axarr[2].title.set_text('All classes')
SEGMENT_CLASSES
# Class-specific predictions
for i in range(1, 4): # Adjusted to loop over the available prediction classes
axarr[i + 2].imshow(predictions[slice_index, :, :, i], cmap="OrRd", interpolation='none', alpha=0.3)
axarr[i + 2].title.set_text(f'{SEGMENT_CLASSES[i]} predicted')
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close(f)
buf.seek(0)
img = Image.open(buf)
return img
examples = {
"Example 1": {"flair": "examples/ex_1/BraTS20_Training_001_flair.nii",
"t1ce": "examples/ex_1/BraTS20_Training_001_t1ce.nii",
"seg": "examples/ex_1/BraTS20_Training_001_seg.nii"},
"Example 2": {"flair": "examples/ex_2/BraTS20_Training_002_flair.nii",
"t1ce": "examples/ex_2/BraTS20_Training_002_t1ce.nii",
"seg": "examples/ex_2/BraTS20_Training_002_seg.nii"},
}
def automatic_process(example_key):
paths = examples[example_key]
print(paths["flair"])
flair = nib.load(paths["flair"]).get_fdata()
t1ce = nib.load(paths["t1ce"]).get_fdata()
seg = nib.load(paths["seg"]).get_fdata()
# Default slice index
slice_index = 50
return process_and_display_direct(flair, t1ce, seg, slice_index)
def process_and_display_direct(flair_data, t1ce_data, seg_data, slice_index):
flair = np.array(flair_data)
t1ce = np.array(t1ce_data)
seg = np.array(seg_data)
p = predictByPath(flair, t1ce)
# Create the subplot image
subplot_img = create_subplot_image(flair, seg, p, slice_index, VOLUME_START_AT, IMG_SIZE)
return subplot_img
def process_and_display(flair_file, t1ce_file, seg_file, slice_index):
if not flair_file or not t1ce_file or not seg_file:
return None # Ensure all files are uploaded
flair = nib.load(flair_file.name).get_fdata()
t1ce = nib.load(t1ce_file.name).get_fdata()
gt = nib.load(seg_file.name).get_fdata()
p = predictByPath(flair, t1ce)
# Create the subplot image
subplot_img = create_subplot_image(flair, gt, p, slice_index, VOLUME_START_AT, IMG_SIZE)
return subplot_img
title = "<center><strong><font size='8'>Open-Vocabulary SAM<font></strong></center>"
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
# Gradio Interface
with gr.Blocks(css=css, title="Tumor Segmentation") as demo:
gr.Markdown(
"""
<p style="text-align: center; font-size: 24px;">MRI Brain Tumor Segmentation</p>
<p style="text-align: center;">made by Ahmet Duzduran</p>
### <p style="text-align: left;">Faculty: Faculty of Computer Science</p>
### <p style="text-align: left;">Specialization: Intelligent Systems and Data Science</p>
### <p style="text-align: left;">Supervisor: Wojciech Oronowicz, PhD, Prof. Of PJATK</p>
"""
)
with gr.Row():
flair_input = gr.File(label="Upload Flair NIfTI File")
t1ce_input = gr.File(label="Upload T1ce NIfTI File")
seg_input = gr.File(label="Upload Seg NIfTI File")
slice_input = gr.Slider(minimum=0, maximum=VOLUME_SLICES - 1, label="Slice Index")
# eval_class_input = gr.Dropdown(choices=list(range(len(SEGMENT_CLASSES))), label="Select Class")
submit_button = gr.Button("Submit")
with gr.Row():
example_selector = gr.Dropdown(list(examples.keys()), label="Select Example")
auto_button = gr.Button("Load Example")
output_image = gr.Image(label="Visualization")
submit_button.click(
process_and_display,
inputs=[flair_input, t1ce_input, seg_input, slice_input],
outputs=output_image
)
auto_button.click(
automatic_process,
inputs=[example_selector],
outputs=output_image
)
demo.launch()