File size: 4,545 Bytes
6435d5a
9843137
6435d5a
 
 
9843137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6435d5a
 
31361ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9843137
6435d5a
 
 
 
 
 
9843137
6435d5a
9843137
 
 
 
 
6435d5a
9843137
 
6435d5a
9843137
 
 
6435d5a
9843137
 
 
6435d5a
9843137
 
 
6435d5a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from transformers import SamModel, SamProcessor, pipeline
from PIL import Image, ImageOps
import numpy as np
import torch

# Constants
XS_YS = [(2.0, 2.0), (2.5, 2.5)]
WIDTH = 600
 
# Load models
@st.cache_resource
def load_models():
    model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
    processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
    od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
    return model, processor, od_pipe

def process_image(image, model, processor, bounding_box=None, input_point=None):
    try:
        # Convert image to RGB mode
        image = image.convert('RGB')
        # Convert image to numpy array
        image_array = np.array(image)

        if bounding_box:
            inputs = processor(images=image_array, input_boxes=[bounding_box], return_tensors="pt")
        elif input_point:
            inputs = processor(images=image_array, input_points=[[input_point]], return_tensors="pt")
        else:
            raise ValueError("Either bounding_box or input_point must be provided")

        with torch.no_grad():
            outputs = model(**inputs)

        predicted_masks = processor.image_processor.post_process_masks(
            outputs.pred_masks,
            inputs["original_sizes"],
            inputs["reshaped_input_sizes"]
        )
        return predicted_masks[0]
    except Exception as e:
        st.error(f"Error processing image: {str(e)}")
        return None

def display_masked_images(raw_image, predicted_mask, caption_prefix):
    for i in range(3):
        mask = predicted_mask[0][i]
        int_mask = np.array(mask).astype(int) * 255
        mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
        final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
        st.image(final_image, caption=f"{caption_prefix} {i+1}", width=WIDTH)

def main():
    st.title("Image Segmentation with Object Detection")

    # Introduction and How-to
    st.markdown("""
    Welcome to the Image Segmentation and Object Detection app, where cutting-edge AI models bring your images to life by identifying and segmenting objects. Here's how it works:
    - **Upload an image**: Drag and drop or use the browse files option.
    - **Detection**: The `facebook/detr-resnet-50` model detects objects and their bounding boxes.
    - **Segmentation**: Following detection, `Zigeng/SlimSAM-uniform-77` segments the objects using the bounding box data.
    - **Further Segmentation**: The app also provides additional segmentation insights using input points at positions (0.4, 0.4) and (0.5, 0.5) for a more granular analysis.
    
    Please note that processing takes some time. We appreciate your patience as the models do their work!
    """)
    
    # Model credits
    st.subheader("Powered by:")
    st.write("- Object Detection Model: `facebook/detr-resnet-50`")
    st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`")

    model, processor, od_pipe = load_models()

    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

    if uploaded_file is not None:
        raw_image = Image.open(uploaded_file)
        st.subheader("Uploaded Image")
        st.image(raw_image, caption="Uploaded Image", width=WIDTH)

        with st.spinner('Processing image...'):
            # Object Detection
            pipeline_output = od_pipe(raw_image)
            input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output]
            labels_format = [b['label'] for b in pipeline_output]

            # Process bounding boxes
            for b, l in zip(input_boxes_format, labels_format):
                st.subheader(f'bounding box : {l}')
                predicted_mask = process_image(raw_image, model, processor, bounding_box=b)
                if predicted_mask is not None:
                    display_masked_images(raw_image, predicted_mask, "Masked Image")

            # Process input points
            for (x, y) in XS_YS:
                point_x, point_y = raw_image.size[0] // x, raw_image.size[1] // y
                st.subheader(f"Input points : ({1/x},{1/y})")
                predicted_mask = process_image(raw_image, model, processor, input_point=[point_x, point_y])
                if predicted_mask is not None:
                    display_masked_images(raw_image, predicted_mask, "Masked Image")

if __name__ == "__main__":
    main()