Vishakaraj's picture
Save output results as json
a312060
raw
history blame contribute delete
No virus
3.76 kB
import os
os.system("cd detectron2 && pip install detectron2-0.6-cp310-cp310-linux_x86_64.whl")
os.system("pip install deepspeed==0.7.0")
import site
from importlib import reload
reload(site)
from PIL import Image
from io import BytesIO
import argparse
import sys
import numpy as np
import torch
import gradio as gr
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
sys.path.insert(0, "third_party/CenterNet2/projects/CenterNet2/")
from centernet.config import add_centernet_config
from grit.config import add_grit_config
from grit.predictor import VisualizationDemo
def get_parser():
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
parser.add_argument(
"--config-file",
default="configs/GRiT_B_DenseCap_ObjectDet.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument("--cpu", action="store_true", help="Use CPU only.")
parser.add_argument(
"--confidence-threshold",
type=float,
default=0.5,
help="Minimum score for instance predictions to be shown",
)
parser.add_argument(
"--test-task",
type=str,
default="",
help="Choose a task to have GRiT perform",
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=["MODEL.WEIGHTS", "./models/grit_b_densecap_objectdet.pth"],
nargs=argparse.REMAINDER,
)
return parser
def setup_cfg(args):
cfg = get_cfg()
if args.cpu:
cfg.MODEL.DEVICE = "cpu"
add_centernet_config(cfg)
add_grit_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
# Set score_threshold for builtin models
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = (
args.confidence_threshold
)
if args.test_task:
cfg.MODEL.TEST_TASK = args.test_task
cfg.MODEL.BEAM_SIZE = 1
cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
cfg.USE_ACT_CHECKPOINT = False
cfg.freeze()
return cfg
def predict(image_file):
image_array = np.array(image_file)[:, :, ::-1] # BGR
predictions, visualized_output = dense_captioning_demo.run_on_image(image_array)
buffer = BytesIO()
visualized_output.fig.savefig(buffer, format='png')
buffer.seek(0)
detections = {}
predictions = predictions["instances"].to(torch.device("cpu"))
for box, description, score in zip(
predictions.pred_boxes,
predictions.pred_object_descriptions.data,
predictions.scores,
):
if description not in detections:
detections[description] = []
detections[description].append(
{
"xmin": float(box[0]),
"ymin": float(box[1]),
"xmax": float(box[2]),
"ymax": float(box[3]),
"score": float(score),
}
)
output = {
"dense_captioning_results": {
"detections": detections,
}
}
return Image.open(buffer), output
args = get_parser().parse_args()
args.test_task = "DenseCap"
setup_logger(name="fvcore")
logger = setup_logger()
logger.info("Arguments: " + str(args))
cfg = setup_cfg(args)
dense_captioning_demo = VisualizationDemo(cfg)
demo = gr.Interface(
title="Dense Captioning - GRiT",
fn=predict,
inputs=gr.Image(type='pil', label="Original Image"),
outputs=[gr.Image(type="pil",label="Output Image"), "json"],
examples=["example_1.jpg", "example_2.jpg"],
)
demo.launch()