SkalskiP's picture
:tada: initial commit
3bd34d6
raw history blame
No virus
1.26 kB
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import pipeline
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAM_GENERATOR = pipeline(
task="mask-generation",
model="facebook/sam-vit-large",
device=DEVICE)
def run_segmentation(image_rgb_pil: Image.Image) -> sv.Detections:
outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32)
mask = np.array(outputs['masks'])
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
def inference(image_rgb_pil: Image.Image) -> Image.Image:
detections = run_segmentation(image_rgb_pil)
mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
annotated_bgr_image = mask_annotator.annotate(
scene=img_bgr_numpy, detections=detections)
return Image.fromarray(annotated_bgr_image[:, :, ::-1])
with gr.Blocks() as demo:
with gr.Row():
input_image = gr.Image(image_mode='RGB', type='pil')
result_image = gr.Image(image_mode='RGB', type='pil')
submit_button = gr.Button("Submit")
submit_button.click(inference, inputs=[input_image], outputs=result_image)
demo.launch(debug=False)