Gosula commited on
Commit
475ae01
1 Parent(s): 2f57795

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ from torchvision import datasets, transforms
6
+ from PIL import Image
7
+ #from train import YOLOv3Lightning
8
+ from utils import non_max_suppression, plot_image, cells_to_bboxes
9
+ from dataset import YOLODataset
10
+ import config
11
+ import albumentations as A
12
+ from albumentations.pytorch import ToTensorV2
13
+
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.patches as patches
16
+
17
+
18
+ # Load the model
19
+ model = YoloVersion3( )
20
+ model.load_state_dict(torch.load('/content/drive/MyDrive/sunandini/Checkpoint/lightning_logs/version_4/checkpoints/Yolov3.pth', map_location=torch.device('cpu')), strict=False)
21
+ model.eval()
22
+
23
+ # Anchor
24
+ scaled_anchors = (
25
+ torch.tensor(config.ANCHORS)
26
+ * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
27
+ ).to("cpu")
28
+
29
+
30
+ test_transforms = A.Compose(
31
+ [
32
+ A.LongestMaxSize(max_size=416),
33
+ A.PadIfNeeded(
34
+ min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
35
+ ),
36
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
37
+ ToTensorV2(),
38
+ ]
39
+ )
40
+
41
+
42
+ def plot_image(image, boxes):
43
+ """Plots predicted bounding boxes on the image"""
44
+ cmap = plt.get_cmap("tab20b")
45
+ class_labels = config.PASCAL_CLASSES
46
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
47
+ im = np.array(image)
48
+ height, width, _ = im.shape
49
+
50
+ # Create figure and axes
51
+ fig, ax = plt.subplots(1)
52
+ # Display the image
53
+ ax.imshow(im)
54
+
55
+ # Create a Rectangle patch
56
+ for box in boxes:
57
+ assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
58
+ class_pred = box[0]
59
+ box = box[2:]
60
+ upper_left_x = box[0] - box[2] / 2
61
+ upper_left_y = box[1] - box[3] / 2
62
+ rect = patches.Rectangle(
63
+ (upper_left_x * width, upper_left_y * height),
64
+ box[2] * width,
65
+ box[3] * height,
66
+ linewidth=2,
67
+ edgecolor=colors[int(class_pred)],
68
+ facecolor="none",
69
+ )
70
+ # Add the patch to the Axes
71
+ ax.add_patch(rect)
72
+ plt.text(
73
+ upper_left_x * width,
74
+ upper_left_y * height,
75
+ s=class_labels[int(class_pred)],
76
+ color="white",
77
+ verticalalignment="top",
78
+ bbox={"color": colors[int(class_pred)], "pad": 0},
79
+ )
80
+
81
+ # plt.show()
82
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
83
+ ax.axis('off')
84
+ plt.savefig('inference.png')
85
+
86
+
87
+ # Inference function
88
+ def inference(inp_image):
89
+ inp_image=inp_image
90
+ org_image = inp_image
91
+ transform = test_transforms
92
+ x = transform(image=inp_image)["image"]
93
+ x=x.unsqueeze(0)
94
+ # Perform inference
95
+ device = "cpu"
96
+ model.to(device)
97
+
98
+ # Ensure model is in evaluation mode
99
+ model.eval()
100
+
101
+ # Perform inference
102
+ with torch.no_grad():
103
+ out = model(x)
104
+ #out = model(x)
105
+
106
+ # Ensure model is in evaluation mode
107
+
108
+
109
+
110
+ bboxes = [[] for _ in range(x.shape[0])]
111
+
112
+ for i in range(3):
113
+ batch_size, A, S, _, _ = out[i].shape
114
+ anchor = scaled_anchors[i]
115
+ boxes_scale_i = cells_to_bboxes(
116
+ out[i], anchor, S=S, is_preds=True
117
+ )
118
+ for idx, (box) in enumerate(boxes_scale_i):
119
+ bboxes[idx] += box
120
+
121
+ nms_boxes = non_max_suppression(
122
+ bboxes[0], iou_threshold=0.5, threshold=0.6, box_format="midpoint",
123
+ )
124
+
125
+ # print(nms_boxes[0])
126
+
127
+ width_ratio = org_image.shape[1] / 416
128
+ height_ratio = org_image.shape[0] / 416
129
+
130
+
131
+
132
+ plot_image(org_image, nms_boxes)
133
+ plotted_img = 'inference.png'
134
+ return plotted_img
135
+
136
+ inputs = gr.inputs.Image(label="Original Image")
137
+ outputs = gr.outputs.Image(type="pil",label="Output Image")
138
+ title = "YOLOv3 model trained on PASCAL VOC Dataset"
139
+ description = "YOLOv3 Gradio demo for object detection"
140
+ examples = [['/content/car1.jpg'], ['/content/home.jpg']]
141
+ gr.Interface(inference, inputs, outputs, title=title, examples=examples, description=description, theme='abidlabs/dracula_revamped').launch(
142
+ debug=False)