Spaces:
Running
Running
import streamlit as st | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing import image | |
import numpy as np | |
from PIL import Image | |
import base64 | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
H = 256 | |
W = 256 | |
from metrics import dice_loss, dice_coef | |
model_path = "model.h5" | |
model = tf.keras.models.load_model(model_path,custom_objects={'dice_loss': dice_loss, 'dice_coef': dice_coef}) | |
st.set_page_config( | |
page_title="Brain Tumor Segmentation App", | |
page_icon=":brain:", | |
layout="wide" | |
) | |
custom_style = """ | |
<style> | |
div[data-testid="stToolbar"], | |
div[data-testid="stDecoration"], | |
div[data-testid="stStatusWidget"], | |
#MainMenu, | |
header, | |
footer { | |
visibility: hidden; | |
height: 0%; | |
} | |
</style> | |
""" | |
st.markdown(custom_style, unsafe_allow_html=True) | |
# Function to perform inference | |
def perform_inference(image): | |
original_shape = image.shape[:2] | |
original_image = image.copy() | |
image = cv2.resize(image, (W, H)) | |
image = image / 255.0 | |
image = np.expand_dims(image, axis=0) | |
mask = model.predict(image, verbose=0)[0] | |
mask = cv2.resize(mask, (original_shape[1], original_shape[0])) | |
mask = mask >= 0.5 | |
mask = np.expand_dims(mask, axis=-1) | |
segmented_image = original_image * mask | |
return original_image, mask, segmented_image | |
# Function to display images using Matplotlib | |
def show_image(image, title="Image"): | |
plt.imshow(image, cmap='gray') | |
plt.title(title) | |
plt.axis('off') | |
st.pyplot() | |
# Function to download sample images | |
def download_sample_images(): | |
sample_images = ["Sample image 1.png", "Sample image 2.png", "Sample image 3.png"] | |
for image_name in sample_images: | |
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), image_name) | |
if os.path.exists(image_path): | |
with open(image_path, "rb") as f: | |
image_bytes = f.read() | |
st.download_button( | |
label=f"Download {image_name}", | |
data=image_bytes, | |
key=f"download_{image_name}", | |
file_name=image_name, | |
mime="image/jpeg", | |
) | |
else: | |
st.warning(f"Sample image {image_name} not found.") | |
# Streamlit app | |
def main(): | |
st.title("Brain Tumor Segmentation App") | |
# Allow user to upload an image | |
uploaded_file = st.file_uploader("Upload a brain scan image...", type=["jpg", "png", "jpeg"]) | |
st.markdown(""" | |
Example Instructions: | |
- Upload a brain scan image. | |
- Or, download sample images below and check the predictions. | |
""") | |
download_sample_images() | |
if uploaded_file is not None: | |
# Read the uploaded image | |
image = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR) | |
# Perform inference on the uploaded image | |
original_image, mask, segmented_image = perform_inference(image) | |
# Display images side by side | |
st.subheader("Results!") | |
fig, axs = plt.subplots(1, 3, figsize=(15, 5)) | |
# Display original image | |
axs[0].imshow(original_image) | |
axs[0].set_title("Original Image") | |
axs[0].axis('off') | |
# Display mask | |
axs[1].imshow(mask.squeeze(), cmap='gray') | |
axs[1].set_title("Mask") | |
axs[1].axis('off') | |
# Display segmented image | |
axs[2].imshow(segmented_image) | |
axs[2].set_title("Segmented Tumor") | |
axs[2].axis('off') | |
st.pyplot(fig) | |
if __name__ == "__main__": | |
main() | |