YoloV3_PASCAL / app.py
PrarthanaTS's picture
Update app.py
cd47c05
raw
history blame contribute delete
No virus
4.86 kB
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 11 18:08:06 2023
@author: prarthana.ts
"""
import torch
import torch.optim as optim
import lightning.pytorch as pl
from lightning.pytorch.tuner import Tuner
# import pytorch_lightning as pl
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import albumentations as A
import cv2
import torch
from pytorch_grad_cam.utils.image import show_cam_on_image
import numpy as np
from albumentations.pytorch import ToTensorV2
from utils_for_app import cells_to_bboxes,non_max_suppression,plot_image,YoloCAM
from yolov3 import YOLOv3
from loss import YoloLoss
from utils import LearningRateFinder
# Create your config module or import it from the existing config.py file.
import config
from main_yolov3_lightening import YOLOv3Lightning
import torch
import cv2
import numpy as np
import gradio as gr
import os
model = YOLOv3Lightning()
model.load_state_dict(torch.load("yolov3_model_without_75_mosaic.pth", map_location=torch.device('cpu')), strict=False)
model.setup(stage="test")
IMAGE_SIZE = 416
transforms = A.Compose(
[
A.LongestMaxSize(max_size=IMAGE_SIZE),
A.PadIfNeeded(
min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
ToTensorV2(),
],
)
ANCHORS = [
[(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
[(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
[(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
] # Note these have been rescaled to be between [0, 1]
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
)
def process_image_and_plot(image,iou_threshold=0.5, threshold=0.4):
transformed_image = transforms(image=image)["image"].unsqueeze(0)
output = model(transformed_image)
bboxes = [[] for _ in range(1)]
for i in range(3):
batch_size, A, S, _, _ = output[i].shape
anchor = scaled_anchors[i]
boxes_scale_i = cells_to_bboxes(output[i], anchor, S=S, is_preds=True)
for idx, box in enumerate(boxes_scale_i):
bboxes[idx] += box
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=iou_threshold, threshold=threshold, box_format="midpoint",
)
fig = plot_image(transformed_image[0].permute(1, 2, 0), nms_boxes)
cam = YoloCAM(model=model, target_layers=[model.model.layers[-2]], use_cuda=False)
grayscale_cam = cam(transformed_image, scaled_anchors)[0, :, :]
img = cv2.resize(image, (416, 416))
img = np.float32(img) / 255
cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
return fig,cam_image
examples = [
["images/cycle.jpg"],
["images/human.jpg"],
["images/automobile.jpg"],
["images/barn.jpg"],
["images/car.jpg"],
["images/cars.jpg"],
["images/farm.jpg"],
["images/farms.jpg"],
["images/living.jpg"],
["images/livings.jpg"],
]
icon_html = '<i class="fas fa-chart-bar"></i>'
title_with_icon = f"""
<div style="background-color: #f5f1f2; padding: 10px; display: flex; align-items: center;">
{icon_html} <span style="margin-left: 10px;">Object Detection on Pascal VOC Dataset with YoloV3</span>
</div>
"""
description_with_icon = f"""
<div style="background-color: #f1f1f5; padding: 10px; display: flex; align-items: center;">
{icon_html}
<span style="margin-left: 10px;">
<p><strong>PyTorch Lightning Implementation of YOLOv3 Trained from Scratch</strong></p>
<p><strong>Trained Classes:</strong></p>
<ul>
<li>๐Ÿšถโ€โ™‚๏ธ Person: person</li>
<li>๐Ÿฆฎ Animal: bird, cat, cow, dog, horse, sheep</li>
<li>๐Ÿšš Vehicle: aeroplane, bicycle, boat, bus, car, motorbike, train</li>
<li>๐ŸŽ„ Indoor: bottle, chair, dining table, potted plant, sofa, TV, monitor</li>
</ul>
<p>Note: Setting a Lower IOU and a higher threshold shows better object detection</p>
</span>
</div>
"""
demo = gr.Interface(process_image_and_plot,
inputs=[gr.Image(label="Input Image"),
gr.Slider(0, 1, value=0.5, label="Intersection over Union (IOU) Threshold",info="Determines how much overlap between two boxes is allowed before they are considered redundant"),
gr.Slider(0, 1, value=0.4, label="Threshold", info="It is used to filter out boxes with confidence scores below it. Higher value reduces the weaker classes"),],
outputs=[
gr.Plot(label="Output with Classes",),
gr.Image(shape=(32, 32), label="GradCAM Output"),
],
title=title_with_icon,
description=description_with_icon,
examples=examples,
)
demo.launch()