Spaces:
Sleeping
Sleeping
Upload n.py
Browse files
n.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import mediapipe as mp
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import cv2 as cv
|
8 |
+
|
9 |
+
mp_drawing = mp.solutions.drawing_utils
|
10 |
+
mp_holistic = mp.solutions.holistic
|
11 |
+
frame_size = (350, 200)
|
12 |
+
NUM_FRAMES = 30
|
13 |
+
|
14 |
+
device = "cpu"
|
15 |
+
unique_symbols = [' ', '!', '(', ')', ',', '-', '0', '1', '2', '4', '5', '6', '7', ':', ';', '?', 'D', 'M', 'a', 'd',
|
16 |
+
'k', 'l', 'n', 'o', 's', 'no_event', 'Ё', 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'Й', 'К', 'Л',
|
17 |
+
'М', 'Н', 'О', 'П', 'Р', 'С', 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Ъ', 'Ы', 'Ь', 'Э', 'Ю', 'Я',
|
18 |
+
'а', 'б', 'в', 'г', 'д', 'е', 'ж', 'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у',
|
19 |
+
'ф', 'х', 'ц', 'ч', 'ш', 'щ', 'ъ', 'ы', 'ь', 'э', 'ю', 'я', 'ё', "#", "<", ">"]
|
20 |
+
label2idx = {unique_symbols[i]: i for i in range(len(unique_symbols))}
|
21 |
+
idx2label = {i: unique_symbols[i] for i in range(len(unique_symbols))}
|
22 |
+
bos_token = "<"
|
23 |
+
eos_token = ">"
|
24 |
+
pad_token = "#"
|
25 |
+
|
26 |
+
|
27 |
+
class TokenEmbedding(nn.Module):
|
28 |
+
def __init__(self, num_vocab=1000, maxlen=100, num_hid=64):
|
29 |
+
super().__init__()
|
30 |
+
self.emb = nn.Embedding(num_vocab, num_hid)
|
31 |
+
self.pos_emb = nn.Embedding(maxlen, num_hid)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
maxlen = x.size()[-1]
|
35 |
+
x = self.emb(x)
|
36 |
+
positions = torch.arange(start=0, end=maxlen).to(device)
|
37 |
+
positions = self.pos_emb(positions)
|
38 |
+
return x + positions
|
39 |
+
|
40 |
+
|
41 |
+
class LandmarkEmbedding(nn.Module):
|
42 |
+
def __init__(self, in_ch, num_hid=64):
|
43 |
+
super().__init__()
|
44 |
+
self.emb = nn.Sequential(
|
45 |
+
nn.Conv1d(in_channels=in_ch, out_channels=num_hid, kernel_size=11, padding="same"),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.Conv1d(in_channels=num_hid, out_channels=num_hid, kernel_size=11, padding="same"),
|
48 |
+
nn.ReLU(),
|
49 |
+
nn.Conv1d(in_channels=num_hid, out_channels=num_hid, kernel_size=11, padding="same"),
|
50 |
+
nn.ReLU()
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = x.permute(0, 2, 1)
|
55 |
+
x = self.emb(x)
|
56 |
+
x = x.permute(0, 2, 1)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class TransformerEncoder(nn.Module):
|
61 |
+
def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):
|
62 |
+
super().__init__()
|
63 |
+
self.att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True)
|
64 |
+
self.ffn = nn.Sequential(
|
65 |
+
nn.Linear(in_features=embed_dim, out_features=feed_forward_dim),
|
66 |
+
nn.ReLU(),
|
67 |
+
nn.Linear(in_features=feed_forward_dim, out_features=embed_dim)
|
68 |
+
)
|
69 |
+
self.layernorm1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
|
70 |
+
self.layernorm2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
|
71 |
+
self.dropout1 = nn.Dropout(rate)
|
72 |
+
self.dropout2 = nn.Dropout(rate)
|
73 |
+
|
74 |
+
def forward(self, inputs):
|
75 |
+
attn_output = self.att(inputs, inputs, inputs)[0]
|
76 |
+
attn_output = self.dropout1(attn_output)
|
77 |
+
out1 = self.layernorm1(inputs + attn_output)
|
78 |
+
ffn_output = self.ffn(out1)
|
79 |
+
ffn_output = self.dropout2(ffn_output)
|
80 |
+
return self.layernorm2(out1 + ffn_output)
|
81 |
+
|
82 |
+
|
83 |
+
class TransformerDecoder(nn.Module):
|
84 |
+
def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1):
|
85 |
+
super().__init__()
|
86 |
+
self.num_heads = num_heads
|
87 |
+
self.layernorm1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
|
88 |
+
self.layernorm2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
|
89 |
+
self.layernorm3 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
|
90 |
+
self.self_att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True)
|
91 |
+
self.enc_att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True)
|
92 |
+
self.self_dropout = nn.Dropout(0.5)
|
93 |
+
self.enc_dropout = nn.Dropout(0.1)
|
94 |
+
self.ffn_dropout = nn.Dropout(0.1)
|
95 |
+
self.ffn = nn.Sequential(
|
96 |
+
nn.Linear(in_features=embed_dim, out_features=feed_forward_dim),
|
97 |
+
nn.ReLU(),
|
98 |
+
nn.Linear(in_features=feed_forward_dim, out_features=embed_dim)
|
99 |
+
)
|
100 |
+
|
101 |
+
def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):
|
102 |
+
"""Masks the upper half of the dot product matrix in self attention.
|
103 |
+
|
104 |
+
This prevents flow of information from future tokens to current token.
|
105 |
+
1's in the lower triangle, counting from the lower right corner.
|
106 |
+
"""
|
107 |
+
i = torch.arange(start=0, end=n_dest)[:, None]
|
108 |
+
j = torch.arange(start=0, end=n_src)
|
109 |
+
m = i >= j - n_src + n_dest
|
110 |
+
mask = m.type(dtype)
|
111 |
+
mask = torch.reshape(mask, [1, n_dest, n_src])
|
112 |
+
batch_size = torch.LongTensor([batch_size])
|
113 |
+
mult = torch.cat((batch_size * self.num_heads, torch.ones(1, 2).type(torch.int32).squeeze(0)), axis=0)
|
114 |
+
mult = tuple(mult.detach().cpu().numpy())
|
115 |
+
return torch.tile(mask, mult).to(device)
|
116 |
+
|
117 |
+
def forward(self, enc_out, target):
|
118 |
+
input_shape = target.size()
|
119 |
+
batch_size = input_shape[0]
|
120 |
+
seq_len = input_shape[1]
|
121 |
+
causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, torch.bool)
|
122 |
+
target_att = self.self_att(target, target, target, is_causal=True)[0]
|
123 |
+
self_dropout = self.self_dropout(target_att)
|
124 |
+
target_norm = self.layernorm1(target + self_dropout)
|
125 |
+
enc_out = self.enc_att(target_norm, enc_out, enc_out)[0]
|
126 |
+
enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm)
|
127 |
+
ffn_out = self.ffn(enc_out_norm)
|
128 |
+
ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out))
|
129 |
+
return ffn_out_norm
|
130 |
+
|
131 |
+
|
132 |
+
class Transformer(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
num_hid=64,
|
136 |
+
num_head=2,
|
137 |
+
num_feed_forward=128,
|
138 |
+
target_maxlen=100,
|
139 |
+
num_layers_enc=4,
|
140 |
+
num_layers_dec=1,
|
141 |
+
num_classes=10,
|
142 |
+
in_ch=126
|
143 |
+
):
|
144 |
+
super().__init__()
|
145 |
+
self.num_layers_enc = num_layers_enc
|
146 |
+
self.num_layers_dec = num_layers_dec
|
147 |
+
self.target_maxlen = target_maxlen
|
148 |
+
self.num_classes = num_classes
|
149 |
+
|
150 |
+
self.enc_input = LandmarkEmbedding(in_ch=in_ch, num_hid=num_hid)
|
151 |
+
self.dec_input = TokenEmbedding(
|
152 |
+
num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid
|
153 |
+
)
|
154 |
+
|
155 |
+
list_encoder = [self.enc_input] + [
|
156 |
+
TransformerEncoder(num_hid, num_head, num_feed_forward)
|
157 |
+
for _ in range(num_layers_enc)
|
158 |
+
]
|
159 |
+
self.encoder = nn.Sequential(*list_encoder)
|
160 |
+
|
161 |
+
for i in range(num_layers_dec):
|
162 |
+
setattr(
|
163 |
+
self,
|
164 |
+
f"dec_layer_{i}",
|
165 |
+
TransformerDecoder(num_hid, num_head, num_feed_forward),
|
166 |
+
)
|
167 |
+
|
168 |
+
self.classifier = nn.Linear(in_features=num_hid, out_features=num_classes)
|
169 |
+
|
170 |
+
def decode(self, enc_out, target):
|
171 |
+
y = self.dec_input(target)
|
172 |
+
for i in range(self.num_layers_dec):
|
173 |
+
y = getattr(self, f"dec_layer_{i}")(enc_out, y)
|
174 |
+
return y
|
175 |
+
|
176 |
+
def forward(self, source, target):
|
177 |
+
x = self.encoder(source)
|
178 |
+
y = self.decode(x, target)
|
179 |
+
y = self.classifier(y)
|
180 |
+
return y
|
181 |
+
|
182 |
+
def generate(self, source, target_start_token_idx):
|
183 |
+
"""Performs inference over one batch of inputs using greedy decoding."""
|
184 |
+
bs = source.size()[0]
|
185 |
+
enc = self.encoder(source)
|
186 |
+
dec_input = torch.ones((bs, 1), dtype=torch.int32) * target_start_token_idx
|
187 |
+
dec_input = dec_input.to(device)
|
188 |
+
dec_logits = []
|
189 |
+
for i in range(self.target_maxlen - 1):
|
190 |
+
dec_out = self.decode(enc, dec_input)
|
191 |
+
logits = self.classifier(dec_out)
|
192 |
+
logits = torch.argmax(logits, dim=-1).type(torch.int32)
|
193 |
+
# last_logit = tf.expand_dims(logits[:, -1], axis=-1)
|
194 |
+
last_logit = logits[:, -1].unsqueeze(0)
|
195 |
+
dec_logits.append(last_logit)
|
196 |
+
dec_input = torch.concat([dec_input, last_logit], axis=-1)
|
197 |
+
dec_input = dec_input.squeeze(0).cpu()
|
198 |
+
return dec_input
|
199 |
+
|
200 |
+
|
201 |
+
model = torch.load("weights.pt", map_location=torch.device('cpu'))
|
202 |
+
model.eval()
|
203 |
+
|
204 |
+
|
205 |
+
def predict(inp):
|
206 |
+
x = torch.from_numpy(inp).to(device)
|
207 |
+
|
208 |
+
enc_out = model.generate(x.unsqueeze(0), label2idx[bos_token]).numpy()
|
209 |
+
res1 = ""
|
210 |
+
for p in enc_out:
|
211 |
+
res1 += idx2label[p]
|
212 |
+
if p == label2idx[eos_token]:
|
213 |
+
break
|
214 |
+
print(f"prediction: {res1}\n")
|
215 |
+
|
216 |
+
|
217 |
+
def mediapipe_detection(image, model, show_landmarks):
|
218 |
+
image = cv.cvtColor(image, cv.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB
|
219 |
+
image = cv.flip(image, 1)
|
220 |
+
image.flags.writeable = False # Image is no longer writeable
|
221 |
+
results = model.process(image) # Make prediction
|
222 |
+
if show_landmarks:
|
223 |
+
image.flags.writeable = True # Image is now writeable
|
224 |
+
image = cv.cvtColor(image, cv.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR
|
225 |
+
return image, results
|
226 |
+
|
227 |
+
|
228 |
+
def classify_image(inp):
|
229 |
+
cap = cv.VideoCapture(inp)
|
230 |
+
landmark_list = []
|
231 |
+
frame_counter = 0
|
232 |
+
with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
|
233 |
+
while cap.isOpened():
|
234 |
+
ret, frame = cap.read()
|
235 |
+
|
236 |
+
if not ret:
|
237 |
+
break
|
238 |
+
|
239 |
+
frame = cv.resize(frame, frame_size)
|
240 |
+
show_landmarks = False # FIX ME
|
241 |
+
image, results = mediapipe_detection(frame, holistic, show_landmarks)
|
242 |
+
|
243 |
+
# pose
|
244 |
+
try:
|
245 |
+
pose = results.pose_landmarks.landmark
|
246 |
+
pose_mat = list([landmark.x, landmark.y, landmark.z] for landmark in pose[11:17])
|
247 |
+
|
248 |
+
mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS,
|
249 |
+
mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=1, circle_radius=2),
|
250 |
+
mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=1, circle_radius=1)
|
251 |
+
)
|
252 |
+
except:
|
253 |
+
pose_mat = [[0, 0, 0]] * 6
|
254 |
+
# print(pose_show)
|
255 |
+
|
256 |
+
# left hand
|
257 |
+
try:
|
258 |
+
left = results.left_hand_landmarks.landmark
|
259 |
+
left_mat = list([landmark.x, landmark.y, landmark.z] for landmark in left)
|
260 |
+
|
261 |
+
if show_landmarks:
|
262 |
+
mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
|
263 |
+
mp_drawing.DrawingSpec(color=(121, 22, 76), thickness=1, circle_radius=2),
|
264 |
+
mp_drawing.DrawingSpec(color=(121, 44, 250), thickness=1, circle_radius=1)
|
265 |
+
)
|
266 |
+
except:
|
267 |
+
left_mat = [[0, 0, 0]] * 21
|
268 |
+
|
269 |
+
# right hand
|
270 |
+
try:
|
271 |
+
right = results.right_hand_landmarks.landmark
|
272 |
+
right_mat = list([landmark.x, landmark.y, landmark.z] for landmark in right)
|
273 |
+
|
274 |
+
if show_landmarks:
|
275 |
+
mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
|
276 |
+
mp_drawing.DrawingSpec(color=(76, 22, 121), thickness=1, circle_radius=2),
|
277 |
+
mp_drawing.DrawingSpec(color=(44, 250, 44), thickness=1, circle_radius=1)
|
278 |
+
)
|
279 |
+
except:
|
280 |
+
right_mat = [[0, 0, 0]] * 21
|
281 |
+
|
282 |
+
iter_landmarks = left_mat + right_mat # + pose_mat
|
283 |
+
landmark_list.append(iter_landmarks)
|
284 |
+
|
285 |
+
if show_landmarks:
|
286 |
+
plt.imshow(image)
|
287 |
+
plt.show()
|
288 |
+
|
289 |
+
frame_counter += 1
|
290 |
+
|
291 |
+
cap.release()
|
292 |
+
|
293 |
+
frames = len(landmark_list)
|
294 |
+
if frames < NUM_FRAMES:
|
295 |
+
for i in range(NUM_FRAMES - frames):
|
296 |
+
landmark_list = [landmark_list[0]] + landmark_list
|
297 |
+
elif frames > NUM_FRAMES:
|
298 |
+
start = (frames - NUM_FRAMES) // 2
|
299 |
+
landmark_list = landmark_list[start:start + NUM_FRAMES]
|
300 |
+
|
301 |
+
landmark_list = np.array([landmark_list], dtype=np.float32)
|
302 |
+
|
303 |
+
if landmark_list.shape == (1, 30, 42, 3):
|
304 |
+
landmark_list = landmark_list.reshape(landmark_list.shape[0], landmark_list.shape[1], -1)
|
305 |
+
inp = torch.from_numpy(landmark_list).to(device)
|
306 |
+
|
307 |
+
# inp = torch.randn(size=[1, 30, 126], dtype=torch.float32)
|
308 |
+
|
309 |
+
with torch.no_grad():
|
310 |
+
out = model.generate(inp, label2idx[bos_token]).numpy()
|
311 |
+
res1 = ""
|
312 |
+
for p in out:
|
313 |
+
res1 += idx2label[p]
|
314 |
+
if p == label2idx[eos_token]:
|
315 |
+
break
|
316 |
+
|
317 |
+
return res1
|
318 |
+
else:
|
319 |
+
return f'Classification Error {landmark_list.shape}'
|
320 |
+
|
321 |
+
|
322 |
+
gr.Interface(fn=classify_image,
|
323 |
+
inputs=gr.Video(height=360, width=480),
|
324 |
+
outputs='text').launch(share=True)
|