yolov3 / app.py
catchlui's picture
Update app.py
f2d625b
import torch, torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
#from modelone import LitResnet
import gradio as gr
import os
import config
import torch
import torch.optim as optim
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2
from model import YOLOv3,YOLOV3LITE
from tqdm import tqdm
from utils import (
mean_average_precision,
cells_to_bboxes,
get_evaluation_bboxes,
save_checkpoint,
load_checkpoint,
check_class_accuracy,
get_loaders,
plot_couple_examples
)
from dataset import YOLODatasetOK
from utils import non_max_suppression,plot_image
from loss import YoloLoss
import warnings
warnings.filterwarnings("ignore")
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger,TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.optim as optim
import pytorch_lightning as pl
torch.backends.cudnn.benchmark = True
def load_checkpoint(checkpoint_file, model, optimizer, lr,with_optim=False):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
model.load_state_dict(checkpoint["state_dict"])
if with_optim:
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return model
model_handler = YOLOV3LITE()
loaded_model =load_checkpoint(
config.CHECKPOINT_FILE,model_handler, model_handler.optimizer, config.LEARNING_RATE
)
#model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
#model = LitResnet.load_from_checkpoint("best_model.ckpt")
inv_normalize = transforms.Normalize(
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
std=[1/0.23, 1/0.23, 1/0.23]
)
classes = (
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor"
)
def inference(input_img,gradcam_on="TRUE", transparency = 0.5, target_layer_number = -1,top_num_images=4,view_missclassified="FALSE",missclassified_count=2):
import albumentations as A
test_transform = A.Compose(
[
# Rescale an image so that maximum side is equal to image_size
A.LongestMaxSize(max_size=config.IMAGE_SIZE),
# Pad remaining areas with zeros
A.PadIfNeeded(
min_height=config.IMAGE_SIZE, min_width=config.IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
),
# Normalize the image
A.Normalize(
mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255
),
# Convert the image to PyTorch tensor
ToTensorV2()
],
# Augmentation for bounding boxes
bbox_params=A.BboxParams(
format="yolo",
min_visibility=0.4,
label_fields=[]
))
anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)
transform = transforms.ToTensor()
org_img = input_img
input_img1 = transform(input_img)
input_img = input_img1
print("Input image",input_img.shape)
input_img = input_img.unsqueeze(0)
print("Input Image unsquevezed",input_img.shape)
out = loaded_model(input_img)
#out = model(x)
iou_thresh = 0.5
thresh = 0.6
print("input_img.shape[0]",input_img.shape[0])
bboxes = [[] for _ in range(input_img.shape[0])]
print("out[0].sshape",out[0].shape)
for i in range(3):
batch_size, A, S, _, _ = out[i].shape
anchor = anchors[i]
boxes_scale_i = cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
)
plot_image(input_img1.permute(1,2,0).detach().cpu(), nms_boxes)
if gradcam_on =="TRUE":
target_layers = [model.model.layer2[target_layer_number]]
cam = GradCAM(model=loaded_model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
img = input_img.squeeze(0)
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
rgb_img = rgb_img.numpy()
visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
else:
img = input_img.squeeze(0)
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
visualization = rgb_img.numpy()
return visualization
title = "Yolov3 trained on Model with GradCAM"
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
examples = [[os.path.join(os.path.dirname(__file__),"imgs/cat.jpg"), 0.5, -1],
[os.path.join(os.path.dirname(__file__),"imgs/dog.jpg"), 0.5, -1],
[os.path.join(os.path.dirname(__file__),"imgs/car.jpg"), 0.5, -1],
[os.path.join(os.path.dirname(__file__),"imgs/frog.jpg"), 0.5, -1],
[os.path.join(os.path.dirname(__file__),"imgs/horse.jpg"), 0.5, -1],
[os.path.join(os.path.dirname(__file__),"imgs/tiger.jpg"), 0.5, -1],
[os.path.join(os.path.dirname(__file__),"imgs/dog2.jpg"), 0.5, -1],
[os.path.join(os.path.dirname(__file__),"imgs/bird2.jpg"), 0.5, -1],
[os.path.join(os.path.dirname(__file__),"imgs/Cat03.jpg"), 0.5, -1],
[os.path.join(os.path.dirname(__file__),"imgs/truck.jpg"), 0.5, -1],
]
demo = gr.Interface(
inference,
inputs = [gr.Image(shape=(416, 416), label="Input Images"),gr.Radio(["TRUE","FALSE"], label="Gradcam Req", info="Do you need gradcam images?"), gr.Slider(0, 1, value = .5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),gr.Slider(1, 10, value = 4, step=1, label="Howmany Top Classes"),gr.Checkbox(label="View Missclassified", info="Do you want to view missclassified images?"),gr.Slider(1, 10, value=1, label="Missclassfied Count", info="Choose between 1 and 10")],
outputs = [gr.Label(), gr.Image(shape=(416, 416), label="Output").style(width=128, height=128)],
title = title,
description = description,
examples = examples,
)
demo.launch()