File size: 2,725 Bytes
28924f5
 
1991283
 
f6b477c
1991283
 
28924f5
 
605391e
1991283
f6b477c
1991283
 
 
 
605391e
1991283
f72a933
 
 
1991283
 
f6b477c
1991283
 
 
 
 
 
f6b477c
88e9c63
 
 
 
 
f6b477c
1991283
f6b477c
1991283
 
88e9c63
1991283
 
f6b477c
88e9c63
1991283
 
 
 
88e9c63
 
1991283
88e9c63
f6b477c
88e9c63
605391e
1991283
f6b477c
1991283
f6b477c
 
1991283
 
 
 
 
 
f6b477c
1991283
 
 
f6b477c
88e9c63
f6a36d2
 
 
 
 
 
 
 
1991283
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import List

import os
import cv2
import supervision as sv
import numpy as np
import gradio as gr
import torch

from transformers import pipeline
from PIL import Image

# Definici贸n de la clase SamAutomaticMaskGenerator
class SamAutomaticMaskGenerator:
    def __init__(self, sam_pipeline):
        self.sam_pipeline = sam_pipeline

    def generate(self, image_rgb):
        # Convertir el array de NumPy a PIL Image
        image_pil = Image.fromarray(image_rgb)
        outputs = self.sam_pipeline(image_pil, points_per_batch=32)
        mask = np.array(outputs['masks'], dtype=np.uint8)
        return mask

# Configuraci贸n del modelo SAM
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
sam_pipeline = pipeline(
    task="mask-generation",
    model="facebook/sam-vit-large",
    device=DEVICE
)
EXAMPLES = [
    ["https://media.roboflow.com/notebooks/examples/dog.jpeg"],
    ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg"]
]


mask_generator = SamAutomaticMaskGenerator(sam_pipeline)

# Funci贸n para procesar y anotar la imagen
def process_image(image_pil):
    # Convertir PIL Image a numpy array para procesamiento
    image_rgb = np.array(image_pil)
    image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)

    # Generar la m谩scara y anotar la imagen
    sam_result = mask_generator.generate(image_rgb)
    mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
    detections = sv.Detections.from_sam(sam_result=sam_result)
    annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)

    # Convertir de nuevo a formato RGB y luego a PIL Image para Gradio
    annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
    annotated_image_pil = Image.fromarray(annotated_image_rgb)

    return image_pil, annotated_image_pil

# Construcci贸n de la interfaz Gradio
with gr.Blocks() as demo:
    gr.Markdown("# SAM - Segmentaci贸n de Im谩genes")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Cargar Imagen")
            submit_button = gr.Button("Segmentar")
        with gr.Column():
            original_image = gr.Image(type="pil", label="Imagen Original")
            segmented_image = gr.Image(type="pil", label="Imagen Segmentada")
    
    submit_button.click(
        process_image, 
        inputs=input_image, 
        outputs=[original_image, segmented_image]
    )
    with gr.Row():
        gr.Examples(
            examples=EXAMPLES,
            fn=process_image,
            inputs=[input_image],
            outputs=[original_image, segmented_image],
            cache_examples=False,
            run_on_click=True
        )
demo.launch(debug=True)