Arulkumar03's picture
Upload app.py
bffd5b1
raw
history blame
No virus
3.01 kB
import argparse
import copy
from IPython.display import display
from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert
# Grounding DINO
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.inference import annotate, load_image, predict
import supervision as sv
# segment anything
from segment_anything import build_sam, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt
# diffusers
import PIL
import requests
import torch
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline
from huggingface_hub import hf_hub_download
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
args = SLConfig.fromfile(cache_config_file)
args.device = device
model = build_model(args)
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(cache_file, map_location=device)
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
print("Model loaded from {} \n => {}".format(cache_file, log))
_ = model.eval()
return model
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename, device)
checkpoint = 'sam_vit_h_4b8939.pth'
predictor = SamPredictor(build_sam(checkpoint=checkpoint).to(device))
# detect object using grounding DINO
def detect(image, text_prompt, model, box_threshold = 0.3, text_threshold = 0.25):
boxes, logits, phrases = predict(
model=model,
image=image,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
annotated_frame = annotated_frame[...,::-1] # BGR to RGB
return annotated_frame, boxes
import gradio as gr
# Define the Gradio interface
def detect_objects(image, text_prompt):
# Convert Gradio input format to the format expected by the code
image_array = np.array(image)
image_source, _ = load_image(image_array)
# Detect objects using grounding DINO
annotated_frame, detected_boxes = detect(image_array, text_prompt, groundingdino_model)
# Convert the annotated frame to Gradio output format
annotated_image = Image.fromarray(annotated_frame)
return annotated_image
# Create the Gradio interface
iface = gr.Interface(
fn=detect_objects,
inputs=[gr.Image(), "text"],
outputs=gr.Image(),
live=True,
interpretation="default"
)
# Launch the Gradio interface
iface.launch()