|
import argparse |
|
import copy |
|
|
|
from IPython.display import display |
|
from PIL import Image, ImageDraw, ImageFont |
|
from torchvision.ops import box_convert |
|
|
|
|
|
import groundingdino.datasets.transforms as T |
|
from groundingdino.models import build_model |
|
from groundingdino.util import box_ops |
|
from groundingdino.util.slconfig import SLConfig |
|
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap |
|
from groundingdino.util.inference import annotate, load_image, predict |
|
|
|
import supervision as sv |
|
|
|
|
|
from segment_anything import build_sam, SamPredictor |
|
import cv2 |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
import PIL |
|
import requests |
|
import torch |
|
from io import BytesIO |
|
from diffusers import StableDiffusionInpaintPipeline |
|
from huggingface_hub import hf_hub_download |
|
|
|
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'): |
|
cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename) |
|
|
|
args = SLConfig.fromfile(cache_config_file) |
|
args.device = device |
|
model = build_model(args) |
|
|
|
cache_file = hf_hub_download(repo_id=repo_id, filename=filename) |
|
checkpoint = torch.load(cache_file, map_location=device) |
|
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) |
|
print("Model loaded from {} \n => {}".format(cache_file, log)) |
|
_ = model.eval() |
|
return model |
|
|
|
ckpt_repo_id = "ShilongLiu/GroundingDINO" |
|
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth" |
|
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py" |
|
|
|
groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename, device) |
|
|
|
checkpoint = 'sam_vit_h_4b8939.pth' |
|
|
|
predictor = SamPredictor(build_sam(checkpoint=checkpoint).to(device)) |
|
|
|
|
|
def detect(image, text_prompt, model, box_threshold = 0.3, text_threshold = 0.25): |
|
boxes, logits, phrases = predict( |
|
model=model, |
|
image=image, |
|
caption=text_prompt, |
|
box_threshold=box_threshold, |
|
text_threshold=text_threshold |
|
) |
|
|
|
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases) |
|
annotated_frame = annotated_frame[...,::-1] |
|
return annotated_frame, boxes |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
def detect_objects(image, text_prompt): |
|
|
|
image_array = np.array(image) |
|
image_source, _ = load_image(image_array) |
|
|
|
|
|
annotated_frame, detected_boxes = detect(image_array, text_prompt, groundingdino_model) |
|
|
|
|
|
annotated_image = Image.fromarray(annotated_frame) |
|
|
|
return annotated_image |
|
|
|
|
|
iface = gr.Interface( |
|
fn=detect_objects, |
|
inputs=[gr.Image(), "text"], |
|
outputs=gr.Image(), |
|
live=True, |
|
interpretation="default" |
|
) |
|
|
|
|
|
iface.launch() |
|
|