from typing import List import os import cv2 import supervision as sv import numpy as np import gradio as gr import torch from transformers import pipeline from PIL import Image # Definición de la clase SamAutomaticMaskGenerator class SamAutomaticMaskGenerator: def __init__(self, sam_pipeline): self.sam_pipeline = sam_pipeline def generate(self, image_rgb): outputs = self.sam_pipeline(image_rgb, points_per_batch=32) mask = np.array(outputs['masks'], dtype=np.uint8) return mask # Configuración del modelo SAM DEVICE = "cuda" if torch.cuda.is_available() else "cpu" sam_pipeline = pipeline( task="mask-generation", model="facebook/sam-vit-large", device=DEVICE ) EXAMPLES = [ ["https://media.roboflow.com/notebooks/examples/dog.jpeg"], ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg"] ] mask_generator = SamAutomaticMaskGenerator(sam_pipeline) # Función para procesar y anotar la imagen def process_image(image_pil): # Convertir PIL Image a numpy array para procesamiento image_rgb = np.array(image_pil) image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) # Generar la máscara y anotar la imagen sam_result = mask_generator.generate(image_rgb) mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) detections = sv.Detections.from_sam(sam_result=sam_result) annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections) # Convertir de nuevo a formato RGB y luego a PIL Image para Gradio annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) annotated_image_pil = Image.fromarray(annotated_image_rgb) return image_pil, annotated_image_pil # Construcción de la interfaz Gradio with gr.Blocks() as demo: gr.Markdown("# SAM - Segmentación de Imágenes") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Cargar Imagen") submit_button = gr.Button("Segmentar") with gr.Column(): original_image = gr.Image(type="pil", label="Imagen Original") segmented_image = gr.Image(type="pil", label="Imagen Segmentada") submit_button.click( process_image, inputs=input_image, outputs=[original_image, segmented_image] ) with gr.Row(): gr.Examples( examples=EXAMPLES, fn=inference, inputs=[input_image], outputs=[gallery], cache_examples=False, run_on_click=True ) demo.launch(debug=True)