File size: 3,628 Bytes
4121bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80a17c3
 
 
 
4121bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1e8dfd
4121bec
 
 
 
 
 
 
 
966e6d7
30db11d
b2a8e32
1c54f2a
4121bec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import requests
import logging
import os
import gradio as gr
import numpy as np
import cv2
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
from config import get_config

from collections import OrderedDict

os.system("python -m pip install -e .")
os.system("pip install opencv-python timm diffdist h5py sklearn ftfy")
os.system("pip install git+https://github.com/lvis-dataset/lvis-api.git")

import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultTrainer as Trainer 
from detectron2.engine import default_argument_parser, default_setup, hooks, launch
from detectron2.evaluation import (
    CityscapesInstanceEvaluator,
    CityscapesSemSegEvaluator,
    COCOEvaluator,
    COCOPanopticEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    SemSegEvaluator,
    verify_results,
    FLICKR30KEvaluator,
)
from detectron2.modeling import GeneralizedRCNNWithTTA

def parse_option():
    parser = argparse.ArgumentParser('RegionCLIP demo script', add_help=False)
    parser.add_argument('--config-file', type=str, default="configs/CLIP_fast_rcnn_R_50_C4.yaml", metavar="FILE", help='path to config file', )
    args, unparsed = parser.parse_known_args()

    return args

def build_transforms(img_size, center_crop=True):
    t = []    
    if center_crop:
        size = int((256 / 224) * img_size)
        t.append(
            transforms.Resize(size)
        )
        t.append(
            transforms.CenterCrop(img_size)    
        )
    else:
        t.append(
            transforms.Resize(img_size)
        )       
    t.append(transforms.ToTensor())
    return transforms.Compose(t)

def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg

'''
build model
'''
args = parse_option()
cfg = setup(args)

model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
    cfg.MODEL.WEIGHTS, resume=False
)
if cfg.MODEL.META_ARCHITECTURE in ['CLIPRCNN', 'CLIPFastRCNN', 'PretrainFastRCNN'] \
    and cfg.MODEL.CLIP.BB_RPN_WEIGHTS is not None\
    and cfg.MODEL.CLIP.CROP_REGION_TYPE == 'RPN': # load 2nd pretrained model
    DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, bb_rpn_weights=True).resume_or_load(
        cfg.MODEL.CLIP.BB_RPN_WEIGHTS, resume=False
    )

'''
build data transform
'''
eval_transforms = build_transforms(800, center_crop=False)
# display_transforms = build_transforms4display(960, center_crop=False)

def localize_object(image, texts):
    img_t = eval_transforms(Image.fromarray(image).convert("RGB")) * 255
    model.eval()
    with torch.no_grad():
        res = model(texts, [{"image": img_t}])

    return res


image = gr.inputs.Image()

gr.Interface(
    description="Zero-Shot Object Detection with RegionCLIP (https://github.com/microsoft/RegionCLIP)",
    fn=localize_object,
    inputs=["image", "text"],
    outputs=[                   
        gr.outputs.Image(
        type="pil",
        label="grounding results"),   
    ],
    examples=[
    ["./birds.png", "a goldfinch"], 
    ["./apples_six.jpg", "a yellow apple"],    
    ["./wines.jpg", "milk shake"], 
    ["./logos.jpg", "a microsoft logo"], 
    ],
).launch()