Vishakaraj's picture
Upload folder using huggingface_hub
c709b60
raw
history blame
No virus
3.5 kB
import os
# os.system("sudo apt-get update && sudo apt-get install -y git")
# os.system("sudo apt-get -y install pybind11-dev")
# os.system("git clone https://github.com/facebookresearch/detectron2.git")
# os.system("pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html")
os.system("cd detectron2 && pip install detectron2-0.6-cp310-cp310-linux_x86_64.whl")
# os.system("pip3 install torch torchvision torchaudio")
os.system("pip install deepspeed==0.7.0")
import site
from importlib import reload
reload(site)
from PIL import Image
import argparse
import sys
import numpy as np
import cv2
import gradio as gr
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
sys.path.insert(0, "third_party/CenterNet2/projects/CenterNet2/")
from centernet.config import add_centernet_config
from grit.config import add_grit_config
from grit.predictor import VisualizationDemo
def get_parser():
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
parser.add_argument(
"--config-file",
default="configs/GRiT_B_DenseCap_ObjectDet.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument("--cpu", action="store_true", help="Use CPU only.")
parser.add_argument(
"--confidence-threshold",
type=float,
default=0.5,
help="Minimum score for instance predictions to be shown",
)
parser.add_argument(
"--test-task",
type=str,
default="",
help="Choose a task to have GRiT perform",
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=["MODEL.WEIGHTS", "./models/grit_b_densecap_objectdet.pth"],
nargs=argparse.REMAINDER,
)
return parser
def setup_cfg(args):
cfg = get_cfg()
if args.cpu:
cfg.MODEL.DEVICE = "cpu"
add_centernet_config(cfg)
add_grit_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
# Set score_threshold for builtin models
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = (
args.confidence_threshold
)
if args.test_task:
cfg.MODEL.TEST_TASK = args.test_task
cfg.MODEL.BEAM_SIZE = 1
cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
cfg.USE_ACT_CHECKPOINT = False
cfg.freeze()
return cfg
def predict(image_file):
image_array = np.array(image_file)[:, :, ::-1] # BGR
_, visualized_output = dense_captioning_demo.run_on_image(image_array)
visualized_output.save(os.path.join(os.getcwd(), "output.jpg"))
output_image = cv2.imread(os.path.join(os.getcwd(), "output.jpg"))
output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
return Image.fromarray(output_image)
args = get_parser().parse_args()
args.test_task = "DenseCap"
setup_logger(name="fvcore")
logger = setup_logger()
logger.info("Arguments: " + str(args))
cfg = setup_cfg(args)
dense_captioning_demo = VisualizationDemo(cfg)
demo = gr.Interface(
title="Dense Captioning - GRiT",
fn=predict,
inputs=gr.Image(type='pil', label="Original Image"),
outputs=gr.Image(type="pil",label="Output Image"),
examples=["example_1.jpg", "example_2.jpg"],
)
demo.launch()