File size: 4,065 Bytes
64797fe
 
 
 
 
 
 
 
54298d7
64797fe
 
 
4c1bb98
64797fe
 
 
 
4c1bb98
de556a9
64797fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a88fe8
64797fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a88fe8
64797fe
 
 
 
 
de556a9
324c550
de556a9
64797fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import numpy as np
import cv2
import torch
from torchvision import datasets, transforms
from PIL import Image
#from train import YOLOv3Lightning
from utils import non_max_suppression, plot_image, cells_to_bboxes
#from dataset import YOLODataset
import config
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model import YoloVersion3
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Load the model
model = YoloVersion3( )
model.load_state_dict(torch.load('Yolov3.pth', map_location=torch.device('cpu')), strict=False)
model.eval()

# Anchor
scaled_anchors = (
    torch.tensor(config.ANCHORS)
    * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to("cpu")


test_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=416),
        A.PadIfNeeded(
            min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
        ),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ]
)

def plot_image(image, boxes):
    """Plots predicted bounding boxes on the image"""
    cmap = plt.get_cmap("tab20b")
    class_labels = config.PASCAL_CLASSES
    colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
    im = np.array(image)
    height, width, _ = im.shape

    # Create figure and axes
    fig, ax = plt.subplots(1)
    # Display the image
    ax.imshow(im)

    # Create a Rectangle patch
    for box in boxes:
        assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
        class_pred = box[0]
        box = box[2:]
        upper_left_x = box[0] - box[2] / 2
        upper_left_y = box[1] - box[3] / 2
        rect = patches.Rectangle(
            (upper_left_x * width, upper_left_y * height),
            box[2] * width,
            box[3] * height,
            linewidth=2,
            edgecolor=colors[int(class_pred)],
            facecolor="none",
        )
        # Add the patch to the Axes
        ax.add_patch(rect)
        plt.text(
            upper_left_x * width,
            upper_left_y * height,
            s=class_labels[int(class_pred)],
            color="white",
            verticalalignment="top",
            bbox={"color": colors[int(class_pred)], "pad": 0},
        )

    # plt.show()
        fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
        ax.axis('off')
        plt.savefig('inference.png')


# Inference function
def inference(inp_image):
    inp_image=inp_image
    org_image = inp_image
    transform = test_transforms
    x = transform(image=inp_image)["image"]
    x=x.unsqueeze(0)
        # Perform inference
    device = "cpu"
    model.to(device)

    # Ensure model is in evaluation mode
    model.eval()

    # Perform inference
    with torch.no_grad():
        out = model(x) 
    #out = model(x)

    # Ensure model is in evaluation mode



    bboxes = [[] for _ in range(x.shape[0])]
    
    for i in range(3):
        batch_size, A, S, _, _ = out[i].shape
        anchor = scaled_anchors[i]
        boxes_scale_i = cells_to_bboxes(
            out[i], anchor, S=S, is_preds=True
        )
        for idx, (box) in enumerate(boxes_scale_i):
            bboxes[idx] += box
    
    nms_boxes = non_max_suppression(
        bboxes[0], iou_threshold=0.5, threshold=0.6, box_format="midpoint",
    )

    # print(nms_boxes[0])

    width_ratio = org_image.shape[1] / 416
    height_ratio = org_image.shape[0] / 416

    plot_image(org_image, nms_boxes)
    plotted_img = 'inference.png'
    return plotted_img

inputs = gr.inputs.Image(label="Original Image")
outputs = gr.outputs.Image(type="pil",label="Output Image")
title = "YOLOv3 model trained on PASCAL VOC Dataset"
description = "YOLOv3 object detection using Gradio demo"
examples = [['examples/car.jpg'], ['examples/home.jpg'],['examples/train.jpg'],['examples/train_persons.jpg']]
gr.Interface(inference, inputs, outputs, title=title,  examples=examples, description=description, theme='xiaobaiyuan/theme_brief').launch(
    debug=False)