regionclip-demo / app.py
jw2yang's picture
Update app.py
30db11d
import argparse
import requests
import logging
import os
import gradio as gr
import numpy as np
import cv2
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
from config import get_config
from collections import OrderedDict
os.system("python -m pip install -e .")
os.system("pip install opencv-python timm diffdist h5py sklearn ftfy")
os.system("pip install git+https://github.com/lvis-dataset/lvis-api.git")
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultTrainer as Trainer
from detectron2.engine import default_argument_parser, default_setup, hooks, launch
from detectron2.evaluation import (
CityscapesInstanceEvaluator,
CityscapesSemSegEvaluator,
COCOEvaluator,
COCOPanopticEvaluator,
DatasetEvaluators,
LVISEvaluator,
PascalVOCDetectionEvaluator,
SemSegEvaluator,
verify_results,
FLICKR30KEvaluator,
)
from detectron2.modeling import GeneralizedRCNNWithTTA
def parse_option():
parser = argparse.ArgumentParser('RegionCLIP demo script', add_help=False)
parser.add_argument('--config-file', type=str, default="configs/CLIP_fast_rcnn_R_50_C4.yaml", metavar="FILE", help='path to config file', )
args, unparsed = parser.parse_known_args()
return args
def build_transforms(img_size, center_crop=True):
t = []
if center_crop:
size = int((256 / 224) * img_size)
t.append(
transforms.Resize(size)
)
t.append(
transforms.CenterCrop(img_size)
)
else:
t.append(
transforms.Resize(img_size)
)
t.append(transforms.ToTensor())
return transforms.Compose(t)
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.freeze()
default_setup(cfg, args)
return cfg
'''
build model
'''
args = parse_option()
cfg = setup(args)
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=False
)
if cfg.MODEL.META_ARCHITECTURE in ['CLIPRCNN', 'CLIPFastRCNN', 'PretrainFastRCNN'] \
and cfg.MODEL.CLIP.BB_RPN_WEIGHTS is not None\
and cfg.MODEL.CLIP.CROP_REGION_TYPE == 'RPN': # load 2nd pretrained model
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, bb_rpn_weights=True).resume_or_load(
cfg.MODEL.CLIP.BB_RPN_WEIGHTS, resume=False
)
'''
build data transform
'''
eval_transforms = build_transforms(800, center_crop=False)
# display_transforms = build_transforms4display(960, center_crop=False)
def localize_object(image, texts):
img_t = eval_transforms(Image.fromarray(image).convert("RGB")) * 255
model.eval()
with torch.no_grad():
res = model(texts, [{"image": img_t}])
return res
image = gr.inputs.Image()
gr.Interface(
description="Zero-Shot Object Detection with RegionCLIP (https://github.com/microsoft/RegionCLIP)",
fn=localize_object,
inputs=["image", "text"],
outputs=[
gr.outputs.Image(
type="pil",
label="grounding results"),
],
examples=[
["./birds.png", "a goldfinch"],
["./apples_six.jpg", "a yellow apple"],
["./wines.jpg", "milk shake"],
["./logos.jpg", "a microsoft logo"],
],
).launch()