import torch import albumentations as A import cv2 from albumentations.pytorch import ToTensorV2 import numpy as np import config from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from model import YOLOv3 import gradio as gr import os import matplotlib.pyplot as plt from utils import * model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE) model.load_state_dict(torch.load("custom_yolo_v3.pt", map_location=torch.device('cpu')), strict=False) IMAGE_SIZE = config.IMAGE_SIZE test_transforms = A.Compose( [ A.LongestMaxSize(max_size=IMAGE_SIZE), A.PadIfNeeded( min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT ), A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), ToTensorV2(), ]) anchors = ( torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1,3,2) ).to(config.DEVICE) def inference(input_img, transparency = 0.5, thresh=0.8, iou_thresh=0.3): x = test_transforms(image=input_img)['image'].unsqueeze(0).to(config.DEVICE) with torch.no_grad(): out = model(x) bboxes = [[] for _ in range(x.shape[0])] for i in range(3): batch_size, A, S, _, _ = out[i].shape anchor = anchors[i] boxes_scale_i = cells_to_bboxes( out[i], anchor, S=S, is_preds=True ) for idx, (box) in enumerate(boxes_scale_i): bboxes[idx] += box model.train() nms_boxes = non_max_suppression( bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", ) visualization = plot_image(x[0].permute(1,2,0).detach().cpu(), nms_boxes) return visualization title = "Object Detection using YOLOv3" description = "A simple Gradio interface to show object detection on images" examples = [['./images/006294.jpg'], ['./images/005898.jpg'], ['./images/003785.jpg'], ['./images/001624.jpg'], ['./images/006796.jpg'], ['./images/003388.jpg'], ['./images/002216.jpg'], ['./images/000341.jpg'], ['./images/006818.jpg'], ] demo = gr.Interface( inference, inputs = [gr.Image(label="Input Image", type='numpy'), gr.Slider(0, 1, value = 0.75, label="Threshold"), gr.Slider(0, 1, value = 0.75, label="IoU Threshold"), gr.Slider(0, 1, value = 0.8, label="Opacity of GradCAM")], outputs = [gr.Image(label="Output").style(width=600, height=600)], title = title, description = description, examples = examples, ) demo.launch()