File size: 4,024 Bytes
9856e13
 
b793f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9856e13
 
b793f0c
 
 
 
 
 
 
 
 
 
 
 
9856e13
 
b793f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9856e13
b793f0c
 
9856e13
 
 
b793f0c
9856e13
 
 
 
 
 
b793f0c
9856e13
 
 
 
 
 
 
b793f0c
9856e13
 
 
 
 
 
 
b793f0c
 
9856e13
b793f0c
 
 
 
 
 
 
9856e13
 
b793f0c
 
9856e13
 
 
 
b793f0c
 
9856e13
 
 
 
 
b793f0c
9856e13
 
b793f0c
 
9856e13
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from GroundingDINO.groundingdino.datasets.transforms import Compose, RandomResize, ToTensor, Normalize
from io import BytesIO
import os
import copy

import numpy as np
import json
import torch
from PIL import Image, ImageDraw, ImageFont

# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

# segment anything
from segment_anything import (
    build_sam,
    build_sam_hq,
    SamPredictor
)
import cv2
import numpy as np
import matplotlib.pyplot as plt


def load_model(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(
        clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)
    _ = model.eval()
    return model


def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
    caption = caption.lower()
    caption = caption.strip()
    if not caption.endswith("."):
        caption = caption + "."
    model = model.to(device)
    image = image.to(device)
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    logits = outputs["pred_logits"].cpu().sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"].cpu()[0]  # (nq, 4)
    logits.shape[0]

    # filter output
    logits_filt = logits.clone()
    boxes_filt = boxes.clone()
    filt_mask = logits_filt.max(dim=1)[0] > box_threshold
    logits_filt = logits_filt[filt_mask]  # num_filt, 256
    boxes_filt = boxes_filt[filt_mask]  # num_filt, 4
    logits_filt.shape[0]

    return boxes_filt


def grounded_sam_demo(input_pil, config_file, grounded_checkpoint, sam_checkpoint,
                      text_prompt, box_threshold=0.3, text_threshold=0.25,
                      device="cuda"):

    # Convert PIL image to tensor with normalization
    transform = Compose([
        RandomResize([800], max_size=1333),
        ToTensor(),
        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    if input_pil.mode != "RGB":
        input_pil = input_pil.convert("RGB")

    image, _ = transform(input_pil, None)

    # Load model
    model = load_model(config_file, grounded_checkpoint, device=device)

    # Get grounding dino model output
    boxes_filt = get_grounding_output(
        model, image, text_prompt, box_threshold, text_threshold, device=device)

    # Initialize SAM
    predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
    image = cv2.cvtColor(np.array(input_pil), cv2.COLOR_RGB2BGR)
    predictor.set_image(image)

    size = input_pil.size
    H, W = size[1], size[0]
    for i in range(boxes_filt.size(0)):
        boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
        boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
        boxes_filt[i][2:] += boxes_filt[i][:2]

    boxes_filt = boxes_filt.cpu()
    transformed_boxes = predictor.transform.apply_boxes_torch(
        boxes_filt, image.shape[:2]).to(device)

    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes.to(device),
        multimask_output=False,
    )

    # Create mask image
    value = 0  # 0 for background
    mask_img = torch.zeros(masks.shape[-2:])
    for idx, mask in enumerate(masks):
        mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1

    fig = plt.figure(figsize=(10, 10))
    plt.imshow(mask_img.numpy())
    plt.axis('off')

    buf = BytesIO()
    plt.savefig(buf, format='png', bbox_inches="tight",
                dpi=300, pad_inches=0.0)
    buf.seek(0)
    out_pil = Image.open(buf)

    return out_pil