gradio_test_app / app.py
derat0r's picture
Update app.py
ab068e8
import matplotlib.pyplot as plt
import mediapipe as mp
import gradio as gr
import numpy as np
import torch
from torch import nn
import cv2 as cv
mp_drawing = mp.solutions.drawing_utils
mp_holistic = mp.solutions.holistic
frame_size = (350, 200)
NUM_FRAMES = 30
device = "cpu"
unique_symbols = [' ', '!', '(', ')', ',', '-', '0', '1', '2', '4', '5', '6', '7', ':', ';', '?', 'D', 'M', 'a', 'd',
'k', 'l', 'n', 'o', 's', 'no_event', 'Ё', 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'Й', 'К', 'Л',
'М', 'Н', 'О', 'П', 'Р', 'С', 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Ъ', 'Ы', 'Ь', 'Э', 'Ю', 'Я',
'а', 'б', 'в', 'г', 'д', 'е', 'ж', 'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у',
'ф', 'х', 'ц', 'ч', 'ш', 'щ', 'ъ', 'ы', 'ь', 'э', 'ю', 'я', 'ё', "#", "<", ">"]
label2idx = {unique_symbols[i]: i for i in range(len(unique_symbols))}
idx2label = {i: unique_symbols[i] for i in range(len(unique_symbols))}
bos_token = "<"
eos_token = ">"
pad_token = "#"
class TokenEmbedding(nn.Module):
def __init__(self, num_vocab=1000, maxlen=100, num_hid=64):
super().__init__()
self.emb = nn.Embedding(num_vocab, num_hid)
self.pos_emb = nn.Embedding(maxlen, num_hid)
def forward(self, x):
maxlen = x.size()[-1]
x = self.emb(x)
positions = torch.arange(start=0, end=maxlen).to(device)
positions = self.pos_emb(positions)
return x + positions
class LandmarkEmbedding(nn.Module):
def __init__(self, in_ch, num_hid=64):
super().__init__()
self.emb = nn.Sequential(
nn.Conv1d(in_channels=in_ch, out_channels=num_hid, kernel_size=11, padding="same"),
nn.ReLU(),
nn.Conv1d(in_channels=num_hid, out_channels=num_hid, kernel_size=11, padding="same"),
nn.ReLU(),
nn.Conv1d(in_channels=num_hid, out_channels=num_hid, kernel_size=11, padding="same"),
nn.ReLU()
)
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.emb(x)
x = x.permute(0, 2, 1)
return x
class TransformerEncoder(nn.Module):
def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):
super().__init__()
self.att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(in_features=embed_dim, out_features=feed_forward_dim),
nn.ReLU(),
nn.Linear(in_features=feed_forward_dim, out_features=embed_dim)
)
self.layernorm1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
self.layernorm2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
self.dropout1 = nn.Dropout(rate)
self.dropout2 = nn.Dropout(rate)
def forward(self, inputs):
attn_output = self.att(inputs, inputs, inputs)[0]
attn_output = self.dropout1(attn_output)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output)
return self.layernorm2(out1 + ffn_output)
class TransformerDecoder(nn.Module):
def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1):
super().__init__()
self.num_heads = num_heads
self.layernorm1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
self.layernorm2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
self.layernorm3 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
self.self_att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True)
self.enc_att = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embed_dim, batch_first=True)
self.self_dropout = nn.Dropout(0.5)
self.enc_dropout = nn.Dropout(0.1)
self.ffn_dropout = nn.Dropout(0.1)
self.ffn = nn.Sequential(
nn.Linear(in_features=embed_dim, out_features=feed_forward_dim),
nn.ReLU(),
nn.Linear(in_features=feed_forward_dim, out_features=embed_dim)
)
def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):
"""Masks the upper half of the dot product matrix in self attention.
This prevents flow of information from future tokens to current token.
1's in the lower triangle, counting from the lower right corner.
"""
i = torch.arange(start=0, end=n_dest)[:, None]
j = torch.arange(start=0, end=n_src)
m = i >= j - n_src + n_dest
mask = m.type(dtype)
mask = torch.reshape(mask, [1, n_dest, n_src])
batch_size = torch.LongTensor([batch_size])
mult = torch.cat((batch_size * self.num_heads, torch.ones(1, 2).type(torch.int32).squeeze(0)), axis=0)
mult = tuple(mult.detach().cpu().numpy())
return torch.tile(mask, mult).to(device)
def forward(self, enc_out, target):
input_shape = target.size()
batch_size = input_shape[0]
seq_len = input_shape[1]
causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, torch.bool)
target_att = self.self_att(target, target, target, is_causal=True)[0]
self_dropout = self.self_dropout(target_att)
target_norm = self.layernorm1(target + self_dropout)
enc_out = self.enc_att(target_norm, enc_out, enc_out)[0]
enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm)
ffn_out = self.ffn(enc_out_norm)
ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out))
return ffn_out_norm
class Transformer(nn.Module):
def __init__(
self,
num_hid=64,
num_head=2,
num_feed_forward=128,
target_maxlen=100,
num_layers_enc=4,
num_layers_dec=1,
num_classes=10,
in_ch=126
):
super().__init__()
self.num_layers_enc = num_layers_enc
self.num_layers_dec = num_layers_dec
self.target_maxlen = target_maxlen
self.num_classes = num_classes
self.enc_input = LandmarkEmbedding(in_ch=in_ch, num_hid=num_hid)
self.dec_input = TokenEmbedding(
num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid
)
list_encoder = [self.enc_input] + [
TransformerEncoder(num_hid, num_head, num_feed_forward)
for _ in range(num_layers_enc)
]
self.encoder = nn.Sequential(*list_encoder)
for i in range(num_layers_dec):
setattr(
self,
f"dec_layer_{i}",
TransformerDecoder(num_hid, num_head, num_feed_forward),
)
self.classifier = nn.Linear(in_features=num_hid, out_features=num_classes)
def decode(self, enc_out, target):
y = self.dec_input(target)
for i in range(self.num_layers_dec):
y = getattr(self, f"dec_layer_{i}")(enc_out, y)
return y
def forward(self, source, target):
x = self.encoder(source)
y = self.decode(x, target)
y = self.classifier(y)
return y
def generate(self, source, target_start_token_idx):
"""Performs inference over one batch of inputs using greedy decoding."""
bs = source.size()[0]
enc = self.encoder(source)
dec_input = torch.ones((bs, 1), dtype=torch.int32) * target_start_token_idx
dec_input = dec_input.to(device)
dec_logits = []
for i in range(self.target_maxlen - 1):
dec_out = self.decode(enc, dec_input)
logits = self.classifier(dec_out)
logits = torch.argmax(logits, dim=-1).type(torch.int32)
# last_logit = tf.expand_dims(logits[:, -1], axis=-1)
last_logit = logits[:, -1].unsqueeze(0)
dec_logits.append(last_logit)
dec_input = torch.concat([dec_input, last_logit], axis=-1)
dec_input = dec_input.squeeze(0).cpu()
return dec_input
model = torch.load("weights.pt", map_location=torch.device('cpu'))
model.eval()
def predict(inp):
x = torch.from_numpy(inp).to(device)
enc_out = model.generate(x.unsqueeze(0), label2idx[bos_token]).numpy()
res1 = ""
for p in enc_out:
res1 += idx2label[p]
if p == label2idx[eos_token]:
break
print(f"prediction: {res1}\n")
def mediapipe_detection(image, model, show_landmarks):
image = cv.cvtColor(image, cv.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB
image = cv.flip(image, 1)
image.flags.writeable = False # Image is no longer writeable
results = model.process(image) # Make prediction
if show_landmarks:
image.flags.writeable = True # Image is now writeable
image = cv.cvtColor(image, cv.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR
return image, results
def classify_image(inp):
cap = cv.VideoCapture(inp)
landmark_list = []
frame_counter = 0
with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame = cv.resize(frame, frame_size)
show_landmarks = False # FIX ME
image, results = mediapipe_detection(frame, holistic, show_landmarks)
# pose
try:
pose = results.pose_landmarks.landmark
pose_mat = list([landmark.x, landmark.y, landmark.z] for landmark in pose[11:17])
mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS,
mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=1, circle_radius=2),
mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=1, circle_radius=1)
)
except:
pose_mat = [[0, 0, 0]] * 6
# print(pose_show)
# left hand
try:
left = results.left_hand_landmarks.landmark
left_mat = list([landmark.x, landmark.y, landmark.z] for landmark in left)
if show_landmarks:
mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
mp_drawing.DrawingSpec(color=(121, 22, 76), thickness=1, circle_radius=2),
mp_drawing.DrawingSpec(color=(121, 44, 250), thickness=1, circle_radius=1)
)
except:
left_mat = [[0, 0, 0]] * 21
# right hand
try:
right = results.right_hand_landmarks.landmark
right_mat = list([landmark.x, landmark.y, landmark.z] for landmark in right)
if show_landmarks:
mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
mp_drawing.DrawingSpec(color=(76, 22, 121), thickness=1, circle_radius=2),
mp_drawing.DrawingSpec(color=(44, 250, 44), thickness=1, circle_radius=1)
)
except:
right_mat = [[0, 0, 0]] * 21
iter_landmarks = left_mat + right_mat # + pose_mat
landmark_list.append(iter_landmarks)
if show_landmarks:
plt.imshow(image)
plt.show()
frame_counter += 1
cap.release()
frames = len(landmark_list)
if frames < NUM_FRAMES:
for i in range(NUM_FRAMES - frames):
landmark_list = [landmark_list[0]] + landmark_list
elif frames > NUM_FRAMES:
start = (frames - NUM_FRAMES) // 2
landmark_list = landmark_list[start:start + NUM_FRAMES]
landmark_list = np.array([landmark_list], dtype=np.float32)
if landmark_list.shape == (1, 30, 42, 3):
landmark_list = landmark_list.reshape(landmark_list.shape[0], landmark_list.shape[1], -1)
inp = torch.from_numpy(landmark_list).to(device)
# inp = torch.randn(size=[1, 30, 126], dtype=torch.float32)
with torch.no_grad():
out = model.generate(inp, label2idx[bos_token]).numpy()
res1 = ""
for p in out:
res1 += idx2label[p]
if p == label2idx[eos_token]:
break
return res1
else:
return f'Classification Error {landmark_list.shape}'
gr.Interface(fn=classify_image,
inputs=gr.Video(height=360, width=480),
outputs='text').launch()