File size: 2,701 Bytes
b6cd222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Load the trained model
@st.cache_resource
def load_trained_model():
    model_path = "segnet_model.keras"  
    return load_model(model_path)

def predict_segmentation(model, image, target_size=(256, 256)):
    image = image.resize(target_size)
    image_array = img_to_array(image) / 255.0
    image_array = np.expand_dims(image_array, axis=0)  

    prediction = model.predict(image_array)[0]  
    return prediction

def create_mask_plot(mask, colormap, labels):
    

    fig, ax = plt.subplots(figsize=(5, 5))
    ax.imshow(mask.squeeze(), cmap=colormap, vmin=0, vmax=len(labels) - 1)
    ax.axis("off")
    
    # Add a legend
    legend_patches = [
        plt.Line2D([0], [0], color=colormap.colors[i], lw=4, label=label)
        for i, label in enumerate(labels)
    ]
    ax.legend(handles=legend_patches, loc="upper right", bbox_to_anchor=(1.2, 1.0))
    
    # Convert the Matplotlib figure to a PIL Image
    fig.canvas.draw()
    image = np.array(fig.canvas.renderer.buffer_rgba())
    plt.close(fig)
    return Image.fromarray(image)

# Streamlit App
def main():
    st.title("Flood Area Segmentation")
    st.write("Upload an image to predict its segmentation mask.")

    # Load the model (cached)
    model = load_trained_model()

    # File uploader
    uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
    
    if uploaded_file is not None:
        # Load and display the uploaded image
        image = Image.open(uploaded_file)

        # Predict segmentation
        with st.spinner("Predicting..."):
            predicted_mask = predict_segmentation(model, image)

        # Define custom colormap and labels
        colormap = ListedColormap(["green", "blue"])  # Green: Non-Flooded, Blue: Flooded
        labels = ["Non-Flooded Area", "Flooded Area"]

        # Create the mask visualization
        mask_image = create_mask_plot(predicted_mask, colormap, labels)

        # Display results side by side
        st.subheader("Results")
        col1, col2 = st.columns(2)
        
        with col1:
            st.write("### Original Image")
            st.image(image, caption="Original Image", use_container_width=True)

        with col2:
            st.write("### Predicted Mask")
            st.image(mask_image, caption="Predicted Mask", use_container_width=True)

if __name__ == "__main__":
    main()