Spaces:
Sleeping
Sleeping
File size: 4,066 Bytes
64797fe 731680c 64797fe 4c1bb98 64797fe 4c1bb98 de556a9 64797fe 3a88fe8 64797fe 3a88fe8 64797fe de556a9 324c550 de556a9 f083318 ca7167f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import numpy as np
import cv2
import torch
from torchvision import datasets, transforms
from PIL import Image
#from train import YOLOv3Lightning
from utils import non_max_suppression, plot_image, cells_to_bboxes
from dataset import YOLODataset
import config
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model import YoloVersion3
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Load the model
model = YoloVersion3( )
model.load_state_dict(torch.load('Yolov3.pth', map_location=torch.device('cpu')), strict=False)
model.eval()
# Anchor
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to("cpu")
test_transforms = A.Compose(
[
A.LongestMaxSize(max_size=416),
A.PadIfNeeded(
min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
ToTensorV2(),
]
)
def plot_image(image, boxes):
"""Plots predicted bounding boxes on the image"""
cmap = plt.get_cmap("tab20b")
class_labels = config.PASCAL_CLASSES
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
im = np.array(image)
height, width, _ = im.shape
# Create figure and axes
fig, ax = plt.subplots(1)
# Display the image
ax.imshow(im)
# Create a Rectangle patch
for box in boxes:
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
class_pred = box[0]
box = box[2:]
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
linewidth=2,
edgecolor=colors[int(class_pred)],
facecolor="none",
)
# Add the patch to the Axes
ax.add_patch(rect)
plt.text(
upper_left_x * width,
upper_left_y * height,
s=class_labels[int(class_pred)],
color="white",
verticalalignment="top",
bbox={"color": colors[int(class_pred)], "pad": 0},
)
# plt.show()
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
ax.axis('off')
plt.savefig('inference.png')
# Inference function
def inference(inp_image):
inp_image=inp_image
org_image = inp_image
transform = test_transforms
x = transform(image=inp_image)["image"]
x=x.unsqueeze(0)
# Perform inference
device = "cpu"
model.to(device)
# Ensure model is in evaluation mode
model.eval()
# Perform inference
with torch.no_grad():
out = model(x)
#out = model(x)
# Ensure model is in evaluation mode
bboxes = [[] for _ in range(x.shape[0])]
for i in range(3):
batch_size, A, S, _, _ = out[i].shape
anchor = scaled_anchors[i]
boxes_scale_i = cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=0.5, threshold=0.6, box_format="midpoint",
)
# print(nms_boxes[0])
width_ratio = org_image.shape[1] / 416
height_ratio = org_image.shape[0] / 416
plot_image(org_image, nms_boxes)
plotted_img = 'inference.png'
return plotted_img
inputs = gr.inputs.Image(label="Original Image")
outputs = gr.outputs.Image(type="pil",label="Output Image")
title = "YOLOv3 model trained on PASCAL VOC Dataset"
description = "YOLOv3 object detection using Gradio demo"
examples = [['examples/car.jpg'], ['examples/home.jpg'],['examples/train.jpg'],['examples/train_persons.jpg']]
gr.Interface(inference, inputs, outputs, title=title, examples=examples, description=description, theme='xiaobaiyuan/theme_brief').launch(
debug=False)
|