SAMSEGMENT / app.py
fireedman's picture
Create app.py
f6b477c verified
raw
history blame
No virus
3.99 kB
from typing import List
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import pipeline, CLIPProcessor, CLIPModel
#************
#Variables globales
MARKDOWN = """
#SAM
"""
EXAMPLES = [
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5],
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5],
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5],
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6],
]
MIN_AREA_THRESHOLD = 0.01
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAM_GENERATOR = pipeline(
task = "mask-generation",
model = "facebook/sam-vit-large",
device = DEVICE
)
SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator(
color = sv.Color.red(),
color_lookup = sv.ColorLookup.INDEX
)
SOLID_MASK_ANNOTATOR = sv.MaskAnnotator(
color = sv.Color.white(),
color_lookup = sv.ColorLookup.INDEX,
opacity = 1
)
#************
#funciones de trabajo
def run_sam(image_rgb_pil : Image.Image ) -> sv.Detections:
outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch = 32)
mask = np.array(outputs['masks'])
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
gray_color = np.array([
gray_value,
gray_value,
gray_value
], dtype=np.uint8)
return np.where(mask[..., None], image, gray_color)
def filter_detections(image_rgb_pil: Image.Image, detections: sv.Detections) -> sv.Detections:
img_rgb_numpy = np.array(image_rgb_pil)
filtering_mask = []
for xyxy, mask in zip(detections.xyxy, detections.mask):
crop = sv.crop_image(
image = img_rgb_numpy,
xyxy =xyxy
)
mask_crop = sv.crop_image(
image=mask,
xyxy=xyxy
)
masked_crop = reverse_mask_image(
image=crop,
mask=mask_crop
)
filtering_mask = np.array(
filtering_mask
)
return detections[filtering_mask]
def inference (image_rgb_pil: Image.Image) -> List[Image.Image]:
width, height = image_rgb_pil.size
area = width * height
detections = run_sam(
image_rgb_pil
)
detections = detections[ detections.area /area > MIN_AREA_THRESHOLD ]
detections = filter_detections(
image_rgb_pil=image_rgb_pil,
detections=detections,
)
blank_image = Image.new("RGB", (width, height), "black")
return [
annotate(
image_rgb_pil=image_rgb_pil,
detections=detections,
annotator=SEMITRANSPARENT_MASK_ANNOTATOR),
annotate(
image_rgb_pil=blank_image,
detections=detections,
annotator=SOLID_MASK_ANNOTATOR)
]
#************
#GRADIO CONSTRUCTION
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image = gr,Image(
image_mode = 'RGB',
type = 'pil',
height = 500
)
submit_button = gr.Button("Pruébalo!!!")
gallery = gr.Gallery(
label = "Result",
object_fit = "scale-down",
preview = True
)
with gr.Row():
gr.Examples(
examples = EXAMPLES,
fn = inference,
inputs = [
input_image,
prompt_text,
confidence_slider
],
outputs = [gallery],
cache_examples = True,
run_on_click = True
)
submit_button.click(
inference,
inputs = [
input_image,
prompt_text,
confidence_slider
],
outputs = gallery
)
demo.launch( debug = True, show_error = True )