|
import cv2 |
|
import numpy as np |
|
import supervision as sv |
|
from typing import List |
|
from PIL import Image |
|
|
|
import torch |
|
|
|
from groundingdino.util.inference import Model |
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
|
|
|
|
|
from ram.models import ram |
|
|
|
from ram import inference_ram |
|
import torchvision |
|
import torchvision.transforms as TS |
|
|
|
|
|
|
|
SOURCE_IMAGE_PATH = "./assets/demo9.jpg" |
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" |
|
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth" |
|
|
|
SAM_ENCODER_VERSION = "vit_h" |
|
SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth" |
|
|
|
TAG2TEXT_CHECKPOINT_PATH = "./tag2text_swin_14m.pth" |
|
RAM_CHECKPOINT_PATH = "./ram_swin_large_14m.pth" |
|
|
|
TAG2TEXT_THRESHOLD = 0.64 |
|
BOX_THRESHOLD = 0.2 |
|
TEXT_THRESHOLD = 0.2 |
|
IOU_THRESHOLD = 0.5 |
|
|
|
|
|
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) |
|
|
|
|
|
|
|
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH) |
|
sam_predictor = SamPredictor(sam) |
|
|
|
|
|
|
|
normalize = TS.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
transform = TS.Compose( |
|
[ |
|
TS.Resize((384, 384)), |
|
TS.ToTensor(), |
|
normalize |
|
] |
|
) |
|
|
|
DELETE_TAG_INDEX = [] |
|
for idx in range(3012, 3429): |
|
DELETE_TAG_INDEX.append(idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ram_model = ram(pretrained=RAM_CHECKPOINT_PATH, |
|
image_size=384, |
|
vit='swin_l') |
|
ram_model.eval() |
|
ram_model = ram_model.to(DEVICE) |
|
|
|
|
|
image = cv2.imread(SOURCE_IMAGE_PATH) |
|
image_pillow = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) |
|
|
|
image_pillow = image_pillow.resize((384, 384)) |
|
image_pillow = transform(image_pillow).unsqueeze(0).to(DEVICE) |
|
|
|
specified_tags='None' |
|
|
|
res = inference_ram(image_pillow , ram_model) |
|
|
|
|
|
|
|
AUTOMATIC_CLASSES=res[0].split(" | ") |
|
|
|
print(f"Tags: {res[0].replace(' |', ',')}") |
|
|
|
|
|
|
|
detections = grounding_dino_model.predict_with_classes( |
|
image=image, |
|
classes=AUTOMATIC_CLASSES, |
|
box_threshold=BOX_THRESHOLD, |
|
text_threshold=BOX_THRESHOLD |
|
) |
|
|
|
|
|
print(f"Before NMS: {len(detections.xyxy)} boxes") |
|
nms_idx = torchvision.ops.nms( |
|
torch.from_numpy(detections.xyxy), |
|
torch.from_numpy(detections.confidence), |
|
IOU_THRESHOLD |
|
).numpy().tolist() |
|
|
|
detections.xyxy = detections.xyxy[nms_idx] |
|
detections.confidence = detections.confidence[nms_idx] |
|
detections.class_id = detections.class_id[nms_idx] |
|
|
|
print(f"After NMS: {len(detections.xyxy)} boxes") |
|
|
|
|
|
box_annotator = sv.BoxAnnotator() |
|
labels = [ |
|
f"{AUTOMATIC_CLASSES[class_id]} {confidence:0.2f}" |
|
for _, _, confidence, class_id, _, _ |
|
in detections] |
|
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels) |
|
|
|
|
|
cv2.imwrite("groundingdino_auto_annotated_image.jpg", annotated_frame) |
|
|
|
|
|
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: |
|
sam_predictor.set_image(image) |
|
result_masks = [] |
|
for box in xyxy: |
|
masks, scores, logits = sam_predictor.predict( |
|
box=box, |
|
multimask_output=True |
|
) |
|
index = np.argmax(scores) |
|
result_masks.append(masks[index]) |
|
return np.array(result_masks) |
|
|
|
|
|
|
|
detections.mask = segment( |
|
sam_predictor=sam_predictor, |
|
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), |
|
xyxy=detections.xyxy |
|
) |
|
|
|
|
|
box_annotator = sv.BoxAnnotator() |
|
mask_annotator = sv.MaskAnnotator() |
|
labels = [ |
|
f"{AUTOMATIC_CLASSES[class_id]} {confidence:0.2f}" |
|
for _, _, confidence, class_id, _, _ |
|
in detections] |
|
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) |
|
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) |
|
|
|
|
|
cv2.imwrite("ram_grounded_sam_auto_annotated_image.jpg", annotated_image) |
|
|