Yolo_V3 / app.py
Gosula's picture
Update app.py
a5806ac verified
raw
history blame contribute delete
No virus
4.07 kB
import gradio as gr
import numpy as np
import cv2
import torch
from torchvision import datasets, transforms
from PIL import Image
#from train import YOLOv3Lightning
from utils import non_max_suppression, plot_image, cells_to_bboxes
from dataset import YOLODataset
import config
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model import YoloVersion3
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Load the model
model = YoloVersion3( )
model.load_state_dict(torch.load('Yolov3.pth', map_location=torch.device('cpu')), strict=False)
model.eval()
# Anchor
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to("cpu")
test_transforms = A.Compose(
[
A.LongestMaxSize(max_size=416),
A.PadIfNeeded(
min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
ToTensorV2(),
]
)
def plot_image(image, boxes):
"""Plots predicted bounding boxes on the image"""
cmap = plt.get_cmap("tab20b")
class_labels = config.PASCAL_CLASSES
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
im = np.array(image)
height, width, _ = im.shape
# Create figure and axes
fig, ax = plt.subplots(1)
# Display the image
ax.imshow(im)
# Create a Rectangle patch
for box in boxes:
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
class_pred = box[0]
box = box[2:]
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
linewidth=2,
edgecolor=colors[int(class_pred)],
facecolor="none",
)
# Add the patch to the Axes
ax.add_patch(rect)
plt.text(
upper_left_x * width,
upper_left_y * height,
s=class_labels[int(class_pred)],
color="white",
verticalalignment="top",
bbox={"color": colors[int(class_pred)], "pad": 0},
)
# plt.show()
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
ax.axis('off')
plt.savefig('inference.png')
# Inference function
def inference(inp_image):
inp_image=inp_image
org_image = inp_image
transform = test_transforms
x = transform(image=inp_image)["image"]
x=x.unsqueeze(0)
# Perform inference
device = "cpu"
model.to(device)
# Ensure model is in evaluation mode
model.eval()
# Perform inference
with torch.no_grad():
out = model(x)
#out = model(x)
# Ensure model is in evaluation mode
bboxes = [[] for _ in range(x.shape[0])]
for i in range(3):
batch_size, A, S, _, _ = out[i].shape
anchor = scaled_anchors[i]
boxes_scale_i = cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=0.5, threshold=0.6, box_format="midpoint",
)
# print(nms_boxes[0])
width_ratio = org_image.shape[1] / 416
height_ratio = org_image.shape[0] / 416
plot_image(org_image, nms_boxes)
plotted_img = 'inference.png'
return plotted_img
inputs = gr.inputs.Image(label="Original Image")
outputs = gr.outputs.Image(type="pil",label="Output Image")
title = "YOLOv3 model trained on PASCAL VOC Dataset"
description = "YOLOv3 object detection using Gradio demo"
examples = [['examples/car.jpg'], ['examples/home.jpg'],['examples/train.jpg'],['examples/train_persons.jpg']]
gr.Interface(inference, inputs, outputs, title=title, examples=examples, description=description, theme='xiaobaiyuan/theme_brief').launch(
debug=False)