import argparse import copy from IPython.display import display from PIL import Image, ImageDraw, ImageFont from torchvision.ops import box_convert # Grounding DINO import groundingdino.datasets.transforms as T from groundingdino.models import build_model from groundingdino.util import box_ops from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap from groundingdino.util.inference import annotate, load_image, predict import supervision as sv # segment anything from segment_anything import build_sam, SamPredictor import cv2 import numpy as np import matplotlib.pyplot as plt # diffusers import PIL import requests import torch from io import BytesIO from diffusers import StableDiffusionInpaintPipeline from huggingface_hub import hf_hub_download def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'): cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename) args = SLConfig.fromfile(cache_config_file) args.device = device model = build_model(args) cache_file = hf_hub_download(repo_id=repo_id, filename=filename) checkpoint = torch.load(cache_file, map_location=device) log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) print("Model loaded from {} \n => {}".format(cache_file, log)) _ = model.eval() return model ckpt_repo_id = "ShilongLiu/GroundingDINO" ckpt_filenmae = "groundingdino_swinb_cogcoor.pth" ckpt_config_filename = "GroundingDINO_SwinB.cfg.py" groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename, device) checkpoint = 'sam_vit_h_4b8939.pth' predictor = SamPredictor(build_sam(checkpoint=checkpoint).to(device)) # detect object using grounding DINO def detect(image, text_prompt, model, box_threshold = 0.3, text_threshold = 0.25): boxes, logits, phrases = predict( model=model, image=image, caption=text_prompt, box_threshold=box_threshold, text_threshold=text_threshold ) annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases) annotated_frame = annotated_frame[...,::-1] # BGR to RGB return annotated_frame, boxes import gradio as gr # Define the Gradio interface def detect_objects(image, text_prompt): # Convert Gradio input format to the format expected by the code image_array = np.array(image) image_source, _ = load_image(image_array) # Detect objects using grounding DINO annotated_frame, detected_boxes = detect(image_array, text_prompt, groundingdino_model) # Convert the annotated frame to Gradio output format annotated_image = Image.fromarray(annotated_frame) return annotated_image # Create the Gradio interface iface = gr.Interface( fn=detect_objects, inputs=[gr.Image(), "text"], outputs=gr.Image(), live=True, interpretation="default" ) # Launch the Gradio interface iface.launch()