brain-tumor / app.py
subek's picture
Create app.py
baacefd verified
raw
history blame
No virus
2.26 kB
import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
from PIL import Image
import base64
# Load the pre-trained brain tumor segmentation model
model_path = 'brain_tumor_segmentation_model.h5' # Replace with your actual model path
model = tf.keras.models.load_model(model_path)
img_size = (256, 256) # Adjust based on your model's input size
st.set_page_config(
page_title="Brain Tumor Segmentation App",
page_icon=":brain:",
layout="wide"
)
custom_style = """
<style>
div[data-testid="stToolbar"],
div[data-testid="stDecoration"],
div[data-testid="stStatusWidget"],
#MainMenu,
header,
footer {
visibility: hidden;
height: 0%;
}
</style>
"""
st.markdown(custom_style, unsafe_allow_html=True)
def preprocess_image(img):
img = img.resize(img_size)
img_array = image.img_to_array(img)
img_array = img_array / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
def predict_tumor_segmentation(img):
img_array = preprocess_image(img)
segmentation_mask = model.predict(img_array)
segmentation_mask = np.squeeze(segmentation_mask, axis=0)
return segmentation_mask
def display_segmentation_result(original_image, segmentation_mask):
original_image = np.array(original_image)
segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8) * 255
segmentation_mask = Image.fromarray(segmentation_mask, 'L')
st.image([original_image, segmentation_mask], caption=['Original Image', 'Segmentation Mask'], use_column_width=True)
def main():
st.title("Brain Tumor Segmentation App")
uploaded_file = st.file_uploader("Upload an MRI image for tumor segmentation...", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
original_image = Image.open(uploaded_file)
st.image(original_image, caption="Uploaded Image", use_column_width=True)
st.markdown("## Tumor Segmentation Result")
# Perform segmentation
segmentation_mask = predict_tumor_segmentation(original_image)
# Display the segmentation result
display_segmentation_result(original_image, segmentation_mask)
if __name__ == "__main__":
main()