Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import cv2 | |
import torch | |
from torchvision import datasets, transforms | |
from PIL import Image | |
#from train import YOLOv3Lightning | |
from utils import non_max_suppression, plot_image, cells_to_bboxes | |
from dataset import YOLODataset | |
import config | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
from model import YoloVersion3 | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
# Load the model | |
model = YoloVersion3( ) | |
model.load_state_dict(torch.load('Yolov3.pth', map_location=torch.device('cpu')), strict=False) | |
model.eval() | |
# Anchor | |
scaled_anchors = ( | |
torch.tensor(config.ANCHORS) | |
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
).to("cpu") | |
test_transforms = A.Compose( | |
[ | |
A.LongestMaxSize(max_size=416), | |
A.PadIfNeeded( | |
min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT | |
), | |
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), | |
ToTensorV2(), | |
] | |
) | |
def plot_image(image, boxes): | |
"""Plots predicted bounding boxes on the image""" | |
cmap = plt.get_cmap("tab20b") | |
class_labels = config.PASCAL_CLASSES | |
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))] | |
im = np.array(image) | |
height, width, _ = im.shape | |
# Create figure and axes | |
fig, ax = plt.subplots(1) | |
# Display the image | |
ax.imshow(im) | |
# Create a Rectangle patch | |
for box in boxes: | |
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height" | |
class_pred = box[0] | |
box = box[2:] | |
upper_left_x = box[0] - box[2] / 2 | |
upper_left_y = box[1] - box[3] / 2 | |
rect = patches.Rectangle( | |
(upper_left_x * width, upper_left_y * height), | |
box[2] * width, | |
box[3] * height, | |
linewidth=2, | |
edgecolor=colors[int(class_pred)], | |
facecolor="none", | |
) | |
# Add the patch to the Axes | |
ax.add_patch(rect) | |
plt.text( | |
upper_left_x * width, | |
upper_left_y * height, | |
s=class_labels[int(class_pred)], | |
color="white", | |
verticalalignment="top", | |
bbox={"color": colors[int(class_pred)], "pad": 0}, | |
) | |
# plt.show() | |
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
ax.axis('off') | |
plt.savefig('inference.png') | |
# Inference function | |
def inference(inp_image): | |
inp_image=inp_image | |
org_image = inp_image | |
transform = test_transforms | |
x = transform(image=inp_image)["image"] | |
x=x.unsqueeze(0) | |
# Perform inference | |
device = "cpu" | |
model.to(device) | |
# Ensure model is in evaluation mode | |
model.eval() | |
# Perform inference | |
with torch.no_grad(): | |
out = model(x) | |
#out = model(x) | |
# Ensure model is in evaluation mode | |
bboxes = [[] for _ in range(x.shape[0])] | |
for i in range(3): | |
batch_size, A, S, _, _ = out[i].shape | |
anchor = scaled_anchors[i] | |
boxes_scale_i = cells_to_bboxes( | |
out[i], anchor, S=S, is_preds=True | |
) | |
for idx, (box) in enumerate(boxes_scale_i): | |
bboxes[idx] += box | |
nms_boxes = non_max_suppression( | |
bboxes[0], iou_threshold=0.5, threshold=0.6, box_format="midpoint", | |
) | |
# print(nms_boxes[0]) | |
width_ratio = org_image.shape[1] / 416 | |
height_ratio = org_image.shape[0] / 416 | |
plot_image(org_image, nms_boxes) | |
plotted_img = 'inference.png' | |
return plotted_img | |
inputs = gr.inputs.Image(label="Original Image") | |
outputs = gr.outputs.Image(type="pil",label="Output Image") | |
title = "YOLOv3 model trained on PASCAL VOC Dataset" | |
description = "YOLOv3 object detection using Gradio demo" | |
examples = [['examples/car.jpg'], ['examples/home.jpg'],['examples/train.jpg'],['examples/train_persons.jpg']] | |
gr.Interface(inference, inputs, outputs, title=title, examples=examples, description=description, theme='xiaobaiyuan/theme_brief').launch( | |
debug=False) |