sergiopaniego's picture
Updated
ca815e1
import gradio as gr
import torch
#import spaces
import json
import base64
from io import BytesIO
from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor
import os
import pandas as pd
from utils import *
from PIL import Image
# Carga de modelos
sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge")
sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
#@spaces.GPU
def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None):
if input_boxes is not None:
input_boxes = [input_boxes]
inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
return masks, scores
def encode_pil_to_base64(pil_image):
buffer = BytesIO()
pil_image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def compare_images_points_and_masks(user_image, input_boxes, input_points):
for example_path, example_data in example_data_map.items():
if example_data["size"] == list(user_image.size):
user_image = Image.open(example_data['original_image_path'])
input_boxes = input_boxes.values.tolist()
input_points = input_points.values.tolist()
input_boxes = [[[int(coord) for coord in box] for box in input_boxes if any(box)]]
input_points = [[[int(coord) for coord in point] for point in input_points if any(point)]]
input_boxes = input_boxes if input_boxes[0] else None
input_points = input_points if input_points[0] else None
sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points)
sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points)
if input_boxes and input_points:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ')
elif input_boxes:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ')
elif input_points:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ')
print('user_image', user_image)
print("img1_b64", img1_b64)
print("img2_b64", img2_b64)
html_code = f"""
<div style="position: relative; width: 100%; max-width: 600px; margin: 0 auto;" id="imageCompareContainer">
<div style="position: relative; width: 100%;">
<img src="data:image/png;base64,{img1_b64}" style="width:100%; display:block;">
<div id="topWrapper" style="position:absolute; top:0; left:0; width:100%; overflow:hidden;">
<img id="topImage" src="data:image/png;base64,{img2_b64}" style="width:100%;">
</div>
<div id="sliderLine" style="position:absolute; top:0; left:0; width:2px; height:100%; background-color:red; pointer-events:none;"></div>
</div>
<input type="range" min="0" max="100" value="0"
style="width:100%; margin-top: 10px;"
oninput="
const val = this.value;
const container = document.getElementById('imageCompareContainer');
const width = container.offsetWidth;
const clipValue = 100 - val;
document.getElementById('topImage').style.clipPath = 'inset(0 ' + clipValue + '% 0 0)';
document.getElementById('sliderLine').style.left = (width * val / 100) + 'px';
">
</div>
"""
return html_code
def load_examples(json_file="examples.json"):
with open(json_file, "r") as f:
examples = json.load(f)
return examples
examples = load_examples()
example_paths = [example["image_path"] for example in examples]
example_data_map = {
example["image_path"]: {
"original_image_path": example["original_image_path"],
"points": example["points"],
"boxes": example["boxes"],
"size": example["size"]
}
for example in examples
}
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
with gr.Blocks(theme=theme, title="πŸ” Compare SAM vs SAM-HQ") as demo:
image_path_box = gr.Textbox(visible=False)
gr.Markdown("## πŸ” Compare SAM vs SAM-HQ")
gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it")
gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)")
with gr.Row():
image_input = gr.Image(
type="pil",
label="Example image (click below to load)",
interactive=False,
height=500,
show_label=True
)
gr.Examples(
examples=example_paths,
inputs=[image_input],
label="Click an example to try πŸ‘‡",
)
result_html = gr.HTML(elem_id="result-html")
with gr.Row():
points_input = gr.Dataframe(
headers=["x", "y"],
label="Points",
datatype=["number", "number"],
col_count=(2, "fixed")
)
boxes_input = gr.Dataframe(
headers=["x0", "y0", "x1", "y1"],
label="Boxes",
datatype=["number", "number", "number", "number"],
col_count=(4, "fixed")
)
def on_image_change(image):
for example_path, example_data in example_data_map.items():
print(image.size)
if example_data["size"] == list(image.size):
return example_data["points"], example_data["boxes"]
return [], []
image_input.change(
fn=on_image_change,
inputs=[image_input],
outputs=[points_input, boxes_input]
)
compare_button = gr.Button("Compare points and masks")
compare_button.click(fn=compare_images_points_and_masks, inputs=[image_input, boxes_input, points_input], outputs=result_html)
gr.HTML("""
<style>
#result-html {
min-height: 500px;
border: 1px solid #ccc;
padding: 10px;
box-sizing: border-box;
background-color: #fff;
border-radius: 8px;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
}
</style>
""")
demo.launch()