DhrubaAdhikary1991's picture
push changes
abb203b verified
raw
history blame
2.66 kB
import torch
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2
import numpy as np
import config
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from model import YOLOv3
import gradio as gr
import os
import matplotlib.pyplot as plt
from utils import *
model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
model.load_state_dict(torch.load("custom_yolo_v3.pt", map_location=torch.device('cpu')), strict=False)
IMAGE_SIZE = config.IMAGE_SIZE
test_transforms = A.Compose(
[
A.LongestMaxSize(max_size=IMAGE_SIZE),
A.PadIfNeeded(
min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
ToTensorV2(),
])
anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1,3,2)
).to(config.DEVICE)
def inference(input_img, transparency = 0.5, thresh=0.8, iou_thresh=0.3):
x = test_transforms(image=input_img)['image'].unsqueeze(0).to(config.DEVICE)
with torch.no_grad():
out = model(x)
bboxes = [[] for _ in range(x.shape[0])]
for i in range(3):
batch_size, A, S, _, _ = out[i].shape
anchor = anchors[i]
boxes_scale_i = cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
model.train()
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
)
visualization = plot_image(x[0].permute(1,2,0).detach().cpu(), nms_boxes)
return visualization
title = "Object Detection using YOLOv3"
description = "A simple Gradio interface to show object detection on images"
examples = [['./images/006294.jpg'],
['./images/005898.jpg'],
['./images/003785.jpg'],
['./images/001624.jpg'],
['./images/006796.jpg'],
['./images/003388.jpg'],
['./images/002216.jpg'],
['./images/000341.jpg'],
['./images/006818.jpg'],
]
demo = gr.Interface(
inference,
inputs = [gr.Image(label="Input Image", type='numpy'),
gr.Slider(0, 1, value = 0.75, label="Threshold"),
gr.Slider(0, 1, value = 0.75, label="IoU Threshold"),
gr.Slider(0, 1, value = 0.8, label="Opacity of GradCAM")],
outputs = [gr.Image(label="Output").style(width=600, height=600)],
title = title,
description = description,
examples = examples,
)
demo.launch()