File size: 3,464 Bytes
05e82a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f3089f
9424cd4
05e82a9
 
 
 
 
d9f5212
9424cd4
05e82a9
 
 
 
 
9424cd4
 
 
 
05e82a9
 
 
 
 
 
 
 
 
 
 
06f784c
05e82a9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

import torch
from transformers import pipeline

from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.patches as patches

from random import choice
import io

detector50 = pipeline(model="facebook/detr-resnet-50")

detector101 = pipeline(model="facebook/detr-resnet-101")


import gradio as gr

COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
            "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
            "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]

fdic = {
    "family" : "Impact",
    "style" : "italic",
    "size" : 15,
    "color" : "yellow",
    "weight" : "bold"
}


def get_figure(in_pil_img, in_results):
    plt.figure(figsize=(16, 10))
    plt.imshow(in_pil_img)
    #pyplot.gcf()
    ax = plt.gca()

    for prediction in in_results:
        selected_color = choice(COLORS)

        x, y = prediction['box']['xmin'], prediction['box']['ymin'],
        w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']

        ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
        ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)

    plt.axis("off")

    return plt.gcf()


def infer(model, in_pil_img):

    results = None
    if model == "detr-resnet-101":
        results = detector101(in_pil_img)
    else:
        results = detector50(in_pil_img)

    figure = get_figure(in_pil_img, results)

    buf = io.BytesIO()
    figure.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    output_pil_img = Image.open(buf)

    return output_pil_img


with gr.Blocks(title="DETR Object Detection by orYx Models") as demo:
    gr.HTML("""
    <style>
    .logo {
        position: absolute;
        top: 10px;
        right: 10px;
        width: 100px; /* Adjust the width of the logo as needed */
        height: auto;
    }
    </style>
    <div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection by orYx Models</div>
    <img class="logo" src="oryx_logo (2).png" alt="Logo">
    <h4 style="color:navy;">1. Select a model.</h4>
    """)
    model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">2. Please upload an image. Or choose one from sample below by clicking on any.</h4>""")


    with gr.Row():
        input_image = gr.Image(label="Input image", type="pil")
        output_image = gr.Image(label="Output image with predicted instances", type="pil")

    gr.Examples(['trees_traffic.jpg',
                 'traffic.jpg',
                 'flyover.jpg',
                 'Saudi_traffic.jpg'], inputs=input_image)

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")

    send_btn = gr.Button("Infer")
    send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
    gr.HTML("""<ul>""")
    gr.HTML("""<li><a href="https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb" target="_blank">Hands-on tutorial for DETR</a>""")
    #gr.HTML("""</ul>""")


#demo.queue()
demo.launch(debug=True)


### EOF ###