ClassCat commited on
Commit
5e03e65
1 Parent(s): 220a9d6

add app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ from transformers import AutoImageProcessor, AutoModelForObjectDetection
5
+ #from transformers import pipeline
6
+
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.patches as patches
10
+
11
+ import io
12
+ from random import choice
13
+
14
+
15
+ image_processor_tiny = AutoImageProcessor.from_pretrained("hustvl/yolos-tiny")
16
+ model_tiny = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")
17
+
18
+ image_processor_small = AutoImageProcessor.from_pretrained("hustvl/yolos-small")
19
+ model_small = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-small")
20
+
21
+
22
+ import gradio as gr
23
+
24
+
25
+ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
26
+ "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
27
+ "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
28
+
29
+ fdic = {
30
+ "family" : "Impact",
31
+ "style" : "italic",
32
+ "size" : 15,
33
+ "color" : "yellow",
34
+ "weight" : "bold"
35
+ }
36
+
37
+
38
+ def get_figure(in_pil_img, in_results):
39
+ plt.figure(figsize=(16, 10))
40
+ plt.imshow(in_pil_img)
41
+ ax = plt.gca()
42
+
43
+ for score, label, box in zip(in_results["scores"], in_results["labels"], in_results["boxes"]):
44
+ selected_color = choice(COLORS)
45
+
46
+ #box = [round(i, 2) for i in box.tolist()]
47
+ x, y, w, h = int(box[0]), int(box[1]), int(box[2]-box[0]), int(box[3]-box[1])
48
+ print(x, y, w, h)
49
+ ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
50
+ ax.text(x, y, f"{model_tiny.config.id2label[label.item()]}: {round(score.item()*100, 1)}%", fontdict=fdic)
51
+ #print(
52
+ # f"Detected {model_tiny.config.id2label[label.item()]} with confidence "
53
+ # f"{round(score.item(), 3)} at location {box}"
54
+ #)
55
+
56
+ plt.axis("off")
57
+
58
+ return plt.gcf()
59
+
60
+
61
+ def infer(in_model, in_threshold, in_pil_img):
62
+ print(type(in_pil_img))
63
+ print(threshold)
64
+ inputs = image_processor_tiny(images=in_pil_img, return_tensors="pt")
65
+ outputs = model_tiny(**inputs)
66
+
67
+ # convert outputs (bounding boxes and class logits) to COCO API
68
+ target_sizes = torch.tensor([in_pil_img.size[::-1]])
69
+ results = image_processor_tiny.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[
70
+ 0
71
+ ]
72
+ print(results)
73
+
74
+ figure = get_figure(in_pil_img, results)
75
+
76
+ buf = io.BytesIO()
77
+ figure.savefig(buf, bbox_inches='tight')
78
+ buf.seek(0)
79
+ output_pil_img = Image.open(buf)
80
+
81
+ return output_pil_img
82
+
83
+
84
+ #from transformers.models.flava import modeling_flava
85
+ with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="テスト"
86
+ ) as demo:
87
+ #sample_index = gr.State([])
88
+
89
+ gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;">MNIST 分類器</div>')
90
+
91
+ model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50")
92
+
93
+ with gr.Row():
94
+ input_image = gr.Image(label="", type="pil")
95
+ output_image = gr.Image(type="pil")
96
+
97
+
98
+ threshold = gr.Slider(0, 1.0, value=0.9, label='threshold')
99
+
100
+ send_btn = gr.Button("予測する")
101
+ send_btn.click(fn=infer, inputs=[model, threshold, input_image], outputs=[output_image])
102
+
103
+ #demo.queue()
104
+ demo.launch(debug=True)
105
+
106
+
107
+
108
+
109
+ ### EOF ###