Spaces:
Runtime error
Runtime error
import torch | |
import cv2 | |
import os | |
import wget | |
import gradio as gr | |
import numpy as np | |
import gdown | |
from huggingface_hub import hf_hub_download | |
from argparse import Namespace | |
try: | |
import detectron2 | |
except: | |
# requirements.txt gives error since detectron2 > setup.py requires torch to be installed, which is not installed before this. | |
os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
import detectron2 | |
from demo import setup_cfg | |
from proxydet.predictor import VisualizationDemo | |
# Use GPU if available | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
# # download metadata | |
# zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy' | |
# if not os.path.exists(zs_weight_path): | |
# wget.download("https://github.com/facebookresearch/Detic/raw/main/datasets/metadata/lvis_v1_clip_a+cname.npy", out=zs_weight_path) | |
# base_cat_mask_path = "datasets/metadata/lvis_v1_base_cat_mask.npy" | |
# if not os.path.exists(base_cat_mask_path): | |
# wget.download("https://docs.google.com/uc?export=download&id=1CbSs5yeqMsWDkRSsIlB-ln_bXDv686rH", out=base_cat_mask_path) | |
# lvis_train_cat_info_path = "datasets/metadata/lvis_v1_train_cat_info.json" | |
# if not os.path.exists(lvis_train_cat_info_path): | |
# wget.download("https://docs.google.com/uc?export=download&id=17WmkAJYBK4xT-YkiXLcwIWmtfulSUtmO", out=lvis_train_cat_info_path) | |
# # download model | |
# model_path = "models/proxydet_swinb_w_inl.pth" | |
# if not os.path.exists(model_path): | |
# gdown.download("https://docs.google.com/uc?export=download&id=17kUPoi-pEK7BlTBheGzWxe_DXJlg28qF", model_path) | |
hf_hub_download( | |
repo_id="doublejtoh/proxydet_data", | |
filename="models/proxydet_swinb_w_inl.pth", | |
repo_type="model", | |
local_dir="./" | |
) | |
hf_hub_download( | |
repo_id="doublejtoh/proxydet_data", | |
filename="datasets/metadata/lvis_v1_base_cat_mask.npy", | |
repo_type="model", | |
local_dir="./" | |
) | |
hf_hub_download( | |
repo_id="doublejtoh/proxydet_data", | |
filename="datasets/metadata/lvis_v1_clip_a+cname.npy", | |
repo_type="model", | |
local_dir="./" | |
) | |
hf_hub_download( | |
repo_id="doublejtoh/proxydet_data", | |
filename="datasets/metadata/lvis_v1_train_cat_info.json", | |
repo_type="model", | |
local_dir="./" | |
) | |
model_path = "models/proxydet_swinb_w_inl.pth" | |
zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy' | |
args = Namespace( | |
base_cat_threshold=0.9, | |
confidence_threshold=0.0, | |
config_file='configs/ProxyDet_SwinB_Lbase_INL.yaml', | |
cpu=not torch.cuda.is_available(), | |
custom_vocabulary='headphone,webcam,paper,coffe', | |
input=['.assets/desk.jpg'], | |
opts=['MODEL.WEIGHTS', model_path], | |
output='out.jpg', | |
pred_all_class=False, | |
video_input=None, | |
vocabulary='custom', | |
webcam=None, | |
zeroshot_weight_path=zs_weight_path | |
) | |
cfg = setup_cfg(args) | |
ovd_demo = VisualizationDemo(cfg, args) | |
def query_image(img, text_queries, score_threshold, base_alpha, novel_beta): | |
text_queries_split = text_queries.split(",") | |
ovd_demo.reset_classifier(text_queries) | |
ovd_demo.reset_base_cat_mask() | |
ovd_demo.predictor.model.roi_heads.cmm_base_alpha = base_alpha | |
ovd_demo.predictor.model.roi_heads.cmm_novel_beta = novel_beta | |
img_bgr = img[:, :, ::-1] | |
with torch.no_grad(): | |
predictions, visualized_output = ovd_demo.run_on_image(img_bgr) | |
output_instances = predictions["instances"].to(device) | |
boxes = output_instances.pred_boxes.tensor | |
scores = output_instances.scores | |
labels = output_instances.pred_classes.tolist() | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
for box, score, label in zip(boxes, scores, labels): | |
box = [int(i) for i in box.tolist()] | |
if score >= score_threshold: | |
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) | |
if box[3] + 25 > 768: | |
y = box[3] - 10 | |
else: | |
y = box[3] + 25 | |
img = cv2.putText( | |
img, text_queries_split[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA | |
) | |
return img | |
if __name__ == "__main__": | |
description = """ | |
Gradio demo for ProxyDet, introduced in <a href="https://arxiv.org/abs/2312.07266">ProxyDet: Synthesizing Proxy Novel Classes via Classwise Mixup for Open-Vocabulary Object Detection</a>. | |
\n\nYou can use ProxyDet to query images with text descriptions of any object. | |
How to use? | |
- Simply upload an image and enter comma separated objects (e.g., "dog,cat,headphone") which you want to detect within the image.\n | |
Parameters: | |
- You can also use the score threshold slider to set a threshold to filter out low probability predictions. | |
- adjust alpha and beta value for base and novel classes, respectively. These determine <b>how much importance will you assign to the scores sourced from our proposed detection head which is trained with our proxy-novel classes</b>. | |
""" | |
demo = gr.Interface( | |
query_image, | |
inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1), gr.Slider(0, 1, value=0.15), gr.Slider(0, 1, value=0.35)], | |
outputs="image", | |
title="Open-Vocabulary Object Detection with ProxyDet", | |
description=description, | |
examples=[ | |
["assets/desk.jpg", "headphone,webcam,paper,coffee", 0.11, 0.15, 0.35], | |
["assets/beach.jpg", "person,kite", 0.1, 0.15, 0.35], | |
["assets/pikachu.jpg", "pikachu,person", 0.15, 0.15, 0.35], | |
], | |
) | |
demo.launch() | |