Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import mediapipe as mp | |
import gradio as gr | |
import numpy as np | |
import torch | |
from torch import nn | |
import cv2 as cv | |
mp_drawing = mp.solutions.drawing_utils | |
mp_holistic = mp.solutions.holistic | |
frame_size = (350, 200) | |
NUM_FRAMES = 30 | |
device = "cpu" | |
unique_symbols = [' ', '!', '(', ')', ',', '-', '0', '1', '2', '4', '5', '6', '7', ':', ';', '?', 'D', 'M', 'a', 'd', | |
'k', 'l', 'n', 'o', 's', 'no_event', 'Ё', 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'Й', 'К', 'Л', | |
'М', 'Н', 'О', 'П', 'Р', 'С', 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Ъ', 'Ы', 'Ь', 'Э', 'Ю', 'Я', | |
'а', 'б', 'в', 'г', 'д', 'е', 'ж', 'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у', | |
'ф', 'х', 'ц', 'ч', 'ш', 'щ', 'ъ', 'ы', 'ь', 'э', 'ю', 'я', 'ё', "#", "<", ">"] | |
label2idx = {unique_symbols[i]: i for i in range(len(unique_symbols))} | |
idx2label = {i: unique_symbols[i] for i in range(len(unique_symbols))} | |
bos_token = "<" | |
eos_token = ">" | |
pad_token = "#" | |
class TokenEmbedding(nn.Module): | |
def __init__(self, num_vocab=1000, maxlen=100, num_hid=64): | |
super().__init__() | |
self.emb = nn.Embedding(num_vocab, num_hid) | |
self.pos_emb = nn.Embedding(maxlen, num_hid) | |
def forward(self, x): | |
maxlen = x.size()[-1] | |
x = self.emb(x) | |
positions = torch.arange(start=0, end=maxlen).to(device) | |
positions = self.pos_emb(positions) | |
return x + positions | |
class LandmarkEmbedding(nn.Module): | |
def __init__(self, in_ch, num_hid=64): | |
super().__init__() | |
self.emb = nn.Sequential( | |
nn.Conv1d(in_channels=in_ch, out_channels=num_hid, kernel_size=11, padding="same"), | |
nn.ReLU(), | |
nn.Conv1d(in_channels=num_hid, out_channels=num_hid, kernel_size=11, padding="same"), | |
nn.ReLU(), | |
nn.Conv1d(in_channels=num_hid, out_channels=num_hid, kernel_size=11, padding="same"), | |
nn.ReLU() | |
) | |
def forward(self, x): | |
x = x.permute(0, 2, 1) | |
x = self.emb(x) | |
x = x.permute(0, 2, 1) | |
return x | |
class TransformerEncoder(nn.Module): | |
def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1): | |
super().__init__() | |
self.att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True) | |
self.ffn = nn.Sequential( | |
nn.Linear(in_features=embed_dim, out_features=feed_forward_dim), | |
nn.ReLU(), | |
nn.Linear(in_features=feed_forward_dim, out_features=embed_dim) | |
) | |
self.layernorm1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6) | |
self.layernorm2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6) | |
self.dropout1 = nn.Dropout(rate) | |
self.dropout2 = nn.Dropout(rate) | |
def forward(self, inputs): | |
attn_output = self.att(inputs, inputs, inputs)[0] | |
attn_output = self.dropout1(attn_output) | |
out1 = self.layernorm1(inputs + attn_output) | |
ffn_output = self.ffn(out1) | |
ffn_output = self.dropout2(ffn_output) | |
return self.layernorm2(out1 + ffn_output) | |
class TransformerDecoder(nn.Module): | |
def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1): | |
super().__init__() | |
self.num_heads = num_heads | |
self.layernorm1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6) | |
self.layernorm2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6) | |
self.layernorm3 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6) | |
self.self_att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True) | |
self.enc_att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True) | |
self.self_dropout = nn.Dropout(0.5) | |
self.enc_dropout = nn.Dropout(0.1) | |
self.ffn_dropout = nn.Dropout(0.1) | |
self.ffn = nn.Sequential( | |
nn.Linear(in_features=embed_dim, out_features=feed_forward_dim), | |
nn.ReLU(), | |
nn.Linear(in_features=feed_forward_dim, out_features=embed_dim) | |
) | |
def causal_attention_mask(self, batch_size, n_dest, n_src, dtype): | |
"""Masks the upper half of the dot product matrix in self attention. | |
This prevents flow of information from future tokens to current token. | |
1's in the lower triangle, counting from the lower right corner. | |
""" | |
i = torch.arange(start=0, end=n_dest)[:, None] | |
j = torch.arange(start=0, end=n_src) | |
m = i >= j - n_src + n_dest | |
mask = m.type(dtype) | |
mask = torch.reshape(mask, [1, n_dest, n_src]) | |
batch_size = torch.LongTensor([batch_size]) | |
mult = torch.cat((batch_size * self.num_heads, torch.ones(1, 2).type(torch.int32).squeeze(0)), axis=0) | |
mult = tuple(mult.detach().cpu().numpy()) | |
return torch.tile(mask, mult).to(device) | |
def forward(self, enc_out, target): | |
input_shape = target.size() | |
batch_size = input_shape[0] | |
seq_len = input_shape[1] | |
causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, torch.bool) | |
target_att = self.self_att(target, target, target, is_causal=True)[0] | |
self_dropout = self.self_dropout(target_att) | |
target_norm = self.layernorm1(target + self_dropout) | |
enc_out = self.enc_att(target_norm, enc_out, enc_out)[0] | |
enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm) | |
ffn_out = self.ffn(enc_out_norm) | |
ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out)) | |
return ffn_out_norm | |
class Transformer(nn.Module): | |
def __init__( | |
self, | |
num_hid=64, | |
num_head=2, | |
num_feed_forward=128, | |
target_maxlen=100, | |
num_layers_enc=4, | |
num_layers_dec=1, | |
num_classes=10, | |
in_ch=126 | |
): | |
super().__init__() | |
self.num_layers_enc = num_layers_enc | |
self.num_layers_dec = num_layers_dec | |
self.target_maxlen = target_maxlen | |
self.num_classes = num_classes | |
self.enc_input = LandmarkEmbedding(in_ch=in_ch, num_hid=num_hid) | |
self.dec_input = TokenEmbedding( | |
num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid | |
) | |
list_encoder = [self.enc_input] + [ | |
TransformerEncoder(num_hid, num_head, num_feed_forward) | |
for _ in range(num_layers_enc) | |
] | |
self.encoder = nn.Sequential(*list_encoder) | |
for i in range(num_layers_dec): | |
setattr( | |
self, | |
f"dec_layer_{i}", | |
TransformerDecoder(num_hid, num_head, num_feed_forward), | |
) | |
self.classifier = nn.Linear(in_features=num_hid, out_features=num_classes) | |
def decode(self, enc_out, target): | |
y = self.dec_input(target) | |
for i in range(self.num_layers_dec): | |
y = getattr(self, f"dec_layer_{i}")(enc_out, y) | |
return y | |
def forward(self, source, target): | |
x = self.encoder(source) | |
y = self.decode(x, target) | |
y = self.classifier(y) | |
return y | |
def generate(self, source, target_start_token_idx): | |
"""Performs inference over one batch of inputs using greedy decoding.""" | |
bs = source.size()[0] | |
enc = self.encoder(source) | |
dec_input = torch.ones((bs, 1), dtype=torch.int32) * target_start_token_idx | |
dec_input = dec_input.to(device) | |
dec_logits = [] | |
for i in range(self.target_maxlen - 1): | |
dec_out = self.decode(enc, dec_input) | |
logits = self.classifier(dec_out) | |
logits = torch.argmax(logits, dim=-1).type(torch.int32) | |
# last_logit = tf.expand_dims(logits[:, -1], axis=-1) | |
last_logit = logits[:, -1].unsqueeze(0) | |
dec_logits.append(last_logit) | |
dec_input = torch.concat([dec_input, last_logit], axis=-1) | |
dec_input = dec_input.squeeze(0).cpu() | |
return dec_input | |
model = torch.load("weights.pt", map_location=torch.device('cpu')) | |
model.eval() | |
def predict(inp): | |
x = torch.from_numpy(inp).to(device) | |
enc_out = model.generate(x.unsqueeze(0), label2idx[bos_token]).numpy() | |
res1 = "" | |
for p in enc_out: | |
res1 += idx2label[p] | |
if p == label2idx[eos_token]: | |
break | |
print(f"prediction: {res1}\n") | |
def mediapipe_detection(image, model, show_landmarks): | |
image = cv.cvtColor(image, cv.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB | |
image = cv.flip(image, 1) | |
image.flags.writeable = False # Image is no longer writeable | |
results = model.process(image) # Make prediction | |
if show_landmarks: | |
image.flags.writeable = True # Image is now writeable | |
image = cv.cvtColor(image, cv.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR | |
return image, results | |
def classify_image(inp): | |
cap = cv.VideoCapture(inp) | |
landmark_list = [] | |
frame_counter = 0 | |
with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic: | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv.resize(frame, frame_size) | |
show_landmarks = False # FIX ME | |
image, results = mediapipe_detection(frame, holistic, show_landmarks) | |
# pose | |
try: | |
pose = results.pose_landmarks.landmark | |
pose_mat = list([landmark.x, landmark.y, landmark.z] for landmark in pose[11:17]) | |
mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS, | |
mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=1, circle_radius=2), | |
mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=1, circle_radius=1) | |
) | |
except: | |
pose_mat = [[0, 0, 0]] * 6 | |
# print(pose_show) | |
# left hand | |
try: | |
left = results.left_hand_landmarks.landmark | |
left_mat = list([landmark.x, landmark.y, landmark.z] for landmark in left) | |
if show_landmarks: | |
mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS, | |
mp_drawing.DrawingSpec(color=(121, 22, 76), thickness=1, circle_radius=2), | |
mp_drawing.DrawingSpec(color=(121, 44, 250), thickness=1, circle_radius=1) | |
) | |
except: | |
left_mat = [[0, 0, 0]] * 21 | |
# right hand | |
try: | |
right = results.right_hand_landmarks.landmark | |
right_mat = list([landmark.x, landmark.y, landmark.z] for landmark in right) | |
if show_landmarks: | |
mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS, | |
mp_drawing.DrawingSpec(color=(76, 22, 121), thickness=1, circle_radius=2), | |
mp_drawing.DrawingSpec(color=(44, 250, 44), thickness=1, circle_radius=1) | |
) | |
except: | |
right_mat = [[0, 0, 0]] * 21 | |
iter_landmarks = left_mat + right_mat # + pose_mat | |
landmark_list.append(iter_landmarks) | |
if show_landmarks: | |
plt.imshow(image) | |
plt.show() | |
frame_counter += 1 | |
cap.release() | |
frames = len(landmark_list) | |
if frames < NUM_FRAMES: | |
for i in range(NUM_FRAMES - frames): | |
landmark_list = [landmark_list[0]] + landmark_list | |
elif frames > NUM_FRAMES: | |
start = (frames - NUM_FRAMES) // 2 | |
landmark_list = landmark_list[start:start + NUM_FRAMES] | |
landmark_list = np.array([landmark_list], dtype=np.float32) | |
if landmark_list.shape == (1, 30, 42, 3): | |
landmark_list = landmark_list.reshape(landmark_list.shape[0], landmark_list.shape[1], -1) | |
inp = torch.from_numpy(landmark_list).to(device) | |
# inp = torch.randn(size=[1, 30, 126], dtype=torch.float32) | |
with torch.no_grad(): | |
out = model.generate(inp, label2idx[bos_token]).numpy() | |
res1 = "" | |
for p in out: | |
res1 += idx2label[p] | |
if p == label2idx[eos_token]: | |
break | |
return res1 | |
else: | |
return f'Classification Error {landmark_list.shape}' | |
gr.Interface(fn=classify_image, | |
inputs=gr.Video(height=360, width=480), | |
outputs='text').launch(share=True) | |