File size: 7,385 Bytes
d6ff06e ca815e1 d6ff06e ca815e1 d6ff06e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import gradio as gr
import torch
#import spaces
import json
import base64
from io import BytesIO
from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor
import os
import pandas as pd
from utils import *
from PIL import Image
# Carga de modelos
sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge")
sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
#@spaces.GPU
def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None):
if input_boxes is not None:
input_boxes = [input_boxes]
inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
return masks, scores
def encode_pil_to_base64(pil_image):
buffer = BytesIO()
pil_image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def compare_images_points_and_masks(user_image, input_boxes, input_points):
for example_path, example_data in example_data_map.items():
if example_data["size"] == list(user_image.size):
user_image = Image.open(example_data['original_image_path'])
input_boxes = input_boxes.values.tolist()
input_points = input_points.values.tolist()
input_boxes = [[[int(coord) for coord in box] for box in input_boxes if any(box)]]
input_points = [[[int(coord) for coord in point] for point in input_points if any(point)]]
input_boxes = input_boxes if input_boxes[0] else None
input_points = input_points if input_points[0] else None
sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points)
sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points)
if input_boxes and input_points:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ')
elif input_boxes:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ')
elif input_points:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ')
print('user_image', user_image)
print("img1_b64", img1_b64)
print("img2_b64", img2_b64)
html_code = f"""
<div style="position: relative; width: 100%; max-width: 600px; margin: 0 auto;" id="imageCompareContainer">
<div style="position: relative; width: 100%;">
<img src="data:image/png;base64,{img1_b64}" style="width:100%; display:block;">
<div id="topWrapper" style="position:absolute; top:0; left:0; width:100%; overflow:hidden;">
<img id="topImage" src="data:image/png;base64,{img2_b64}" style="width:100%;">
</div>
<div id="sliderLine" style="position:absolute; top:0; left:0; width:2px; height:100%; background-color:red; pointer-events:none;"></div>
</div>
<input type="range" min="0" max="100" value="0"
style="width:100%; margin-top: 10px;"
oninput="
const val = this.value;
const container = document.getElementById('imageCompareContainer');
const width = container.offsetWidth;
const clipValue = 100 - val;
document.getElementById('topImage').style.clipPath = 'inset(0 ' + clipValue + '% 0 0)';
document.getElementById('sliderLine').style.left = (width * val / 100) + 'px';
">
</div>
"""
return html_code
def load_examples(json_file="examples.json"):
with open(json_file, "r") as f:
examples = json.load(f)
return examples
examples = load_examples()
example_paths = [example["image_path"] for example in examples]
example_data_map = {
example["image_path"]: {
"original_image_path": example["original_image_path"],
"points": example["points"],
"boxes": example["boxes"],
"size": example["size"]
}
for example in examples
}
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
with gr.Blocks(theme=theme, title="π Compare SAM vs SAM-HQ") as demo:
image_path_box = gr.Textbox(visible=False)
gr.Markdown("## π Compare SAM vs SAM-HQ")
gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it")
gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)")
with gr.Row():
image_input = gr.Image(
type="pil",
label="Example image (click below to load)",
interactive=False,
height=500,
show_label=True
)
gr.Examples(
examples=example_paths,
inputs=[image_input],
label="Click an example to try π",
)
result_html = gr.HTML(elem_id="result-html")
with gr.Row():
points_input = gr.Dataframe(
headers=["x", "y"],
label="Points",
datatype=["number", "number"],
col_count=(2, "fixed")
)
boxes_input = gr.Dataframe(
headers=["x0", "y0", "x1", "y1"],
label="Boxes",
datatype=["number", "number", "number", "number"],
col_count=(4, "fixed")
)
def on_image_change(image):
for example_path, example_data in example_data_map.items():
print(image.size)
if example_data["size"] == list(image.size):
return example_data["points"], example_data["boxes"]
return [], []
image_input.change(
fn=on_image_change,
inputs=[image_input],
outputs=[points_input, boxes_input]
)
compare_button = gr.Button("Compare points and masks")
compare_button.click(fn=compare_images_points_and_masks, inputs=[image_input, boxes_input, points_input], outputs=result_html)
gr.HTML("""
<style>
#result-html {
min-height: 500px;
border: 1px solid #ccc;
padding: 10px;
box-sizing: border-box;
background-color: #fff;
border-radius: 8px;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
}
</style>
""")
demo.launch()
|