|
import torch |
|
import albumentations as A |
|
import cv2 |
|
|
|
from albumentations.pytorch import ToTensorV2 |
|
import numpy as np |
|
import config |
|
from PIL import Image |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
from model import YOLOv3 |
|
import gradio as gr |
|
import os |
|
import matplotlib.pyplot as plt |
|
from utils import * |
|
|
|
|
|
model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE) |
|
model.load_state_dict(torch.load("custom_yolo_v3.pt", map_location=torch.device('cpu')), strict=False) |
|
|
|
IMAGE_SIZE = config.IMAGE_SIZE |
|
test_transforms = A.Compose( |
|
[ |
|
A.LongestMaxSize(max_size=IMAGE_SIZE), |
|
A.PadIfNeeded( |
|
min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT |
|
), |
|
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), |
|
ToTensorV2(), |
|
]) |
|
|
|
anchors = ( |
|
torch.tensor(config.ANCHORS) |
|
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1,3,2) |
|
).to(config.DEVICE) |
|
|
|
def inference(input_img, transparency = 0.5, thresh=0.8, iou_thresh=0.3): |
|
|
|
x = test_transforms(image=input_img)['image'].unsqueeze(0).to(config.DEVICE) |
|
with torch.no_grad(): |
|
out = model(x) |
|
bboxes = [[] for _ in range(x.shape[0])] |
|
for i in range(3): |
|
batch_size, A, S, _, _ = out[i].shape |
|
anchor = anchors[i] |
|
boxes_scale_i = cells_to_bboxes( |
|
out[i], anchor, S=S, is_preds=True |
|
) |
|
for idx, (box) in enumerate(boxes_scale_i): |
|
bboxes[idx] += box |
|
|
|
model.train() |
|
|
|
nms_boxes = non_max_suppression( |
|
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", |
|
) |
|
visualization = plot_image(x[0].permute(1,2,0).detach().cpu(), nms_boxes) |
|
|
|
return visualization |
|
|
|
title = "Object Detection using YOLOv3" |
|
description = "A simple Gradio interface to show object detection on images" |
|
examples = [['./images/006294.jpg'], |
|
['./images/005898.jpg'], |
|
['./images/003785.jpg'], |
|
['./images/001624.jpg'], |
|
['./images/006796.jpg'], |
|
['./images/003388.jpg'], |
|
['./images/002216.jpg'], |
|
['./images/000341.jpg'], |
|
['./images/006818.jpg'], |
|
] |
|
|
|
|
|
demo = gr.Interface( |
|
inference, |
|
inputs = [gr.Image(label="Input Image", type='numpy'), |
|
gr.Slider(0, 1, value = 0.75, label="Threshold"), |
|
gr.Slider(0, 1, value = 0.75, label="IoU Threshold"), |
|
gr.Slider(0, 1, value = 0.8, label="Opacity of GradCAM")], |
|
outputs = [gr.Image(label="Output").style(width=600, height=600)], |
|
title = title, |
|
description = description, |
|
examples = examples, |
|
) |
|
demo.launch() |