import os import cv2 import supervision as sv import numpy as np import gradio as gr from transformers import pipeline from PIL import Image # Definición de la clase SamAutomaticMaskGenerator class SamAutomaticMaskGenerator: def __init__(self, sam_pipeline): self.sam_pipeline = sam_pipeline def generate(self, image_rgb): outputs = self.sam_pipeline(image_rgb, points_per_batch=32) mask = np.array(outputs['masks'], dtype=np.uint8) return mask # Configuración del modelo SAM DEVICE = "cuda" if torch.cuda.is_available() else "cpu" sam_pipeline = pipeline( task="mask-generation", model="facebook/sam-vit-large", device=DEVICE ) mask_generator = SamAutomaticMaskGenerator(sam_pipeline) # Función para procesar y anotar la imagen def process_image(image_pil): image_rgb = np.array(image_pil) image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) sam_result = mask_generator.generate(image_rgb) mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) detections = sv.Detections.from_sam(sam_result=sam_result) annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections) annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) return Image.fromarray(image_rgb), Image.fromarray(annotated_image_rgb) # Construcción de la interfaz Gradio with gr.Blocks() as demo: gr.Markdown("# SAM - Segmentación de Imágenes") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Cargar Imagen") submit_button = gr.Button("Segmentar") with gr.Column(): original_image = gr.Image(type="pil", label="Imagen Original") segmented_image = gr.Image(type="pil", label="Imagen Segmentada") submit_button.click( process_image, inputs=input_image, outputs=[original_image, segmented_image] ) demo.launch(debug=True)