souranil3d's picture
Base inference
08aa404
import gradio as gr
import torch
import torchvision
import logging
from detectron2.engine import DefaultPredictor
import cv2
from detectron2.config import get_cfg
from src.utils.visualizer import add_bboxes
config_file="config.yaml"
cfg = get_cfg()
cfg.merge_from_file(config_file)
cfg.MODEL.DEVICE="cpu"
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = "checkpoints_model_final_imagenet_40k_synthetic.pth.pth"
def predict(
config_file, checkpoint_file, img_path
):
predictor = DefaultPredictor(cfg)
im = cv2.imread(img_path)
output = predictor(im)
img = add_bboxes(im, output['instances'].pred_boxes, scores=output['instances'].scores)
return img
title = "Pet Detection"
description = "Demo for Indoor Pet Detection"
examples = [['example.jpg']]
gr.Interface(predict, inputs=gr.inputs.Image(type="file"), outputs=gr.outputs.Image(type="pil"),enable_queue=True, title=title,
description=description,
# article=article,
examples=examples).launch()