Spaces:
Runtime error
Runtime error
import argparse | |
import copy | |
from IPython.display import display | |
from PIL import Image, ImageDraw, ImageFont | |
from torchvision.ops import box_convert | |
# Grounding DINO | |
import groundingdino.datasets.transforms as T | |
from groundingdino.models import build_model | |
from groundingdino.util import box_ops | |
from groundingdino.util.slconfig import SLConfig | |
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap | |
from groundingdino.util.inference import annotate, load_image, predict | |
import supervision as sv | |
# segment anything | |
from segment_anything import build_sam, SamPredictor | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# diffusers | |
import PIL | |
import requests | |
import torch | |
from io import BytesIO | |
from diffusers import StableDiffusionInpaintPipeline | |
from huggingface_hub import hf_hub_download | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'): | |
cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename) | |
args = SLConfig.fromfile(cache_config_file) | |
args.device = device | |
model = build_model(args) | |
cache_file = hf_hub_download(repo_id=repo_id, filename=filename) | |
checkpoint = torch.load(cache_file, map_location=device) | |
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) | |
print("Model loaded from {} \n => {}".format(cache_file, log)) | |
_ = model.eval() | |
return model | |
ckpt_repo_id = "ShilongLiu/GroundingDINO" | |
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth" | |
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py" | |
groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename, device) | |
checkpoint = 'sam_vit_h_4b8939.pth' | |
predictor = SamPredictor(build_sam(checkpoint=checkpoint).to(device)) | |
# detect object using grounding DINO | |
def detect(image, text_prompt, model, box_threshold = 0.3, text_threshold = 0.25): | |
boxes, logits, phrases = predict( | |
model=model, | |
image=image, | |
caption=text_prompt, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold | |
) | |
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases) | |
annotated_frame = annotated_frame[...,::-1] # BGR to RGB | |
return annotated_frame, boxes | |
import gradio as gr | |
# Define the Gradio interface | |
def detect_objects(image, text_prompt): | |
# Convert Gradio input format to the format expected by the code | |
image_array = np.array(image) | |
image_source, _ = load_image(image_array) | |
# Detect objects using grounding DINO | |
annotated_frame, detected_boxes = detect(image_array, text_prompt, groundingdino_model) | |
# Convert the annotated frame to Gradio output format | |
annotated_image = Image.fromarray(annotated_frame) | |
return annotated_image | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=detect_objects, | |
inputs=[gr.Image(), "text"], | |
outputs=gr.Image(), | |
live=True, | |
interpretation="default" | |
) | |
# Launch the Gradio interface | |
iface.launch() | |