import streamlit as st import numpy as np from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.image import img_to_array from PIL import Image import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap # Load the trained model @st.cache_resource def load_trained_model(): model_path = "segnet_model.keras" return load_model(model_path) def predict_segmentation(model, image, target_size=(256, 256)): image = image.resize(target_size) image_array = img_to_array(image) / 255.0 image_array = np.expand_dims(image_array, axis=0) prediction = model.predict(image_array)[0] return prediction def create_mask_plot(mask, colormap, labels): fig, ax = plt.subplots(figsize=(5, 5)) ax.imshow(mask.squeeze(), cmap=colormap, vmin=0, vmax=len(labels) - 1) ax.axis("off") # Add a legend legend_patches = [ plt.Line2D([0], [0], color=colormap.colors[i], lw=4, label=label) for i, label in enumerate(labels) ] ax.legend(handles=legend_patches, loc="upper right", bbox_to_anchor=(1.2, 1.0)) # Convert the Matplotlib figure to a PIL Image fig.canvas.draw() image = np.array(fig.canvas.renderer.buffer_rgba()) plt.close(fig) return Image.fromarray(image) # Streamlit App def main(): st.title("Flood Area Segmentation") st.write("Upload an image to predict its segmentation mask.") # Load the model (cached) model = load_trained_model() # File uploader uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: # Load and display the uploaded image image = Image.open(uploaded_file) # Predict segmentation with st.spinner("Predicting..."): predicted_mask = predict_segmentation(model, image) # Define custom colormap and labels colormap = ListedColormap(["green", "blue"]) # Green: Non-Flooded, Blue: Flooded labels = ["Non-Flooded Area", "Flooded Area"] # Create the mask visualization mask_image = create_mask_plot(predicted_mask, colormap, labels) # Display results side by side st.subheader("Results") col1, col2 = st.columns(2) with col1: st.write("### Original Image") st.image(image, caption="Original Image", use_container_width=True) with col2: st.write("### Predicted Mask") st.image(mask_image, caption="Predicted Mask", use_container_width=True) if __name__ == "__main__": main()