File size: 3,944 Bytes
a64bccf
 
 
 
e4dee6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a64bccf
 
 
 
 
 
 
7eb36e4
 
 
 
58f3f8d
a64bccf
 
e4dee6a
a64bccf
 
 
 
 
 
e4dee6a
 
 
 
 
a64bccf
 
e4dee6a
 
 
a64bccf
e4dee6a
a64bccf
e4dee6a
 
 
 
 
 
837d4ca
e4dee6a
 
 
 
4dfaafd
e4dee6a
 
 
 
 
4dfaafd
e4dee6a
 
 
 
 
 
 
 
6dd29b3
e4dee6a
 
 
a64bccf
 
 
e4dee6a
a64bccf
 
e4dee6a
4dfaafd
e4dee6a
 
 
b688ee0
 
 
 
a64bccf
e4dee6a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import gradio as gr
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import spaces
import numpy as np
from PIL import Image
import io
import random
from transformers import SamModel, SamProcessor


def apply_colored_masks_on_image(image, masks):
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image.astype('uint8'), 'RGB')

    image_rgba = image.convert("RGBA")

    for i in range(masks.shape[0]):
        mask = masks[i].squeeze().cpu().numpy()
        mask_image = Image.fromarray((mask * 255).astype(np.uint8), 'L')
        color = tuple([random.randint(0, 255) for _ in range(3)] + [128]) 
        colored_mask = Image.new("RGBA", image.size, color)
        colored_mask.putalpha(mask_image)
        image_rgba = Image.alpha_composite(image_rgba, colored_mask)

    return image_rgba


# Use GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
model_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor_sam = SamProcessor.from_pretrained("facebook/sam-vit-huge")
    

@spaces.GPU
def query_image(img, text_queries, score_threshold=0.5):
    text_queries = text_queries.split(",")
    size = max(img.shape[:2])
    target_sizes = torch.Tensor([[size, size]])
    inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)

    with torch.no_grad():
        model_outputs = model(**inputs)
        model_outputs.logits = model_outputs.logits.cpu()
        model_outputs.pred_boxes = model_outputs.pred_boxes.cpu()
        results = processor.post_process_object_detection(outputs=model_outputs, target_sizes=target_sizes)
    
    boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]

  
    img_pil = Image.fromarray(img.astype('uint8'), 'RGB')

    result_labels = []
    result_boxes = []
    for box, score, label in zip(boxes, scores, labels):
        if score >= score_threshold:
            box = [int(i) for i in box.tolist()]
            label_text = text_queries[label.item()]
            result_labels.append((box, label_text))
            result_boxes.append(box)

    sam_image = generate_image_with_sam(np.array(img_pil), result_boxes)

    return sam_image,result_labels

   
def generate_image_with_sam(img, input_boxes):
    img_pil = Image.fromarray(img.astype('uint8'), 'RGB')
    inputs = processor_sam(img_pil, return_tensors="pt").to(device)
    
    image_embeddings = model_sam.get_image_embeddings(inputs["pixel_values"])

    inputs = processor_sam(img_pil, input_boxes=[input_boxes], return_tensors="pt").to(device)
    inputs["input_boxes"].shape
    inputs.pop("pixel_values", None)
    inputs.update({"image_embeddings": image_embeddings})

    with torch.no_grad():
        outputs = model_sam(**inputs, multimask_output=False)
    
    masks = processor_sam.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
    # scores = outputs.iou_scores

    SAM_image = apply_colored_masks_on_image(img_pil, masks[0])
    return SAM_image


description = """
Split anythings
"""
demo = gr.Interface(
    fn=query_image,
    inputs=[gr.Image(), gr.Textbox(label="Query Text"), gr.Slider(0, 1, value=0.1, label="Score Threshold")],
    outputs=gr.AnnotatedImage(),
    title="Zero-Shot Object Detection SV3",
    description="This interface demonstrates object detection using  zero-shot object detection and SAM  for image segmentation.",
    examples=[
        ["images/dark_cell.png", "gray cells", 0.1],
        ["images/animals.png", "Rabbit,Squirrel,Parrot,Hedgehog,Turtle,Ladybug,Chick,Frog,Butterfly,Snail,Mouse", 0.35],
    ],
)

demo.launch()