File size: 7,385 Bytes
d6ff06e
 
ca815e1
d6ff06e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca815e1
d6ff06e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gradio as gr
import torch
#import spaces
import json
import base64
from io import BytesIO
from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor
import os
import pandas as pd
from utils import *
from PIL import Image

# Carga de modelos
sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge")
sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")

sam_model = SamModel.from_pretrained("facebook/sam-vit-huge")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

#@spaces.GPU
def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None):
    if input_boxes is not None:
        input_boxes = [input_boxes]
    inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
    )
    scores = outputs.iou_scores
    return masks, scores

def encode_pil_to_base64(pil_image):
    buffer = BytesIO()
    pil_image.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")

def compare_images_points_and_masks(user_image, input_boxes, input_points):
    for example_path, example_data in example_data_map.items():
        if example_data["size"] == list(user_image.size):
            user_image = Image.open(example_data['original_image_path'])
    input_boxes = input_boxes.values.tolist()
    input_points = input_points.values.tolist()

    input_boxes = [[[int(coord) for coord in box] for box in input_boxes if any(box)]]
    input_points = [[[int(coord) for coord in point] for point in input_points if any(point)]]

    input_boxes = input_boxes if input_boxes[0] else None
    input_points = input_points if input_points[0] else None

    sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points)
    sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points)

    if input_boxes and input_points:
        img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')
        img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ')
    elif input_boxes:
        img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM')
        img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ')
    elif input_points:
        img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM')
        img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ')

    print('user_image', user_image)
    print("img1_b64", img1_b64)
    print("img2_b64", img2_b64)

    html_code = f"""
    <div style="position: relative; width: 100%; max-width: 600px; margin: 0 auto;" id="imageCompareContainer">
        <div style="position: relative; width: 100%;">
            <img src="data:image/png;base64,{img1_b64}" style="width:100%; display:block;">
            <div id="topWrapper" style="position:absolute; top:0; left:0; width:100%; overflow:hidden;">
                <img id="topImage" src="data:image/png;base64,{img2_b64}" style="width:100%;">
            </div>
            <div id="sliderLine" style="position:absolute; top:0; left:0; width:2px; height:100%; background-color:red; pointer-events:none;"></div>
        </div>
        <input type="range" min="0" max="100" value="0" 
            style="width:100%; margin-top: 10px;" 
            oninput="
            const val = this.value;
            const container = document.getElementById('imageCompareContainer');
            const width = container.offsetWidth;
            const clipValue = 100 - val;
            document.getElementById('topImage').style.clipPath = 'inset(0 ' + clipValue + '% 0 0)';
            document.getElementById('sliderLine').style.left = (width * val / 100) + 'px';
            ">
    </div>
    """
    return html_code

def load_examples(json_file="examples.json"):
    with open(json_file, "r") as f:
        examples = json.load(f)
    return examples

examples = load_examples()
example_paths = [example["image_path"] for example in examples]
example_data_map = {
    example["image_path"]: {
        "original_image_path": example["original_image_path"],
        "points": example["points"],
        "boxes": example["boxes"],
        "size": example["size"]
    }
    for example in examples
}

theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
with gr.Blocks(theme=theme, title="πŸ” Compare SAM vs SAM-HQ") as demo:
    image_path_box = gr.Textbox(visible=False)
    gr.Markdown("## πŸ” Compare SAM vs SAM-HQ")
    gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it")
    gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)")

    with gr.Row():
        image_input = gr.Image(
            type="pil",
            label="Example image (click below to load)",
            interactive=False,
            height=500,
            show_label=True
        )
    
    gr.Examples(
        examples=example_paths,
        inputs=[image_input],
        label="Click an example to try πŸ‘‡",
    )

    result_html = gr.HTML(elem_id="result-html")

    with gr.Row():
        points_input = gr.Dataframe(
            headers=["x", "y"],
            label="Points",
            datatype=["number", "number"],
            col_count=(2, "fixed")
        )
        boxes_input = gr.Dataframe(
            headers=["x0", "y0", "x1", "y1"],
            label="Boxes",
            datatype=["number", "number", "number", "number"],
            col_count=(4, "fixed")
        )

    def on_image_change(image):
        for example_path, example_data in example_data_map.items():
            print(image.size)
            if example_data["size"] == list(image.size):
                return example_data["points"], example_data["boxes"]
        return [], []

    image_input.change(
        fn=on_image_change,
        inputs=[image_input],
        outputs=[points_input, boxes_input]
    )

    compare_button = gr.Button("Compare points and masks")
    compare_button.click(fn=compare_images_points_and_masks, inputs=[image_input, boxes_input, points_input], outputs=result_html)

    gr.HTML("""
        <style>
        #result-html {
            min-height: 500px;
            border: 1px solid #ccc;
            padding: 10px;
            box-sizing: border-box;
            background-color: #fff;
            border-radius: 8px;
            box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
        }
        </style>
    """)

demo.launch()