Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import copy | |
import numpy as np | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
import PIL | |
# OwlViT Detection | |
from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
# segment anything | |
from segment_anything import build_sam, SamPredictor | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import gc | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_box(box, ax): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
def plot_boxes_to_image(image_pil, tgt): | |
H, W = tgt["size"] | |
boxes = tgt["boxes"] | |
labels = tgt["labels"] | |
assert len(boxes) == len(labels), "boxes and labels must have same length" | |
draw = ImageDraw.Draw(image_pil) | |
mask = Image.new("L", image_pil.size, 0) | |
mask_draw = ImageDraw.Draw(mask) | |
# draw boxes and masks | |
for box, label in zip(boxes, labels): | |
# random color | |
color = tuple(np.random.randint(0, 255, size=3).tolist()) | |
# draw | |
x0, y0, x1, y1 = box | |
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) | |
draw.rectangle([x0, y0, x1, y1], outline=color, width=6) | |
draw.text((x0, y0), str(label), fill=color) | |
font = ImageFont.load_default() | |
if hasattr(font, "getbbox"): | |
bbox = draw.textbbox((x0, y0), str(label), font) | |
else: | |
w, h = draw.textsize(str(label), font) | |
bbox = (x0, y0, w + x0, y0 + h) | |
# bbox = draw.textbbox((x0, y0), str(label)) | |
draw.rectangle(bbox, fill=color) | |
draw.text((x0, y0), str(label), fill="white") | |
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) | |
return image_pil, mask | |
# Use GPU if available | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:4") | |
else: | |
device = torch.device("cpu") | |
# load OWL-ViT model | |
owlvit_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device) | |
owlvit_model.eval() | |
owlvit_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") | |
# run segment anything (SAM) | |
sam_predictor = SamPredictor(build_sam(checkpoint="./sam_vit_h_4b8939.pth")) | |
def query_image(img, text_prompt): | |
# load image | |
if not isinstance(img, PIL.Image.Image): | |
pil_img = Image.fromarray(np.uint8(img)).convert('RGB') | |
text_prompt = text_prompt | |
texts = [text_prompt.split(",")] | |
box_threshold = 0.0 | |
# run object detection model | |
with torch.no_grad(): | |
inputs = owlvit_processor(text=texts, images=pil_img, return_tensors="pt").to(device) | |
outputs = owlvit_model(**inputs) | |
# Target image sizes (height, width) to rescale box predictions [batch_size, 2] | |
target_sizes = torch.Tensor([pil_img.size[::-1]]) | |
# Convert outputs (bounding boxes and class logits) to COCO API | |
results = owlvit_processor.post_process_object_detection(outputs=outputs, threshold=box_threshold, target_sizes=target_sizes.to(device)) | |
scores = torch.sigmoid(outputs.logits) | |
topk_scores, topk_idxs = torch.topk(scores, k=1, dim=1) | |
i = 0 # Retrieve predictions for the first image for the corresponding text queries | |
text = texts[i] | |
topk_idxs = topk_idxs.squeeze(1).tolist() | |
topk_boxes = results[i]['boxes'][topk_idxs] | |
topk_scores = topk_scores.view(len(text), -1) | |
topk_labels = results[i]["labels"][topk_idxs] | |
boxes, scores, labels = topk_boxes, topk_scores, topk_labels | |
# boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] | |
boxes = boxes.cpu().detach().numpy() | |
normalized_boxes = copy.deepcopy(boxes) | |
# # visualize pred | |
size = pil_img.size | |
pred_dict = { | |
"boxes": normalized_boxes, | |
"size": [size[1], size[0]], # H, W | |
"labels": [text[idx] for idx in labels] | |
} | |
# release the OWL-ViT | |
# owlvit_model.cpu() | |
# del owlvit_model | |
gc.collect() | |
torch.cuda.empty_cache() | |
# run segment anything (SAM) | |
open_cv_image = np.array(pil_img) | |
image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) | |
sam_predictor.set_image(image) | |
H, W = size[1], size[0] | |
for i in range(boxes.shape[0]): | |
boxes[i] = torch.Tensor(boxes[i]) | |
boxes = torch.tensor(boxes, device=sam_predictor.device) | |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) | |
masks, _, _ = sam_predictor.predict_torch( | |
point_coords = None, | |
point_labels = None, | |
boxes = transformed_boxes, | |
multimask_output = False, | |
) | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(image) | |
for mask in masks: | |
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) | |
for box in boxes: | |
show_box(box.numpy(), plt.gca()) | |
plt.axis('off') | |
import io | |
buf = io.BytesIO() | |
plt.savefig(buf) | |
buf.seek(0) | |
owlvit_segment_image = Image.open(buf).convert('RGB') | |
# grounded results | |
image_with_box = plot_boxes_to_image(pil_img, pred_dict)[0] | |
return owlvit_segment_image, image_with_box | |
# return owlvit_segment_image | |