Upload yolos_minimal_inference_example.py
Browse files
yolos_minimal_inference_example.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""YOLOS minimal inference example.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/YOLOS/YOLOS_minimal_inference_example.ipynb
|
8 |
+
|
9 |
+
## Set-up environment
|
10 |
+
|
11 |
+
First, we install the HuggingFace Transformers library (from source for now, as the model was just added to the library and not yet included in a new PyPi release).
|
12 |
+
"""
|
13 |
+
|
14 |
+
!pip install -q git+https://github.com/huggingface/transformers.git
|
15 |
+
|
16 |
+
pip install gradio
|
17 |
+
|
18 |
+
import gradio as gr
|
19 |
+
from gradio.mix import Series
|
20 |
+
from PIL import Image
|
21 |
+
import requests
|
22 |
+
from transformers import AutoFeatureExtractor, YolosForObjectDetection
|
23 |
+
import torch
|
24 |
+
import matplotlib.pyplot as plt
|
25 |
+
import cv2
|
26 |
+
|
27 |
+
import os
|
28 |
+
os.getcwd()
|
29 |
+
|
30 |
+
# colors for visualization
|
31 |
+
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
32 |
+
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
33 |
+
|
34 |
+
def plot_results(pil_img, prob, boxes, count):
|
35 |
+
plt.figure(figsize=(16,10))
|
36 |
+
plt.imshow(pil_img)
|
37 |
+
ax = plt.gca()
|
38 |
+
colors = COLORS * 100
|
39 |
+
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
|
40 |
+
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
|
41 |
+
fill=False, color=c, linewidth=3))
|
42 |
+
cl = p.argmax()
|
43 |
+
text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
|
44 |
+
ax.text(xmin, ymin, text, fontsize=15,
|
45 |
+
bbox=dict(facecolor='yellow', alpha=0.5))
|
46 |
+
plt.axis('off')
|
47 |
+
if count < 10:
|
48 |
+
plt.savefig('exp2/frame0%d.png' % count)
|
49 |
+
else: plt.savefig('exp2/frame%d.png' % count)
|
50 |
+
|
51 |
+
model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small")
|
52 |
+
vidcap = cv2.VideoCapture('/content/2022-08-10_ППП-стоянки_кам-3_191356 (online-video-cutter.com).mp4')
|
53 |
+
success,image = vidcap.read()
|
54 |
+
count = 0
|
55 |
+
#path = '/content/cutted'
|
56 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small")
|
57 |
+
|
58 |
+
while success:
|
59 |
+
success,image = vidcap.read()
|
60 |
+
count += 1
|
61 |
+
|
62 |
+
if count%10 == 0:
|
63 |
+
image = Image.fromarray(image)
|
64 |
+
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
|
65 |
+
|
66 |
+
with torch.no_grad():
|
67 |
+
outputs = model(pixel_values, output_attentions=True)
|
68 |
+
|
69 |
+
# keep only predictions of queries with 0.9+ confidence (excluding no-object class)
|
70 |
+
probas = outputs.logits.softmax(-1)[0, :, :-1]
|
71 |
+
keep = probas.max(-1).values > 0.8
|
72 |
+
|
73 |
+
# rescale bounding boxes
|
74 |
+
target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
|
75 |
+
postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
|
76 |
+
bboxes_scaled = postprocessed_outputs[0]['boxes']
|
77 |
+
plot_results(image, probas[keep], bboxes_scaled[keep], count)
|
78 |
+
|
79 |
+
print('Process a new frame: ', success)
|
80 |
+
|
81 |
+
"""Set model and directory parameters:
|
82 |
+
|
83 |
+
Perform sliced inference on given folder:
|
84 |
+
"""
|
85 |
+
|
86 |
+
image_folder = '/content/exp2'
|
87 |
+
file_list = os.listdir(image_folder)
|
88 |
+
|
89 |
+
#grab last 2 characters of the file name:
|
90 |
+
def last_2chars(x):
|
91 |
+
return(x[5:7])
|
92 |
+
|
93 |
+
srtd = sorted(file_list, key = last_2chars)
|
94 |
+
|
95 |
+
video_name = 'video.avi'
|
96 |
+
|
97 |
+
images = [img for img in srtd if img.endswith(".png")]
|
98 |
+
frame = cv2.imread(os.path.join(image_folder, images[0]))
|
99 |
+
height, width, layers = frame.shape
|
100 |
+
|
101 |
+
video = cv2.VideoWriter(video_name, 0, 5, (width,height))
|
102 |
+
|
103 |
+
for image in images:
|
104 |
+
video.write(cv2.imread(os.path.join(image_folder, image)))
|
105 |
+
|
106 |
+
cv2.destroyAllWindows()
|
107 |
+
video.release()
|