File size: 4,643 Bytes
cb670b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7171a3
 
cb670b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7171a3
 
 
 
60b1040
 
 
f7171a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7161b2
f7171a3
 
 
 
 
 
 
f7161b2
f7171a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6193542
f7171a3
f7161b2
f7171a3
 
f7161b2
f7171a3
 
 
 
 
 
 
 
 
 
 
 
 
195912f
f7171a3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from transformers import AutoFeatureExtractor, YolosForObjectDetection
import gradio as gr
from PIL import Image
import torch
import matplotlib.pyplot as plt
import io
import numpy as np
import os
os.system("pip -qq install yoloxdetect==0.0.7")
from yoloxdetect import YoloxDetector

# Images
torch.hub.download_url_to_file('https://tochkanews.ru/wp-content/uploads/2020/09/0.jpg', '1.jpg')
torch.hub.download_url_to_file('https://s.rdrom.ru/1/pubs/4/35893/1906770.jpg', '2.jpg')
torch.hub.download_url_to_file('https://static.mk.ru/upload/entities/2022/04/17/07/articles/detailPicture/5b/39/28/b6/ffb1aa636dd62c30e6ff670f84474f75.jpg', '3.jpg')


COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]




def get_class_list_from_input(classes_string: str):
    if classes_string == "":
        return []
    classes_list = classes_string.split(",")
    classes_list = [x.strip() for x in classes_list]
    return classes_list

def plot_results(pil_img, prob, boxes, model, classes_list):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        cl = p.argmax()
        object_class = model.config.id2label[cl.item()]
        
        if len(classes_list) > 0 :
            if object_class not in classes_list:
                continue
            
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                fill=False, color=c, linewidth=3))
        text = f'{object_class}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    return fig2img(plt.gcf())
    
def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


    
def inference(
    image_path: gr.inputs.Image = None,
    model_path: gr.inputs.Dropdown = 'kadirnar/yolox_s-v0.1.1',
    image_size: gr.inputs.Slider = 640,
    prob_threshold  = 0.8,
    "",
):

    if model_name in ("yolox_s-v0.1.1", "yolox_m-v0.1.1", "yolox_tiny-v0.1.1"): 
        model = YoloxDetector(f"kadirnar/{model_name}", device="cpu", hf_model=True)
        pred = model.predict(image_path=image_path, image_size=image_size)
        return pred

    else: 
        feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/{model_name}")
        model = YolosForObjectDetection.from_pretrained(f"hustvl/{model_name}")
    
        img = Image.fromarray(img)
        
        pixel_values = feature_extractor(img, return_tensors="pt").pixel_values
    
        with torch.no_grad():
            outputs = model(pixel_values, output_attentions=True)
        
        probas = outputs.logits.softmax(-1)[0, :, :-1]
        keep = probas.max(-1).values > prob_threshold
    
        target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
        postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
        bboxes_scaled = postprocessed_outputs[0]['boxes']

        
        classes_list = get_class_list_from_input(classes_to_show)
        res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
        
        return res_img

        
classes_to_show = gr.components.Textbox(placeholder="e.g. person, boat", label="Classes to use (empty means all classes)")

inputs = [
    gr.inputs.Image(type="filepath", label="Input Image"),
    gr.inputs.Dropdown(
        label="Model Path",
        choices=[
            "yolox_s-v0.1.1",
            "yolox_m-v0.1.1",
            "yolox_tiny-v0.1.1",
            "yolos-tiny",
            "yolos-small",
            "yolos-base",
            "yolos-small-300",
            "yolos-small-dwr"
        ],
        default="kadirnar/yolox_s-v0.1.1",
    ),
    gr.inputs.Slider(minimum=0, maximum=1.0, step=0.01, default=0.9, label="Probability Threshold"),
    gr.inputs.Slider(minimum=320, maximum=1280, default=640, step=32, label="Image Size"),
    classes_to_show,
]


outputs = gr.outputs.Image(type="filepath", label="Output Image")

examples = [
    ["1.jpg", "kadirnar/yolox_m-v0.1.1", 0.8, 640, ""],
    ["2.jpg", "kadirnar/yolox_s-v0.1.1", 0.8, 640, ""],
    ["3.jpg", "kadirnar/yolox_tiny-v0.1.1", 0.8, 640, ""],
]

demo_app = gr.Interface(
    fn=inference,
    inputs=inputs,
    outputs=outputs,
    title="Object Detection with YOLO",
    examples=examples,
    cache_examples=True,
    theme='huggingface',
)
demo_app.launch(debug=True, enable_queue=True)