Spaces:
Running
Running
File size: 3,944 Bytes
a64bccf e4dee6a a64bccf 7eb36e4 58f3f8d a64bccf e4dee6a a64bccf e4dee6a a64bccf e4dee6a a64bccf e4dee6a a64bccf e4dee6a 837d4ca e4dee6a 4dfaafd e4dee6a 4dfaafd e4dee6a 6dd29b3 e4dee6a a64bccf e4dee6a a64bccf e4dee6a 4dfaafd e4dee6a b688ee0 a64bccf e4dee6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
import gradio as gr
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import spaces
import numpy as np
from PIL import Image
import io
import random
from transformers import SamModel, SamProcessor
def apply_colored_masks_on_image(image, masks):
if not isinstance(image, Image.Image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image_rgba = image.convert("RGBA")
for i in range(masks.shape[0]):
mask = masks[i].squeeze().cpu().numpy()
mask_image = Image.fromarray((mask * 255).astype(np.uint8), 'L')
color = tuple([random.randint(0, 255) for _ in range(3)] + [128])
colored_mask = Image.new("RGBA", image.size, color)
colored_mask.putalpha(mask_image)
image_rgba = Image.alpha_composite(image_rgba, colored_mask)
return image_rgba
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
model_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor_sam = SamProcessor.from_pretrained("facebook/sam-vit-huge")
@spaces.GPU
def query_image(img, text_queries, score_threshold=0.5):
text_queries = text_queries.split(",")
size = max(img.shape[:2])
target_sizes = torch.Tensor([[size, size]])
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
model_outputs = model(**inputs)
model_outputs.logits = model_outputs.logits.cpu()
model_outputs.pred_boxes = model_outputs.pred_boxes.cpu()
results = processor.post_process_object_detection(outputs=model_outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
img_pil = Image.fromarray(img.astype('uint8'), 'RGB')
result_labels = []
result_boxes = []
for box, score, label in zip(boxes, scores, labels):
if score >= score_threshold:
box = [int(i) for i in box.tolist()]
label_text = text_queries[label.item()]
result_labels.append((box, label_text))
result_boxes.append(box)
sam_image = generate_image_with_sam(np.array(img_pil), result_boxes)
return sam_image,result_labels
def generate_image_with_sam(img, input_boxes):
img_pil = Image.fromarray(img.astype('uint8'), 'RGB')
inputs = processor_sam(img_pil, return_tensors="pt").to(device)
image_embeddings = model_sam.get_image_embeddings(inputs["pixel_values"])
inputs = processor_sam(img_pil, input_boxes=[input_boxes], return_tensors="pt").to(device)
inputs["input_boxes"].shape
inputs.pop("pixel_values", None)
inputs.update({"image_embeddings": image_embeddings})
with torch.no_grad():
outputs = model_sam(**inputs, multimask_output=False)
masks = processor_sam.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
# scores = outputs.iou_scores
SAM_image = apply_colored_masks_on_image(img_pil, masks[0])
return SAM_image
description = """
Split anythings
"""
demo = gr.Interface(
fn=query_image,
inputs=[gr.Image(), gr.Textbox(label="Query Text"), gr.Slider(0, 1, value=0.1, label="Score Threshold")],
outputs=gr.AnnotatedImage(),
title="Zero-Shot Object Detection SV3",
description="This interface demonstrates object detection using zero-shot object detection and SAM for image segmentation.",
examples=[
["images/dark_cell.png", "gray cells", 0.1],
["images/animals.png", "Rabbit,Squirrel,Parrot,Hedgehog,Turtle,Ladybug,Chick,Frog,Butterfly,Snail,Mouse", 0.35],
],
)
demo.launch()
|