Spaces:
Runtime error
Runtime error
import numpy as np | |
import cv2 | |
import nibabel as nib | |
from PIL import Image | |
import io | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
from huggingface_hub import from_pretrained_keras | |
model = from_pretrained_keras("duzduran/NeuroNest3D") | |
# Constants | |
IMG_SIZE = 128 | |
VOLUME_SLICES = 100 | |
VOLUME_START_AT = 22 | |
SEGMENT_CLASSES = ['NOT tumor', 'ENHANCING', 'CORE', 'WHOLE'] | |
def predictByPath(flair, ce): | |
X = np.empty((VOLUME_SLICES, IMG_SIZE, IMG_SIZE, 2)) | |
for j in range(VOLUME_SLICES): | |
X[j, :, :, 0] = cv2.resize(flair[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)) | |
X[j, :, :, 1] = cv2.resize(ce[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)) | |
# Normalize and make predictions | |
X_normalized = X / np.max(X) | |
return model.predict(X_normalized, verbose=1) | |
def create_subplot_image(origImage, gt, predictions, slice_index, start_at, img_size): | |
plt.figure(figsize=(18, 10)) | |
f, axarr = plt.subplots(1, 6, figsize=(18, 10)) | |
for i in range(6): | |
axarr[i].imshow(cv2.resize(origImage[:, :, slice_index + start_at], (img_size, img_size)), cmap="gray", | |
interpolation='none') | |
# Original image flair | |
axarr[0].title.set_text('Original image flair') | |
# Ground truth | |
curr_gt = cv2.resize(gt[:, :, slice_index + start_at], (img_size, img_size), interpolation=cv2.INTER_NEAREST) | |
axarr[1].imshow(curr_gt, cmap="Reds", interpolation='none', alpha=0.3) | |
axarr[1].title.set_text('Ground truth') | |
# All classes | |
axarr[2].imshow(predictions[slice_index, :, :, 1:4], cmap="Reds", interpolation='none', alpha=0.3) | |
axarr[2].title.set_text('All classes') | |
SEGMENT_CLASSES | |
# Class-specific predictions | |
for i in range(1, 4): # Adjusted to loop over the available prediction classes | |
axarr[i + 2].imshow(predictions[slice_index, :, :, i], cmap="OrRd", interpolation='none', alpha=0.3) | |
axarr[i + 2].title.set_text(f'{SEGMENT_CLASSES[i]} predicted') | |
# Convert plot to image | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close(f) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
examples = { | |
"Example 1": {"flair": "examples/ex_1/BraTS20_Training_001_flair.nii", | |
"t1ce": "examples/ex_1/BraTS20_Training_001_t1ce.nii", | |
"seg": "examples/ex_1/BraTS20_Training_001_seg.nii"}, | |
"Example 2": {"flair": "examples/ex_2/BraTS20_Training_002_flair.nii", | |
"t1ce": "examples/ex_2/BraTS20_Training_002_t1ce.nii", | |
"seg": "examples/ex_2/BraTS20_Training_002_seg.nii"}, | |
} | |
def automatic_process(example_key): | |
paths = examples[example_key] | |
print(paths["flair"]) | |
flair = nib.load(paths["flair"]).get_fdata() | |
t1ce = nib.load(paths["t1ce"]).get_fdata() | |
seg = nib.load(paths["seg"]).get_fdata() | |
# Default slice index | |
slice_index = 50 | |
return process_and_display_direct(flair, t1ce, seg, slice_index) | |
def process_and_display_direct(flair_data, t1ce_data, seg_data, slice_index): | |
flair = np.array(flair_data) | |
t1ce = np.array(t1ce_data) | |
seg = np.array(seg_data) | |
p = predictByPath(flair, t1ce) | |
# Create the subplot image | |
subplot_img = create_subplot_image(flair, seg, p, slice_index, VOLUME_START_AT, IMG_SIZE) | |
return subplot_img | |
def process_and_display(flair_file, t1ce_file, seg_file, slice_index): | |
if not flair_file or not t1ce_file or not seg_file: | |
return None # Ensure all files are uploaded | |
flair = nib.load(flair_file.name).get_fdata() | |
t1ce = nib.load(t1ce_file.name).get_fdata() | |
gt = nib.load(seg_file.name).get_fdata() | |
p = predictByPath(flair, t1ce) | |
# Create the subplot image | |
subplot_img = create_subplot_image(flair, gt, p, slice_index, VOLUME_START_AT, IMG_SIZE) | |
return subplot_img | |
title = "<center><strong><font size='8'>Open-Vocabulary SAM<font></strong></center>" | |
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" | |
# Gradio Interface | |
with gr.Blocks(css=css, title="Tumor Segmentation") as demo: | |
gr.Markdown( | |
""" | |
<p style="text-align: center; font-size: 24px;">MRI Brain Tumor Segmentation</p> | |
<p style="text-align: center;">made by Ahmet Duzduran</p> | |
### <p style="text-align: left;">Faculty: Faculty of Computer Science</p> | |
### <p style="text-align: left;">Specialization: Intelligent Systems and Data Science</p> | |
### <p style="text-align: left;">Supervisor: Wojciech Oronowicz, PhD, Prof. Of PJATK</p> | |
""" | |
) | |
with gr.Row(): | |
flair_input = gr.File(label="Upload Flair NIfTI File") | |
t1ce_input = gr.File(label="Upload T1ce NIfTI File") | |
seg_input = gr.File(label="Upload Seg NIfTI File") | |
slice_input = gr.Slider(minimum=0, maximum=VOLUME_SLICES - 1, label="Slice Index") | |
# eval_class_input = gr.Dropdown(choices=list(range(len(SEGMENT_CLASSES))), label="Select Class") | |
submit_button = gr.Button("Submit") | |
with gr.Row(): | |
example_selector = gr.Dropdown(list(examples.keys()), label="Select Example") | |
auto_button = gr.Button("Load Example") | |
output_image = gr.Image(label="Visualization") | |
submit_button.click( | |
process_and_display, | |
inputs=[flair_input, t1ce_input, seg_input, slice_input], | |
outputs=output_image | |
) | |
auto_button.click( | |
automatic_process, | |
inputs=[example_selector], | |
outputs=output_image | |
) | |
demo.launch() | |