File size: 1,258 Bytes
3bd34d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import pipeline

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAM_GENERATOR = pipeline(
    task="mask-generation",
    model="facebook/sam-vit-large",
    device=DEVICE)


def run_segmentation(image_rgb_pil: Image.Image) -> sv.Detections:
    outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32)
    mask = np.array(outputs['masks'])
    return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)


def inference(image_rgb_pil: Image.Image) -> Image.Image:
    detections = run_segmentation(image_rgb_pil)
    mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)

    img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
    annotated_bgr_image = mask_annotator.annotate(
        scene=img_bgr_numpy, detections=detections)
    return Image.fromarray(annotated_bgr_image[:, :, ::-1])


with gr.Blocks() as demo:
    with gr.Row():
        input_image = gr.Image(image_mode='RGB', type='pil')
        result_image = gr.Image(image_mode='RGB', type='pil')
    submit_button = gr.Button("Submit")

    submit_button.click(inference, inputs=[input_image], outputs=result_image)

demo.launch(debug=False)