yuragoithf commited on
Commit
14bf43e
1 Parent(s): e3c83f9

Updated app

Browse files
Files changed (2) hide show
  1. app.py +103 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoFeatureExtractor, YolosForObjectDetection
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ import io
7
+ import numpy as np
8
+
9
+
10
+ COLORS = [
11
+ [0.000, 0.447, 0.741],
12
+ [0.850, 0.325, 0.098],
13
+ [0.929, 0.694, 0.125],
14
+ [0.494, 0.184, 0.556],
15
+ [0.466, 0.674, 0.188],
16
+ [0.301, 0.745, 0.933],
17
+ ]
18
+
19
+
20
+ def process_class_list(classes_string: str):
21
+ if classes_string == "":
22
+ return []
23
+ classes_list = classes_string.split(",")
24
+ classes_list = [x.strip() for x in classes_list]
25
+ return classes_list
26
+
27
+
28
+ def model_inference(img, model_name: str, prob_threshold: int, classes_to_show=str):
29
+ feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/{model_name}")
30
+ model = YolosForObjectDetection.from_pretrained(f"hustvl/{model_name}")
31
+
32
+ img = Image.fromarray(img)
33
+
34
+ pixel_values = feature_extractor(img, return_tensors="pt").pixel_values
35
+
36
+ with torch.no_grad():
37
+ outputs = model(pixel_values, output_attentions=True)
38
+
39
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
40
+ keep = probas.max(-1).values > prob_threshold
41
+
42
+ target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
43
+ postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
44
+ bboxes_scaled = postprocessed_outputs[0]["boxes"]
45
+
46
+ classes_list = process_class_list(classes_to_show)
47
+ res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
48
+
49
+ return res_img
50
+
51
+
52
+ def plot_results(pil_img, prob, boxes, model, classes_list):
53
+ plt.figure(figsize=(16, 10))
54
+ plt.imshow(pil_img)
55
+ ax = plt.gca()
56
+ colors = COLORS * 100
57
+ for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
58
+ cl = p.argmax()
59
+ object_class = model.config.id2label[cl.item()]
60
+
61
+ if len(classes_list) > 0:
62
+ if object_class not in classes_list:
63
+ continue
64
+
65
+ ax.add_patch(
66
+ plt.Rectangle(
67
+ (xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3
68
+ )
69
+ )
70
+ text = f"{object_class}: {p[cl]:0.2f}"
71
+ ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
72
+ plt.axis("off")
73
+ return fig2img(plt.gcf())
74
+
75
+
76
+ def fig2img(fig):
77
+ buf = io.BytesIO()
78
+ fig.savefig(buf)
79
+ buf.seek(0)
80
+ img = Image.open(buf)
81
+ return img
82
+
83
+
84
+ description = """Object Detection"""
85
+
86
+ image_in = gr.components.Image()
87
+ image_out = gr.components.Image()
88
+ model_choice = "yolos-small-dwr"
89
+ prob_threshold_slider = gr.components.Slider(
90
+ minimum=0, maximum=1.0, step=0.01, value=0.9, label="Probability Threshold"
91
+ )
92
+ classes_to_show = gr.components.Textbox(
93
+ placeholder="e.g. car, dog",
94
+ label="Classes to filter (leave empty to detect all classes)",
95
+ )
96
+
97
+ Iface = gr.Interface(
98
+ fn=model_inference,
99
+ inputs=[image_in, model_choice, prob_threshold_slider, classes_to_show],
100
+ outputs=image_out,
101
+ title="Object Detection",
102
+ description=description,
103
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ matplotlib
4
+ pillow