alwin00007's picture
uploaded model
b6cd222 verified
import streamlit as st
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
# Load the trained model
@st.cache_resource
def load_trained_model():
model_path = "segnet_model.keras"
return load_model(model_path)
def predict_segmentation(model, image, target_size=(256, 256)):
image = image.resize(target_size)
image_array = img_to_array(image) / 255.0
image_array = np.expand_dims(image_array, axis=0)
prediction = model.predict(image_array)[0]
return prediction
def create_mask_plot(mask, colormap, labels):
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(mask.squeeze(), cmap=colormap, vmin=0, vmax=len(labels) - 1)
ax.axis("off")
# Add a legend
legend_patches = [
plt.Line2D([0], [0], color=colormap.colors[i], lw=4, label=label)
for i, label in enumerate(labels)
]
ax.legend(handles=legend_patches, loc="upper right", bbox_to_anchor=(1.2, 1.0))
# Convert the Matplotlib figure to a PIL Image
fig.canvas.draw()
image = np.array(fig.canvas.renderer.buffer_rgba())
plt.close(fig)
return Image.fromarray(image)
# Streamlit App
def main():
st.title("Flood Area Segmentation")
st.write("Upload an image to predict its segmentation mask.")
# Load the model (cached)
model = load_trained_model()
# File uploader
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
# Load and display the uploaded image
image = Image.open(uploaded_file)
# Predict segmentation
with st.spinner("Predicting..."):
predicted_mask = predict_segmentation(model, image)
# Define custom colormap and labels
colormap = ListedColormap(["green", "blue"]) # Green: Non-Flooded, Blue: Flooded
labels = ["Non-Flooded Area", "Flooded Area"]
# Create the mask visualization
mask_image = create_mask_plot(predicted_mask, colormap, labels)
# Display results side by side
st.subheader("Results")
col1, col2 = st.columns(2)
with col1:
st.write("### Original Image")
st.image(image, caption="Original Image", use_container_width=True)
with col2:
st.write("### Predicted Mask")
st.image(mask_image, caption="Predicted Mask", use_container_width=True)
if __name__ == "__main__":
main()