Spaces:
Sleeping
Sleeping
File size: 2,701 Bytes
b6cd222 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import streamlit as st
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
# Load the trained model
@st.cache_resource
def load_trained_model():
model_path = "segnet_model.keras"
return load_model(model_path)
def predict_segmentation(model, image, target_size=(256, 256)):
image = image.resize(target_size)
image_array = img_to_array(image) / 255.0
image_array = np.expand_dims(image_array, axis=0)
prediction = model.predict(image_array)[0]
return prediction
def create_mask_plot(mask, colormap, labels):
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(mask.squeeze(), cmap=colormap, vmin=0, vmax=len(labels) - 1)
ax.axis("off")
# Add a legend
legend_patches = [
plt.Line2D([0], [0], color=colormap.colors[i], lw=4, label=label)
for i, label in enumerate(labels)
]
ax.legend(handles=legend_patches, loc="upper right", bbox_to_anchor=(1.2, 1.0))
# Convert the Matplotlib figure to a PIL Image
fig.canvas.draw()
image = np.array(fig.canvas.renderer.buffer_rgba())
plt.close(fig)
return Image.fromarray(image)
# Streamlit App
def main():
st.title("Flood Area Segmentation")
st.write("Upload an image to predict its segmentation mask.")
# Load the model (cached)
model = load_trained_model()
# File uploader
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
# Load and display the uploaded image
image = Image.open(uploaded_file)
# Predict segmentation
with st.spinner("Predicting..."):
predicted_mask = predict_segmentation(model, image)
# Define custom colormap and labels
colormap = ListedColormap(["green", "blue"]) # Green: Non-Flooded, Blue: Flooded
labels = ["Non-Flooded Area", "Flooded Area"]
# Create the mask visualization
mask_image = create_mask_plot(predicted_mask, colormap, labels)
# Display results side by side
st.subheader("Results")
col1, col2 = st.columns(2)
with col1:
st.write("### Original Image")
st.image(image, caption="Original Image", use_container_width=True)
with col2:
st.write("### Predicted Mask")
st.image(mask_image, caption="Predicted Mask", use_container_width=True)
if __name__ == "__main__":
main()
|