nervn / grounded_sam_demo.py
mart9992's picture
m
9856e13
raw history blame
No virus
4.02 kB
from GroundingDINO.groundingdino.datasets.transforms import Compose, RandomResize, ToTensor, Normalize
from io import BytesIO
import os
import copy
import numpy as np
import json
import torch
from PIL import Image, ImageDraw, ImageFont
# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# segment anything
from segment_anything import (
build_sam,
build_sam_hq,
SamPredictor
)
import cv2
import numpy as np
import matplotlib.pyplot as plt
def load_model(model_config_path, model_checkpoint_path, device):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(
clean_state_dict(checkpoint["model"]), strict=False)
print(load_res)
_ = model.eval()
return model
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
return boxes_filt
def grounded_sam_demo(input_pil, config_file, grounded_checkpoint, sam_checkpoint,
text_prompt, box_threshold=0.3, text_threshold=0.25,
device="cuda"):
# Convert PIL image to tensor with normalization
transform = Compose([
RandomResize([800], max_size=1333),
ToTensor(),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
if input_pil.mode != "RGB":
input_pil = input_pil.convert("RGB")
image, _ = transform(input_pil, None)
# Load model
model = load_model(config_file, grounded_checkpoint, device=device)
# Get grounding dino model output
boxes_filt = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold, device=device)
# Initialize SAM
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
image = cv2.cvtColor(np.array(input_pil), cv2.COLOR_RGB2BGR)
predictor.set_image(image)
size = input_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
transformed_boxes = predictor.transform.apply_boxes_torch(
boxes_filt, image.shape[:2]).to(device)
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(device),
multimask_output=False,
)
# Create mask image
value = 0 # 0 for background
mask_img = torch.zeros(masks.shape[-2:])
for idx, mask in enumerate(masks):
mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
fig = plt.figure(figsize=(10, 10))
plt.imshow(mask_img.numpy())
plt.axis('off')
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches="tight",
dpi=300, pad_inches=0.0)
buf.seek(0)
out_pil = Image.open(buf)
return out_pil