|
import streamlit as st
|
|
import numpy as np
|
|
from tensorflow.keras.models import load_model
|
|
from tensorflow.keras.preprocessing.image import img_to_array
|
|
from PIL import Image
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.colors import ListedColormap
|
|
|
|
|
|
@st.cache_resource
|
|
def load_trained_model():
|
|
model_path = "segnet_model.keras"
|
|
return load_model(model_path)
|
|
|
|
def predict_segmentation(model, image, target_size=(256, 256)):
|
|
image = image.resize(target_size)
|
|
image_array = img_to_array(image) / 255.0
|
|
image_array = np.expand_dims(image_array, axis=0)
|
|
|
|
prediction = model.predict(image_array)[0]
|
|
return prediction
|
|
|
|
def create_mask_plot(mask, colormap, labels):
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(5, 5))
|
|
ax.imshow(mask.squeeze(), cmap=colormap, vmin=0, vmax=len(labels) - 1)
|
|
ax.axis("off")
|
|
|
|
|
|
legend_patches = [
|
|
plt.Line2D([0], [0], color=colormap.colors[i], lw=4, label=label)
|
|
for i, label in enumerate(labels)
|
|
]
|
|
ax.legend(handles=legend_patches, loc="upper right", bbox_to_anchor=(1.2, 1.0))
|
|
|
|
|
|
fig.canvas.draw()
|
|
image = np.array(fig.canvas.renderer.buffer_rgba())
|
|
plt.close(fig)
|
|
return Image.fromarray(image)
|
|
|
|
|
|
def main():
|
|
st.title("Flood Area Segmentation")
|
|
st.write("Upload an image to predict its segmentation mask.")
|
|
|
|
|
|
model = load_trained_model()
|
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
|
|
|
|
if uploaded_file is not None:
|
|
|
|
image = Image.open(uploaded_file)
|
|
|
|
|
|
with st.spinner("Predicting..."):
|
|
predicted_mask = predict_segmentation(model, image)
|
|
|
|
|
|
colormap = ListedColormap(["green", "blue"])
|
|
labels = ["Non-Flooded Area", "Flooded Area"]
|
|
|
|
|
|
mask_image = create_mask_plot(predicted_mask, colormap, labels)
|
|
|
|
|
|
st.subheader("Results")
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
st.write("### Original Image")
|
|
st.image(image, caption="Original Image", use_container_width=True)
|
|
|
|
with col2:
|
|
st.write("### Predicted Mask")
|
|
st.image(mask_image, caption="Predicted Mask", use_container_width=True)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|