|
import torch, torchvision |
|
from torchvision import transforms |
|
import numpy as np |
|
import gradio as gr |
|
from PIL import Image |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
|
import gradio as gr |
|
import os |
|
import config |
|
import torch |
|
import torch.optim as optim |
|
import albumentations as A |
|
import cv2 |
|
from albumentations.pytorch import ToTensorV2 |
|
|
|
|
|
from model import YOLOv3,YOLOV3LITE |
|
from tqdm import tqdm |
|
from utils import ( |
|
mean_average_precision, |
|
cells_to_bboxes, |
|
get_evaluation_bboxes, |
|
save_checkpoint, |
|
load_checkpoint, |
|
check_class_accuracy, |
|
get_loaders, |
|
plot_couple_examples |
|
) |
|
from dataset import YOLODatasetOK |
|
from utils import non_max_suppression,plot_image |
|
from loss import YoloLoss |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from pytorch_lightning.callbacks.progress import TQDMProgressBar |
|
from pytorch_lightning.loggers import CSVLogger,TensorBoardLogger |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
import torch.optim as optim |
|
import pytorch_lightning as pl |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
def load_checkpoint(checkpoint_file, model, optimizer, lr,with_optim=False): |
|
print("=> Loading checkpoint") |
|
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE) |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
if with_optim: |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
|
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
param_group["lr"] = lr |
|
return model |
|
|
|
model_handler = YOLOV3LITE() |
|
loaded_model =load_checkpoint( |
|
config.CHECKPOINT_FILE,model_handler, model_handler.optimizer, config.LEARNING_RATE |
|
) |
|
|
|
|
|
|
|
inv_normalize = transforms.Normalize( |
|
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], |
|
std=[1/0.23, 1/0.23, 1/0.23] |
|
) |
|
classes = ( |
|
"aeroplane", |
|
"bicycle", |
|
"bird", |
|
"boat", |
|
"bottle", |
|
"bus", |
|
"car", |
|
"cat", |
|
"chair", |
|
"cow", |
|
"diningtable", |
|
"dog", |
|
"horse", |
|
"motorbike", |
|
"person", |
|
"pottedplant", |
|
"sheep", |
|
"sofa", |
|
"train", |
|
"tvmonitor" |
|
) |
|
|
|
def inference(input_img,gradcam_on="TRUE", transparency = 0.5, target_layer_number = -1,top_num_images=4,view_missclassified="FALSE",missclassified_count=2): |
|
import albumentations as A |
|
|
|
test_transform = A.Compose( |
|
[ |
|
|
|
A.LongestMaxSize(max_size=config.IMAGE_SIZE), |
|
|
|
A.PadIfNeeded( |
|
min_height=config.IMAGE_SIZE, min_width=config.IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT |
|
), |
|
|
|
A.Normalize( |
|
mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255 |
|
), |
|
|
|
ToTensorV2() |
|
], |
|
|
|
bbox_params=A.BboxParams( |
|
format="yolo", |
|
min_visibility=0.4, |
|
label_fields=[] |
|
)) |
|
anchors = ( |
|
torch.tensor(config.ANCHORS) |
|
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) |
|
).to(config.DEVICE) |
|
|
|
transform = transforms.ToTensor() |
|
org_img = input_img |
|
input_img1 = transform(input_img) |
|
input_img = input_img1 |
|
print("Input image",input_img.shape) |
|
input_img = input_img.unsqueeze(0) |
|
print("Input Image unsquevezed",input_img.shape) |
|
out = loaded_model(input_img) |
|
|
|
iou_thresh = 0.5 |
|
thresh = 0.6 |
|
print("input_img.shape[0]",input_img.shape[0]) |
|
bboxes = [[] for _ in range(input_img.shape[0])] |
|
print("out[0].sshape",out[0].shape) |
|
for i in range(3): |
|
batch_size, A, S, _, _ = out[i].shape |
|
anchor = anchors[i] |
|
boxes_scale_i = cells_to_bboxes( |
|
out[i], anchor, S=S, is_preds=True |
|
) |
|
for idx, (box) in enumerate(boxes_scale_i): |
|
bboxes[idx] += box |
|
nms_boxes = non_max_suppression( |
|
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", |
|
) |
|
plot_image(input_img1.permute(1,2,0).detach().cpu(), nms_boxes) |
|
if gradcam_on =="TRUE": |
|
target_layers = [model.model.layer2[target_layer_number]] |
|
cam = GradCAM(model=loaded_model, target_layers=target_layers, use_cuda=False) |
|
grayscale_cam = cam(input_tensor=input_img, targets=None) |
|
grayscale_cam = grayscale_cam[0, :] |
|
img = input_img.squeeze(0) |
|
img = inv_normalize(img) |
|
rgb_img = np.transpose(img, (1, 2, 0)) |
|
rgb_img = rgb_img.numpy() |
|
visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) |
|
|
|
else: |
|
img = input_img.squeeze(0) |
|
img = inv_normalize(img) |
|
rgb_img = np.transpose(img, (1, 2, 0)) |
|
visualization = rgb_img.numpy() |
|
|
|
return visualization |
|
|
|
title = "Yolov3 trained on Model with GradCAM" |
|
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results" |
|
examples = [[os.path.join(os.path.dirname(__file__),"imgs/cat.jpg"), 0.5, -1], |
|
[os.path.join(os.path.dirname(__file__),"imgs/dog.jpg"), 0.5, -1], |
|
[os.path.join(os.path.dirname(__file__),"imgs/car.jpg"), 0.5, -1], |
|
[os.path.join(os.path.dirname(__file__),"imgs/frog.jpg"), 0.5, -1], |
|
[os.path.join(os.path.dirname(__file__),"imgs/horse.jpg"), 0.5, -1], |
|
[os.path.join(os.path.dirname(__file__),"imgs/tiger.jpg"), 0.5, -1], |
|
[os.path.join(os.path.dirname(__file__),"imgs/dog2.jpg"), 0.5, -1], |
|
[os.path.join(os.path.dirname(__file__),"imgs/bird2.jpg"), 0.5, -1], |
|
[os.path.join(os.path.dirname(__file__),"imgs/Cat03.jpg"), 0.5, -1], |
|
[os.path.join(os.path.dirname(__file__),"imgs/truck.jpg"), 0.5, -1], |
|
|
|
] |
|
demo = gr.Interface( |
|
inference, |
|
inputs = [gr.Image(shape=(416, 416), label="Input Images"),gr.Radio(["TRUE","FALSE"], label="Gradcam Req", info="Do you need gradcam images?"), gr.Slider(0, 1, value = .5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),gr.Slider(1, 10, value = 4, step=1, label="Howmany Top Classes"),gr.Checkbox(label="View Missclassified", info="Do you want to view missclassified images?"),gr.Slider(1, 10, value=1, label="Missclassfied Count", info="Choose between 1 and 10")], |
|
outputs = [gr.Label(), gr.Image(shape=(416, 416), label="Output").style(width=128, height=128)], |
|
title = title, |
|
description = description, |
|
examples = examples, |
|
) |
|
demo.launch() |