Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing.image import img_to_array | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import ListedColormap | |
# Load the trained model | |
def load_trained_model(): | |
model_path = "segnet_model.keras" | |
return load_model(model_path) | |
def predict_segmentation(model, image, target_size=(256, 256)): | |
image = image.resize(target_size) | |
image_array = img_to_array(image) / 255.0 | |
image_array = np.expand_dims(image_array, axis=0) | |
prediction = model.predict(image_array)[0] | |
return prediction | |
def create_mask_plot(mask, colormap, labels): | |
fig, ax = plt.subplots(figsize=(5, 5)) | |
ax.imshow(mask.squeeze(), cmap=colormap, vmin=0, vmax=len(labels) - 1) | |
ax.axis("off") | |
# Add a legend | |
legend_patches = [ | |
plt.Line2D([0], [0], color=colormap.colors[i], lw=4, label=label) | |
for i, label in enumerate(labels) | |
] | |
ax.legend(handles=legend_patches, loc="upper right", bbox_to_anchor=(1.2, 1.0)) | |
# Convert the Matplotlib figure to a PIL Image | |
fig.canvas.draw() | |
image = np.array(fig.canvas.renderer.buffer_rgba()) | |
plt.close(fig) | |
return Image.fromarray(image) | |
# Streamlit App | |
def main(): | |
st.title("Flood Area Segmentation") | |
st.write("Upload an image to predict its segmentation mask.") | |
# Load the model (cached) | |
model = load_trained_model() | |
# File uploader | |
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"]) | |
if uploaded_file is not None: | |
# Load and display the uploaded image | |
image = Image.open(uploaded_file) | |
# Predict segmentation | |
with st.spinner("Predicting..."): | |
predicted_mask = predict_segmentation(model, image) | |
# Define custom colormap and labels | |
colormap = ListedColormap(["green", "blue"]) # Green: Non-Flooded, Blue: Flooded | |
labels = ["Non-Flooded Area", "Flooded Area"] | |
# Create the mask visualization | |
mask_image = create_mask_plot(predicted_mask, colormap, labels) | |
# Display results side by side | |
st.subheader("Results") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("### Original Image") | |
st.image(image, caption="Original Image", use_container_width=True) | |
with col2: | |
st.write("### Predicted Mask") | |
st.image(mask_image, caption="Predicted Mask", use_container_width=True) | |
if __name__ == "__main__": | |
main() | |