|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader, random_split |
|
import pandas as pd |
|
import numpy as np |
|
from transformers import BertTokenizer |
|
import os |
|
from pose_format import Pose |
|
import matplotlib.pyplot as plt |
|
from matplotlib import animation |
|
from fastdtw import fastdtw |
|
from scipy.spatial.distance import cosine |
|
from config import MAX_TEXT_LEN, TARGET_NUM_FRAMES, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED |
|
from transformers import BertModel |
|
|
|
tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2") |
|
|
|
|
|
selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543]) |
|
NUM_KEYPOINTS = len(selected_keypoint_indices) |
|
POSE_DIM = NUM_KEYPOINTS * 3 |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class TextToPoseSeq2Seq(nn.Module): |
|
def __init__(self, vocab_size, hidden_dim=512, pose_dim=POSE_DIM, max_len=MAX_TEXT_LEN, target_len=TARGET_NUM_FRAMES): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
self.target_len = target_len |
|
self.pose_dim = pose_dim |
|
|
|
|
|
self.encoder = BertModel.from_pretrained("indobenchmark/indobert-base-p2") |
|
|
|
|
|
self.input_proj = nn.Linear(pose_dim, hidden_dim) |
|
bert_hidden = self.encoder.config.hidden_size |
|
self.gru_cell = nn.GRUCell(hidden_dim + bert_hidden, hidden_dim) |
|
self.dropout = nn.Dropout(0.3) |
|
|
|
self.fc_pose = nn.Linear(hidden_dim, pose_dim) |
|
self.fc_conf = nn.Linear(hidden_dim, NUM_KEYPOINTS) |
|
self.output_scale = 1.0 |
|
|
|
def forward(self, input_ids, attention_mask=None, target_pose=None, teacher_forcing_ratio=TEACHER_FORCING_RATIO): |
|
B = input_ids.size(0) |
|
pose_outputs = [] |
|
conf_outputs = [] |
|
input_pose = torch.zeros(B, self.pose_dim).to(input_ids.device) |
|
|
|
|
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
context = encoder_outputs.last_hidden_state[:, 0, :] |
|
|
|
h = torch.zeros(B, self.hidden_dim).to(input_ids.device) |
|
|
|
for t in range(self.target_len): |
|
use_teacher = self.training and target_pose is not None and torch.rand(1).item() < teacher_forcing_ratio |
|
if use_teacher and t > 0: |
|
input_pose = target_pose[:, t - 1, :] |
|
elif t > 0: |
|
input_pose = pose_outputs[-1].squeeze(1).detach() |
|
|
|
pose_emb = self.input_proj(input_pose) |
|
gru_input = torch.cat([pose_emb, context], dim=-1) |
|
h = self.gru_cell(gru_input, h) |
|
h = self.dropout(h) |
|
|
|
pred_pose = self.fc_pose(h) * self.output_scale |
|
pred_conf = torch.sigmoid(self.fc_conf(h)) |
|
|
|
pose_outputs.append(pred_pose.unsqueeze(1)) |
|
conf_outputs.append(pred_conf.unsqueeze(1)) |
|
input_pose = pred_pose.detach() |
|
|
|
return torch.cat(pose_outputs, dim=1), torch.cat(conf_outputs, dim=1) |
|
|
|
|
|
def mpjpe(pred, target, mask=None): |
|
|
|
pred = pred.view(pred.size(0), pred.size(1), NUM_KEYPOINTS, 3) |
|
target = target.view(target.size(0), target.size(1), NUM_KEYPOINTS, 3) |
|
|
|
error = torch.norm(pred - target, dim=3) |
|
|
|
if mask is not None: |
|
mask = mask.view(pred.size(0), pred.size(1), NUM_KEYPOINTS) |
|
masked_error = error * mask |
|
return masked_error.sum() / (mask.sum() + 1e-8) |
|
else: |
|
return error.mean() |
|
|
|
|
|
def per_joint_mpjpe(pred, target, mask=None): |
|
pred = pred.view(-1, NUM_KEYPOINTS, 3) |
|
target = target.view(-1, NUM_KEYPOINTS, 3) |
|
error = torch.norm(pred - target, dim=2) |
|
|
|
if mask is not None: |
|
mask = mask.view(-1, NUM_KEYPOINTS) |
|
masked_error = error * mask |
|
joint_means = masked_error.sum(dim=0) / (mask.sum(dim=0) + 1e-8) |
|
return joint_means.cpu().numpy() |
|
else: |
|
return error.mean(dim=0).cpu().numpy() |
|
|
|
def pose_velocity(pose_seq): |
|
|
|
|
|
diffs = pose_seq[:, 1:, :] - pose_seq[:, :-1, :] |
|
|
|
diffs = diffs.view(diffs.size(0), diffs.size(1), NUM_KEYPOINTS, 3) |
|
|
|
return torch.norm(diffs, dim=3).mean().item() |
|
|
|
|
|
def cosine_similarity(pred, target): |
|
|
|
pred = pred.view(-1, POSE_DIM).cpu().numpy() |
|
target = target.view(-1, POSE_DIM).cpu().numpy() |
|
|
|
|
|
|
|
|
|
if np.linalg.norm(pred) == 0 or np.linalg.norm(target) == 0: |
|
return 0.0 |
|
return 1 - cosine(pred.flatten(), target.flatten()) |
|
|
|
def dtw_distance(pred, target): |
|
|
|
|
|
|
|
pred_seq = pred[0].view(-1, POSE_DIM).cpu().numpy() |
|
target_seq = target[0].view(-1, POSE_DIM).cpu().numpy() |
|
|
|
distance, _ = fastdtw(pred_seq, target_seq, dist=lambda a, b: np.linalg.norm(a - b)) |
|
return distance |
|
|