File size: 4,579 Bytes
002bd9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import sys
sys.path.append(".")
import gradio as gr
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor
import torch
from PIL import Image
import requests
import numpy as np
import time
from transformers import CLIPProcessor, CLIPModel
from segment_anything import SamPredictor, sam_model_registry
cache_dir = ".cache"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam_model = "facebook/sam-vit-huge"
# wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -O tmp/data/sam_vit_h_4b8939.pth
sam_ckpt = "tmp/data/sam_vit_h_4b8939.pth"
sam = sam_model_registry["vit_h"](sam_ckpt)
sam = sam.to(device)
sam = SamPredictor(sam)
captioner_model = "Salesforce/blip-image-captioning-base"
clip_model = "openai/clip-vit-base-patch32"
clip = CLIPModel.from_pretrained(clip_model, cache_dir=cache_dir).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model, cache_dir=cache_dir)
# NOTE(xiaoke): in original clip, dtype is float16, here we use float32 as hf default
dtype = clip.dtype
img_url = "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw)
NUM_OUTPUT_HEADS = 3
LIBRARIES = ["caption_mask_with_highest_iou", "multimask_output", "return_patches"]
DEFAULT_LIBRARIES = ["multimask_output", "return_patches"]
def click_and_run(input_image, args, evt: gr.SelectData):
x, y = evt.index
input_points = [[x, y]]
return run(args, input_image, input_points=input_points, input_labels=[1])
def box_and_run(input_image, args, input_boxes_text):
x, y, x2, y2 = list(map(int, input_boxes_text.split(",")))
input_boxes = [[x, y, x2, y2]]
return run(args, input_image, input_boxes=input_boxes)
def run(args, input_image, input_points=None, input_boxes=None, input_labels=None):
if input_points is None and input_boxes is None:
raise ValueError("input_points and input_boxes cannot be both None")
if input_points is not None:
input_points = np.array(input_points)
if input_boxes is not None:
input_boxes = np.array(input_boxes)
caption_mask_with_highest_iou = "caption_mask_with_highest_iou" in args
multimask_output = "multimask_output" in args
return_patches = "return_patches" in args
input_image = np.array(input_image)
sam.set_image(input_image)
masks, iou_predictions, low_res_masks = sam.predict(
point_coords=input_points, box=input_boxes, point_labels=input_labels, multimask_output=multimask_output
)
outputs = []
num_heads = len(masks)
# Tuple[numpy.ndarray | PIL.Image | str, List[Tuple[numpy.ndarray | Tuple[int, int, int, int], str]]]
# (batch_size(1), region_size(1), num_heads)
iou_scores = iou_predictions
for i in range(num_heads):
output = [input_image, [[masks[i], f"iou:{iou_scores[i]:.4f}"]]]
outputs.append(output)
for i in range(num_heads, NUM_OUTPUT_HEADS):
output = [np.ones((1, 1)), []]
outputs.append(output)
for i in range(NUM_OUTPUT_HEADS):
output = [np.ones((1, 1)), []]
outputs.append(output)
return outputs
def fake_click_and_run(input_image, args, evt: gr.SelectData):
outputs = []
# Tuple[numpy.ndarray | PIL.Image | str, List[Tuple[numpy.ndarray | Tuple[int, int, int, int], str]]]
num_heads = 1
for i in range(num_heads):
output = [input_image, []]
outputs.append(output)
for i in range(num_heads, NUM_OUTPUT_HEADS):
output = [input_image, []]
outputs.append(output)
return outputs
with gr.Blocks() as demo:
input_image = gr.Image(value=raw_image, label="Input Image", interactive=True, type="pil", height=500)
args = gr.CheckboxGroup(choices=LIBRARIES, value=DEFAULT_LIBRARIES, label="SAM Captioner Arguments")
input_boxes_text = gr.Textbox(lines=1, label="Input Boxes (x,y,x2,y2)", value="0,0,100,100")
input_boxes_button = gr.Button(value="Run with Input Boxes")
output_images = []
with gr.Row():
for i in range(NUM_OUTPUT_HEADS):
output_images.append(gr.AnnotatedImage(label=f"Output Image {i}", height=500))
with gr.Row():
for i in range(NUM_OUTPUT_HEADS):
output_images.append(gr.AnnotatedImage(label=f"Output Image {i}", height=500))
input_image.select(click_and_run, [input_image, args], [*output_images])
input_boxes_button.click(box_and_run, [input_image, args, input_boxes_text], [*output_images])
demo.launch()
|