Spaces:
Runtime error
Runtime error
File size: 3,687 Bytes
23cb925 3bd34d6 23cb925 3bd34d6 23cb925 3bd34d6 23cb925 3bd34d6 23cb925 3bd34d6 23cb925 3bd34d6 23cb925 3bd34d6 23cb925 3bd34d6 23cb925 3bd34d6 23cb925 3bd34d6 23cb925 3bd34d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from typing import List
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import pipeline, CLIPProcessor, CLIPModel
MARKDOWN = """
# Segment Anything Model + MetaCLIP
This is the demo for a Open Vocabulary Image Segmentation using
[Segment Anything Model](https://github.com/facebookresearch/segment-anything) and
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo.
"""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAM_GENERATOR = pipeline(
task="mask-generation",
model="facebook/sam-vit-large",
device=DEVICE)
CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE)
CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
MASK_ANNOTATOR = sv.MaskAnnotator(
color=sv.Color.red(),
color_lookup=sv.ColorLookup.INDEX)
def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32)
mask = np.array(outputs['masks'])
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
def run_clip(image_rgb_pil: Image.Image, text: List[str]) -> np.ndarray:
inputs = CLIP_PROCESSOR(
text=text,
images=image_rgb_pil,
return_tensors="pt",
padding=True
).to(DEVICE)
outputs = CLIP_MODEL(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
return probs.detach().cpu().numpy()
def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8)
return np.where(mask[..., None], image, gray_color)
def annotate(image_rgb_pil: Image.Image, detections: sv.Detections) -> Image.Image:
img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
annotated_bgr_image = MASK_ANNOTATOR.annotate(
scene=img_bgr_numpy, detections=detections)
return Image.fromarray(annotated_bgr_image[:, :, ::-1])
def filter_detections(
image_rgb_pil: Image.Image,
detections: sv.Detections,
prompt: str
) -> sv.Detections:
img_rgb_numpy = np.array(image_rgb_pil)
text = [f"a picture of {prompt}", "a picture of background"]
filtering_mask = []
for xyxy, mask in zip(detections.xyxy, detections.mask):
crop = sv.crop_image(image=img_rgb_numpy, xyxy=xyxy)
mask_crop = sv.crop_image(image=mask, xyxy=xyxy)
masked_crop = reverse_mask_image(image=crop, mask=mask_crop)
masked_crop_pil = Image.fromarray(masked_crop)
probs = run_clip(image_rgb_pil=masked_crop_pil, text=text)
lass_index = np.argmax(probs)
filtering_mask.append(lass_index == 0)
filtering_mask = np.array(filtering_mask)
return detections[filtering_mask]
def inference(image_rgb_pil: Image.Image, prompt: str) -> Image.Image:
width, height = image_rgb_pil.size
area = width * height
detections = run_sam(image_rgb_pil)
detections = detections[detections.area / area > 0.005]
detections = filter_detections(
image_rgb_pil=image_rgb_pil,
detections=detections,
prompt=prompt)
return annotate(image_rgb_pil=image_rgb_pil, detections=detections)
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image = gr.Image(image_mode='RGB', type='pil')
prompt_text = gr.Textbox(label="Prompt", value="dog")
result_image = gr.Image(image_mode='RGB', type='pil')
submit_button = gr.Button("Submit")
submit_button.click(
inference,
inputs=[input_image, prompt_text],
outputs=result_image)
demo.launch(debug=False)
|