import gradio as gr import torch #import spaces import json import base64 from io import BytesIO from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor import os import pandas as pd from utils import * from PIL import Image # Carga de modelos sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge") sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge") sam_model = SamModel.from_pretrained("facebook/sam-vit-huge") sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") #@spaces.GPU def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None): if input_boxes is not None: input_boxes = [input_boxes] inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores return masks, scores def encode_pil_to_base64(pil_image): buffer = BytesIO() pil_image.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") def compare_images_points_and_masks(user_image, input_boxes, input_points): for example_path, example_data in example_data_map.items(): if example_data["size"] == list(user_image.size): user_image = Image.open(example_data['original_image_path']) input_boxes = input_boxes.values.tolist() input_points = input_points.values.tolist() input_boxes = [[[int(coord) for coord in box] for box in input_boxes if any(box)]] input_points = [[[int(coord) for coord in point] for point in input_points if any(point)]] input_boxes = input_boxes if input_boxes[0] else None input_points = input_points if input_points[0] else None sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points) sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points) if input_boxes and input_points: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ') elif input_boxes: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ') elif input_points: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ') print('user_image', user_image) print("img1_b64", img1_b64) print("img2_b64", img2_b64) html_code = f"""
""" return html_code def load_examples(json_file="examples.json"): with open(json_file, "r") as f: examples = json.load(f) return examples examples = load_examples() example_paths = [example["image_path"] for example in examples] example_data_map = { example["image_path"]: { "original_image_path": example["original_image_path"], "points": example["points"], "boxes": example["boxes"], "size": example["size"] } for example in examples } theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald") with gr.Blocks(theme=theme, title="🔍 Compare SAM vs SAM-HQ") as demo: image_path_box = gr.Textbox(visible=False) gr.Markdown("## 🔍 Compare SAM vs SAM-HQ") gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it") gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)") with gr.Row(): image_input = gr.Image( type="pil", label="Example image (click below to load)", interactive=False, height=500, show_label=True ) gr.Examples( examples=example_paths, inputs=[image_input], label="Click an example to try 👇", ) result_html = gr.HTML(elem_id="result-html") with gr.Row(): points_input = gr.Dataframe( headers=["x", "y"], label="Points", datatype=["number", "number"], col_count=(2, "fixed") ) boxes_input = gr.Dataframe( headers=["x0", "y0", "x1", "y1"], label="Boxes", datatype=["number", "number", "number", "number"], col_count=(4, "fixed") ) def on_image_change(image): for example_path, example_data in example_data_map.items(): print(image.size) if example_data["size"] == list(image.size): return example_data["points"], example_data["boxes"] return [], [] image_input.change( fn=on_image_change, inputs=[image_input], outputs=[points_input, boxes_input] ) compare_button = gr.Button("Compare points and masks") compare_button.click(fn=compare_images_points_and_masks, inputs=[image_input, boxes_input, points_input], outputs=result_html) gr.HTML(""" """) demo.launch()