Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from torchvision import datasets, transforms
|
6 |
+
from PIL import Image
|
7 |
+
#from train import YOLOv3Lightning
|
8 |
+
from utils import non_max_suppression, plot_image, cells_to_bboxes
|
9 |
+
from dataset import YOLODataset
|
10 |
+
import config
|
11 |
+
import albumentations as A
|
12 |
+
from albumentations.pytorch import ToTensorV2
|
13 |
+
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
import matplotlib.patches as patches
|
16 |
+
|
17 |
+
|
18 |
+
# Load the model
|
19 |
+
model = YoloVersion3( )
|
20 |
+
model.load_state_dict(torch.load('/content/drive/MyDrive/sunandini/Checkpoint/lightning_logs/version_4/checkpoints/Yolov3.pth', map_location=torch.device('cpu')), strict=False)
|
21 |
+
model.eval()
|
22 |
+
|
23 |
+
# Anchor
|
24 |
+
scaled_anchors = (
|
25 |
+
torch.tensor(config.ANCHORS)
|
26 |
+
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
|
27 |
+
).to("cpu")
|
28 |
+
|
29 |
+
|
30 |
+
test_transforms = A.Compose(
|
31 |
+
[
|
32 |
+
A.LongestMaxSize(max_size=416),
|
33 |
+
A.PadIfNeeded(
|
34 |
+
min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
|
35 |
+
),
|
36 |
+
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
|
37 |
+
ToTensorV2(),
|
38 |
+
]
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def plot_image(image, boxes):
|
43 |
+
"""Plots predicted bounding boxes on the image"""
|
44 |
+
cmap = plt.get_cmap("tab20b")
|
45 |
+
class_labels = config.PASCAL_CLASSES
|
46 |
+
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
|
47 |
+
im = np.array(image)
|
48 |
+
height, width, _ = im.shape
|
49 |
+
|
50 |
+
# Create figure and axes
|
51 |
+
fig, ax = plt.subplots(1)
|
52 |
+
# Display the image
|
53 |
+
ax.imshow(im)
|
54 |
+
|
55 |
+
# Create a Rectangle patch
|
56 |
+
for box in boxes:
|
57 |
+
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
|
58 |
+
class_pred = box[0]
|
59 |
+
box = box[2:]
|
60 |
+
upper_left_x = box[0] - box[2] / 2
|
61 |
+
upper_left_y = box[1] - box[3] / 2
|
62 |
+
rect = patches.Rectangle(
|
63 |
+
(upper_left_x * width, upper_left_y * height),
|
64 |
+
box[2] * width,
|
65 |
+
box[3] * height,
|
66 |
+
linewidth=2,
|
67 |
+
edgecolor=colors[int(class_pred)],
|
68 |
+
facecolor="none",
|
69 |
+
)
|
70 |
+
# Add the patch to the Axes
|
71 |
+
ax.add_patch(rect)
|
72 |
+
plt.text(
|
73 |
+
upper_left_x * width,
|
74 |
+
upper_left_y * height,
|
75 |
+
s=class_labels[int(class_pred)],
|
76 |
+
color="white",
|
77 |
+
verticalalignment="top",
|
78 |
+
bbox={"color": colors[int(class_pred)], "pad": 0},
|
79 |
+
)
|
80 |
+
|
81 |
+
# plt.show()
|
82 |
+
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
83 |
+
ax.axis('off')
|
84 |
+
plt.savefig('inference.png')
|
85 |
+
|
86 |
+
|
87 |
+
# Inference function
|
88 |
+
def inference(inp_image):
|
89 |
+
inp_image=inp_image
|
90 |
+
org_image = inp_image
|
91 |
+
transform = test_transforms
|
92 |
+
x = transform(image=inp_image)["image"]
|
93 |
+
x=x.unsqueeze(0)
|
94 |
+
# Perform inference
|
95 |
+
device = "cpu"
|
96 |
+
model.to(device)
|
97 |
+
|
98 |
+
# Ensure model is in evaluation mode
|
99 |
+
model.eval()
|
100 |
+
|
101 |
+
# Perform inference
|
102 |
+
with torch.no_grad():
|
103 |
+
out = model(x)
|
104 |
+
#out = model(x)
|
105 |
+
|
106 |
+
# Ensure model is in evaluation mode
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
bboxes = [[] for _ in range(x.shape[0])]
|
111 |
+
|
112 |
+
for i in range(3):
|
113 |
+
batch_size, A, S, _, _ = out[i].shape
|
114 |
+
anchor = scaled_anchors[i]
|
115 |
+
boxes_scale_i = cells_to_bboxes(
|
116 |
+
out[i], anchor, S=S, is_preds=True
|
117 |
+
)
|
118 |
+
for idx, (box) in enumerate(boxes_scale_i):
|
119 |
+
bboxes[idx] += box
|
120 |
+
|
121 |
+
nms_boxes = non_max_suppression(
|
122 |
+
bboxes[0], iou_threshold=0.5, threshold=0.6, box_format="midpoint",
|
123 |
+
)
|
124 |
+
|
125 |
+
# print(nms_boxes[0])
|
126 |
+
|
127 |
+
width_ratio = org_image.shape[1] / 416
|
128 |
+
height_ratio = org_image.shape[0] / 416
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
plot_image(org_image, nms_boxes)
|
133 |
+
plotted_img = 'inference.png'
|
134 |
+
return plotted_img
|
135 |
+
|
136 |
+
inputs = gr.inputs.Image(label="Original Image")
|
137 |
+
outputs = gr.outputs.Image(type="pil",label="Output Image")
|
138 |
+
title = "YOLOv3 model trained on PASCAL VOC Dataset"
|
139 |
+
description = "YOLOv3 Gradio demo for object detection"
|
140 |
+
examples = [['/content/car1.jpg'], ['/content/home.jpg']]
|
141 |
+
gr.Interface(inference, inputs, outputs, title=title, examples=examples, description=description, theme='abidlabs/dracula_revamped').launch(
|
142 |
+
debug=False)
|