derat0r commited on
Commit
b6f3637
1 Parent(s): 325e3b4

Upload n.py

Browse files
Files changed (1) hide show
  1. n.py +324 -0
n.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import mediapipe as mp
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ import cv2 as cv
8
+
9
+ mp_drawing = mp.solutions.drawing_utils
10
+ mp_holistic = mp.solutions.holistic
11
+ frame_size = (350, 200)
12
+ NUM_FRAMES = 30
13
+
14
+ device = "cpu"
15
+ unique_symbols = [' ', '!', '(', ')', ',', '-', '0', '1', '2', '4', '5', '6', '7', ':', ';', '?', 'D', 'M', 'a', 'd',
16
+ 'k', 'l', 'n', 'o', 's', 'no_event', 'Ё', 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'Й', 'К', 'Л',
17
+ 'М', 'Н', 'О', 'П', 'Р', 'С', 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Ъ', 'Ы', 'Ь', 'Э', 'Ю', 'Я',
18
+ 'а', 'б', 'в', 'г', 'д', 'е', 'ж', 'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у',
19
+ 'ф', 'х', 'ц', 'ч', 'ш', 'щ', 'ъ', 'ы', 'ь', 'э', 'ю', 'я', 'ё', "#", "<", ">"]
20
+ label2idx = {unique_symbols[i]: i for i in range(len(unique_symbols))}
21
+ idx2label = {i: unique_symbols[i] for i in range(len(unique_symbols))}
22
+ bos_token = "<"
23
+ eos_token = ">"
24
+ pad_token = "#"
25
+
26
+
27
+ class TokenEmbedding(nn.Module):
28
+ def __init__(self, num_vocab=1000, maxlen=100, num_hid=64):
29
+ super().__init__()
30
+ self.emb = nn.Embedding(num_vocab, num_hid)
31
+ self.pos_emb = nn.Embedding(maxlen, num_hid)
32
+
33
+ def forward(self, x):
34
+ maxlen = x.size()[-1]
35
+ x = self.emb(x)
36
+ positions = torch.arange(start=0, end=maxlen).to(device)
37
+ positions = self.pos_emb(positions)
38
+ return x + positions
39
+
40
+
41
+ class LandmarkEmbedding(nn.Module):
42
+ def __init__(self, in_ch, num_hid=64):
43
+ super().__init__()
44
+ self.emb = nn.Sequential(
45
+ nn.Conv1d(in_channels=in_ch, out_channels=num_hid, kernel_size=11, padding="same"),
46
+ nn.ReLU(),
47
+ nn.Conv1d(in_channels=num_hid, out_channels=num_hid, kernel_size=11, padding="same"),
48
+ nn.ReLU(),
49
+ nn.Conv1d(in_channels=num_hid, out_channels=num_hid, kernel_size=11, padding="same"),
50
+ nn.ReLU()
51
+ )
52
+
53
+ def forward(self, x):
54
+ x = x.permute(0, 2, 1)
55
+ x = self.emb(x)
56
+ x = x.permute(0, 2, 1)
57
+ return x
58
+
59
+
60
+ class TransformerEncoder(nn.Module):
61
+ def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):
62
+ super().__init__()
63
+ self.att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True)
64
+ self.ffn = nn.Sequential(
65
+ nn.Linear(in_features=embed_dim, out_features=feed_forward_dim),
66
+ nn.ReLU(),
67
+ nn.Linear(in_features=feed_forward_dim, out_features=embed_dim)
68
+ )
69
+ self.layernorm1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
70
+ self.layernorm2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
71
+ self.dropout1 = nn.Dropout(rate)
72
+ self.dropout2 = nn.Dropout(rate)
73
+
74
+ def forward(self, inputs):
75
+ attn_output = self.att(inputs, inputs, inputs)[0]
76
+ attn_output = self.dropout1(attn_output)
77
+ out1 = self.layernorm1(inputs + attn_output)
78
+ ffn_output = self.ffn(out1)
79
+ ffn_output = self.dropout2(ffn_output)
80
+ return self.layernorm2(out1 + ffn_output)
81
+
82
+
83
+ class TransformerDecoder(nn.Module):
84
+ def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1):
85
+ super().__init__()
86
+ self.num_heads = num_heads
87
+ self.layernorm1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
88
+ self.layernorm2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
89
+ self.layernorm3 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
90
+ self.self_att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True)
91
+ self.enc_att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True)
92
+ self.self_dropout = nn.Dropout(0.5)
93
+ self.enc_dropout = nn.Dropout(0.1)
94
+ self.ffn_dropout = nn.Dropout(0.1)
95
+ self.ffn = nn.Sequential(
96
+ nn.Linear(in_features=embed_dim, out_features=feed_forward_dim),
97
+ nn.ReLU(),
98
+ nn.Linear(in_features=feed_forward_dim, out_features=embed_dim)
99
+ )
100
+
101
+ def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):
102
+ """Masks the upper half of the dot product matrix in self attention.
103
+
104
+ This prevents flow of information from future tokens to current token.
105
+ 1's in the lower triangle, counting from the lower right corner.
106
+ """
107
+ i = torch.arange(start=0, end=n_dest)[:, None]
108
+ j = torch.arange(start=0, end=n_src)
109
+ m = i >= j - n_src + n_dest
110
+ mask = m.type(dtype)
111
+ mask = torch.reshape(mask, [1, n_dest, n_src])
112
+ batch_size = torch.LongTensor([batch_size])
113
+ mult = torch.cat((batch_size * self.num_heads, torch.ones(1, 2).type(torch.int32).squeeze(0)), axis=0)
114
+ mult = tuple(mult.detach().cpu().numpy())
115
+ return torch.tile(mask, mult).to(device)
116
+
117
+ def forward(self, enc_out, target):
118
+ input_shape = target.size()
119
+ batch_size = input_shape[0]
120
+ seq_len = input_shape[1]
121
+ causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, torch.bool)
122
+ target_att = self.self_att(target, target, target, is_causal=True)[0]
123
+ self_dropout = self.self_dropout(target_att)
124
+ target_norm = self.layernorm1(target + self_dropout)
125
+ enc_out = self.enc_att(target_norm, enc_out, enc_out)[0]
126
+ enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm)
127
+ ffn_out = self.ffn(enc_out_norm)
128
+ ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out))
129
+ return ffn_out_norm
130
+
131
+
132
+ class Transformer(nn.Module):
133
+ def __init__(
134
+ self,
135
+ num_hid=64,
136
+ num_head=2,
137
+ num_feed_forward=128,
138
+ target_maxlen=100,
139
+ num_layers_enc=4,
140
+ num_layers_dec=1,
141
+ num_classes=10,
142
+ in_ch=126
143
+ ):
144
+ super().__init__()
145
+ self.num_layers_enc = num_layers_enc
146
+ self.num_layers_dec = num_layers_dec
147
+ self.target_maxlen = target_maxlen
148
+ self.num_classes = num_classes
149
+
150
+ self.enc_input = LandmarkEmbedding(in_ch=in_ch, num_hid=num_hid)
151
+ self.dec_input = TokenEmbedding(
152
+ num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid
153
+ )
154
+
155
+ list_encoder = [self.enc_input] + [
156
+ TransformerEncoder(num_hid, num_head, num_feed_forward)
157
+ for _ in range(num_layers_enc)
158
+ ]
159
+ self.encoder = nn.Sequential(*list_encoder)
160
+
161
+ for i in range(num_layers_dec):
162
+ setattr(
163
+ self,
164
+ f"dec_layer_{i}",
165
+ TransformerDecoder(num_hid, num_head, num_feed_forward),
166
+ )
167
+
168
+ self.classifier = nn.Linear(in_features=num_hid, out_features=num_classes)
169
+
170
+ def decode(self, enc_out, target):
171
+ y = self.dec_input(target)
172
+ for i in range(self.num_layers_dec):
173
+ y = getattr(self, f"dec_layer_{i}")(enc_out, y)
174
+ return y
175
+
176
+ def forward(self, source, target):
177
+ x = self.encoder(source)
178
+ y = self.decode(x, target)
179
+ y = self.classifier(y)
180
+ return y
181
+
182
+ def generate(self, source, target_start_token_idx):
183
+ """Performs inference over one batch of inputs using greedy decoding."""
184
+ bs = source.size()[0]
185
+ enc = self.encoder(source)
186
+ dec_input = torch.ones((bs, 1), dtype=torch.int32) * target_start_token_idx
187
+ dec_input = dec_input.to(device)
188
+ dec_logits = []
189
+ for i in range(self.target_maxlen - 1):
190
+ dec_out = self.decode(enc, dec_input)
191
+ logits = self.classifier(dec_out)
192
+ logits = torch.argmax(logits, dim=-1).type(torch.int32)
193
+ # last_logit = tf.expand_dims(logits[:, -1], axis=-1)
194
+ last_logit = logits[:, -1].unsqueeze(0)
195
+ dec_logits.append(last_logit)
196
+ dec_input = torch.concat([dec_input, last_logit], axis=-1)
197
+ dec_input = dec_input.squeeze(0).cpu()
198
+ return dec_input
199
+
200
+
201
+ model = torch.load("weights.pt", map_location=torch.device('cpu'))
202
+ model.eval()
203
+
204
+
205
+ def predict(inp):
206
+ x = torch.from_numpy(inp).to(device)
207
+
208
+ enc_out = model.generate(x.unsqueeze(0), label2idx[bos_token]).numpy()
209
+ res1 = ""
210
+ for p in enc_out:
211
+ res1 += idx2label[p]
212
+ if p == label2idx[eos_token]:
213
+ break
214
+ print(f"prediction: {res1}\n")
215
+
216
+
217
+ def mediapipe_detection(image, model, show_landmarks):
218
+ image = cv.cvtColor(image, cv.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB
219
+ image = cv.flip(image, 1)
220
+ image.flags.writeable = False # Image is no longer writeable
221
+ results = model.process(image) # Make prediction
222
+ if show_landmarks:
223
+ image.flags.writeable = True # Image is now writeable
224
+ image = cv.cvtColor(image, cv.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR
225
+ return image, results
226
+
227
+
228
+ def classify_image(inp):
229
+ cap = cv.VideoCapture(inp)
230
+ landmark_list = []
231
+ frame_counter = 0
232
+ with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
233
+ while cap.isOpened():
234
+ ret, frame = cap.read()
235
+
236
+ if not ret:
237
+ break
238
+
239
+ frame = cv.resize(frame, frame_size)
240
+ show_landmarks = False # FIX ME
241
+ image, results = mediapipe_detection(frame, holistic, show_landmarks)
242
+
243
+ # pose
244
+ try:
245
+ pose = results.pose_landmarks.landmark
246
+ pose_mat = list([landmark.x, landmark.y, landmark.z] for landmark in pose[11:17])
247
+
248
+ mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS,
249
+ mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=1, circle_radius=2),
250
+ mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=1, circle_radius=1)
251
+ )
252
+ except:
253
+ pose_mat = [[0, 0, 0]] * 6
254
+ # print(pose_show)
255
+
256
+ # left hand
257
+ try:
258
+ left = results.left_hand_landmarks.landmark
259
+ left_mat = list([landmark.x, landmark.y, landmark.z] for landmark in left)
260
+
261
+ if show_landmarks:
262
+ mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
263
+ mp_drawing.DrawingSpec(color=(121, 22, 76), thickness=1, circle_radius=2),
264
+ mp_drawing.DrawingSpec(color=(121, 44, 250), thickness=1, circle_radius=1)
265
+ )
266
+ except:
267
+ left_mat = [[0, 0, 0]] * 21
268
+
269
+ # right hand
270
+ try:
271
+ right = results.right_hand_landmarks.landmark
272
+ right_mat = list([landmark.x, landmark.y, landmark.z] for landmark in right)
273
+
274
+ if show_landmarks:
275
+ mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
276
+ mp_drawing.DrawingSpec(color=(76, 22, 121), thickness=1, circle_radius=2),
277
+ mp_drawing.DrawingSpec(color=(44, 250, 44), thickness=1, circle_radius=1)
278
+ )
279
+ except:
280
+ right_mat = [[0, 0, 0]] * 21
281
+
282
+ iter_landmarks = left_mat + right_mat # + pose_mat
283
+ landmark_list.append(iter_landmarks)
284
+
285
+ if show_landmarks:
286
+ plt.imshow(image)
287
+ plt.show()
288
+
289
+ frame_counter += 1
290
+
291
+ cap.release()
292
+
293
+ frames = len(landmark_list)
294
+ if frames < NUM_FRAMES:
295
+ for i in range(NUM_FRAMES - frames):
296
+ landmark_list = [landmark_list[0]] + landmark_list
297
+ elif frames > NUM_FRAMES:
298
+ start = (frames - NUM_FRAMES) // 2
299
+ landmark_list = landmark_list[start:start + NUM_FRAMES]
300
+
301
+ landmark_list = np.array([landmark_list], dtype=np.float32)
302
+
303
+ if landmark_list.shape == (1, 30, 42, 3):
304
+ landmark_list = landmark_list.reshape(landmark_list.shape[0], landmark_list.shape[1], -1)
305
+ inp = torch.from_numpy(landmark_list).to(device)
306
+
307
+ # inp = torch.randn(size=[1, 30, 126], dtype=torch.float32)
308
+
309
+ with torch.no_grad():
310
+ out = model.generate(inp, label2idx[bos_token]).numpy()
311
+ res1 = ""
312
+ for p in out:
313
+ res1 += idx2label[p]
314
+ if p == label2idx[eos_token]:
315
+ break
316
+
317
+ return res1
318
+ else:
319
+ return f'Classification Error {landmark_list.shape}'
320
+
321
+
322
+ gr.Interface(fn=classify_image,
323
+ inputs=gr.Video(height=360, width=480),
324
+ outputs='text').launch(share=True)