Spaces:
Runtime error
Runtime error
| import torch | |
| import cv2 | |
| import os | |
| import wget | |
| import gradio as gr | |
| import numpy as np | |
| import gdown | |
| from huggingface_hub import hf_hub_download | |
| from argparse import Namespace | |
| try: | |
| import detectron2 | |
| except: | |
| # requirements.txt gives error since detectron2 > setup.py requires torch to be installed, which is not installed before this. | |
| os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
| import detectron2 | |
| from demo import setup_cfg | |
| from proxydet.predictor import VisualizationDemo | |
| # Use GPU if available | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| # # download metadata | |
| # zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy' | |
| # if not os.path.exists(zs_weight_path): | |
| # wget.download("https://github.com/facebookresearch/Detic/raw/main/datasets/metadata/lvis_v1_clip_a+cname.npy", out=zs_weight_path) | |
| # base_cat_mask_path = "datasets/metadata/lvis_v1_base_cat_mask.npy" | |
| # if not os.path.exists(base_cat_mask_path): | |
| # wget.download("https://docs.google.com/uc?export=download&id=1CbSs5yeqMsWDkRSsIlB-ln_bXDv686rH", out=base_cat_mask_path) | |
| # lvis_train_cat_info_path = "datasets/metadata/lvis_v1_train_cat_info.json" | |
| # if not os.path.exists(lvis_train_cat_info_path): | |
| # wget.download("https://docs.google.com/uc?export=download&id=17WmkAJYBK4xT-YkiXLcwIWmtfulSUtmO", out=lvis_train_cat_info_path) | |
| # # download model | |
| # model_path = "models/proxydet_swinb_w_inl.pth" | |
| # if not os.path.exists(model_path): | |
| # gdown.download("https://docs.google.com/uc?export=download&id=17kUPoi-pEK7BlTBheGzWxe_DXJlg28qF", model_path) | |
| hf_hub_download( | |
| repo_id="doublejtoh/proxydet_data", | |
| filename="models/proxydet_swinb_w_inl.pth", | |
| repo_type="model", | |
| local_dir="./" | |
| ) | |
| hf_hub_download( | |
| repo_id="doublejtoh/proxydet_data", | |
| filename="datasets/metadata/lvis_v1_base_cat_mask.npy", | |
| repo_type="model", | |
| local_dir="./" | |
| ) | |
| hf_hub_download( | |
| repo_id="doublejtoh/proxydet_data", | |
| filename="datasets/metadata/lvis_v1_clip_a+cname.npy", | |
| repo_type="model", | |
| local_dir="./" | |
| ) | |
| hf_hub_download( | |
| repo_id="doublejtoh/proxydet_data", | |
| filename="datasets/metadata/lvis_v1_train_cat_info.json", | |
| repo_type="model", | |
| local_dir="./" | |
| ) | |
| model_path = "models/proxydet_swinb_w_inl.pth" | |
| zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy' | |
| args = Namespace( | |
| base_cat_threshold=0.9, | |
| confidence_threshold=0.0, | |
| config_file='configs/ProxyDet_SwinB_Lbase_INL.yaml', | |
| cpu=not torch.cuda.is_available(), | |
| custom_vocabulary='headphone,webcam,paper,coffe', | |
| input=['.assets/desk.jpg'], | |
| opts=['MODEL.WEIGHTS', model_path], | |
| output='out.jpg', | |
| pred_all_class=False, | |
| video_input=None, | |
| vocabulary='custom', | |
| webcam=None, | |
| zeroshot_weight_path=zs_weight_path | |
| ) | |
| cfg = setup_cfg(args) | |
| ovd_demo = VisualizationDemo(cfg, args) | |
| def query_image(img, text_queries, score_threshold, base_alpha, novel_beta): | |
| text_queries_split = text_queries.split(",") | |
| ovd_demo.reset_classifier(text_queries) | |
| ovd_demo.reset_base_cat_mask() | |
| ovd_demo.predictor.model.roi_heads.cmm_base_alpha = base_alpha | |
| ovd_demo.predictor.model.roi_heads.cmm_novel_beta = novel_beta | |
| img_bgr = img[:, :, ::-1] | |
| with torch.no_grad(): | |
| predictions, visualized_output = ovd_demo.run_on_image(img_bgr) | |
| output_instances = predictions["instances"].to(device) | |
| boxes = output_instances.pred_boxes.tensor | |
| scores = output_instances.scores | |
| labels = output_instances.pred_classes.tolist() | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| for box, score, label in zip(boxes, scores, labels): | |
| box = [int(i) for i in box.tolist()] | |
| if score >= score_threshold: | |
| img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) | |
| if box[3] + 25 > 768: | |
| y = box[3] - 10 | |
| else: | |
| y = box[3] + 25 | |
| img = cv2.putText( | |
| img, text_queries_split[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA | |
| ) | |
| return img | |
| if __name__ == "__main__": | |
| description = """ | |
| Gradio demo for ProxyDet, introduced in <a href="https://arxiv.org/abs/2312.07266">ProxyDet: Synthesizing Proxy Novel Classes via Classwise Mixup for Open-Vocabulary Object Detection</a>. | |
| \n\nYou can use ProxyDet to query images with text descriptions of any object. | |
| How to use? | |
| - Simply upload an image and enter comma separated objects (e.g., "dog,cat,headphone") which you want to detect within the image.\n | |
| Parameters: | |
| - You can also use the score threshold slider to set a threshold to filter out low probability predictions. | |
| - adjust alpha and beta value for base and novel classes, respectively. These determine <b>how much importance will you assign to the scores sourced from our proposed detection head which is trained with our proxy-novel classes</b>. | |
| """ | |
| demo = gr.Interface( | |
| query_image, | |
| inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1), gr.Slider(0, 1, value=0.15), gr.Slider(0, 1, value=0.35)], | |
| outputs="image", | |
| title="Open-Vocabulary Object Detection with ProxyDet", | |
| description=description, | |
| examples=[ | |
| ["assets/desk.jpg", "headphone,webcam,paper,coffee", 0.11, 0.15, 0.35], | |
| ["assets/beach.jpg", "person,kite", 0.1, 0.15, 0.35], | |
| ["assets/pikachu.jpg", "pikachu,person", 0.15, 0.15, 0.35], | |
| ], | |
| ) | |
| demo.launch() | |