| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader, random_split |
| import pandas as pd |
| import numpy as np |
| from transformers import BertTokenizer |
| import os |
| from pose_format import Pose |
| import matplotlib.pyplot as plt |
| from matplotlib import animation |
| from fastdtw import fastdtw |
| from scipy.spatial.distance import cosine |
| from config import MAX_TEXT_LEN, TARGET_NUM_FRAMES, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED |
| from transformers import BertModel |
|
|
| tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2") |
|
|
| |
| selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543]) |
| NUM_KEYPOINTS = len(selected_keypoint_indices) |
| POSE_DIM = NUM_KEYPOINTS * 3 |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| class TextToPoseSeq2Seq(nn.Module): |
| def __init__(self, vocab_size, hidden_dim=512, pose_dim=POSE_DIM, max_len=MAX_TEXT_LEN, target_len=TARGET_NUM_FRAMES): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.target_len = target_len |
| self.pose_dim = pose_dim |
|
|
| |
| self.encoder = BertModel.from_pretrained("indobenchmark/indobert-base-p2") |
|
|
| |
| self.input_proj = nn.Linear(pose_dim, hidden_dim) |
| bert_hidden = self.encoder.config.hidden_size |
| self.gru_cell = nn.GRUCell(hidden_dim + bert_hidden, hidden_dim) |
| self.dropout = nn.Dropout(0.3) |
|
|
| self.fc_pose = nn.Linear(hidden_dim, pose_dim) |
| self.fc_conf = nn.Linear(hidden_dim, NUM_KEYPOINTS) |
| self.output_scale = 1.0 |
|
|
| def forward(self, input_ids, attention_mask=None, target_pose=None, teacher_forcing_ratio=TEACHER_FORCING_RATIO): |
| B = input_ids.size(0) |
| pose_outputs = [] |
| conf_outputs = [] |
| input_pose = torch.zeros(B, self.pose_dim).to(input_ids.device) |
|
|
| |
| encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| context = encoder_outputs.last_hidden_state[:, 0, :] |
|
|
| h = torch.zeros(B, self.hidden_dim).to(input_ids.device) |
|
|
| for t in range(self.target_len): |
| use_teacher = self.training and target_pose is not None and torch.rand(1).item() < teacher_forcing_ratio |
| if use_teacher and t > 0: |
| input_pose = target_pose[:, t - 1, :] |
| elif t > 0: |
| input_pose = pose_outputs[-1].squeeze(1).detach() |
|
|
| pose_emb = self.input_proj(input_pose) |
| gru_input = torch.cat([pose_emb, context], dim=-1) |
| h = self.gru_cell(gru_input, h) |
| h = self.dropout(h) |
|
|
| pred_pose = self.fc_pose(h) * self.output_scale |
| pred_conf = torch.sigmoid(self.fc_conf(h)) |
|
|
| pose_outputs.append(pred_pose.unsqueeze(1)) |
| conf_outputs.append(pred_conf.unsqueeze(1)) |
| input_pose = pred_pose.detach() |
|
|
| return torch.cat(pose_outputs, dim=1), torch.cat(conf_outputs, dim=1) |
|
|
| |
| def mpjpe(pred, target, mask=None): |
| |
| pred = pred.view(pred.size(0), pred.size(1), NUM_KEYPOINTS, 3) |
| target = target.view(target.size(0), target.size(1), NUM_KEYPOINTS, 3) |
|
|
| error = torch.norm(pred - target, dim=3) |
|
|
| if mask is not None: |
| mask = mask.view(pred.size(0), pred.size(1), NUM_KEYPOINTS) |
| masked_error = error * mask |
| return masked_error.sum() / (mask.sum() + 1e-8) |
| else: |
| return error.mean() |
|
|
|
|
| def per_joint_mpjpe(pred, target, mask=None): |
| pred = pred.view(-1, NUM_KEYPOINTS, 3) |
| target = target.view(-1, NUM_KEYPOINTS, 3) |
| error = torch.norm(pred - target, dim=2) |
|
|
| if mask is not None: |
| mask = mask.view(-1, NUM_KEYPOINTS) |
| masked_error = error * mask |
| joint_means = masked_error.sum(dim=0) / (mask.sum(dim=0) + 1e-8) |
| return joint_means.cpu().numpy() |
| else: |
| return error.mean(dim=0).cpu().numpy() |
|
|
| def pose_velocity(pose_seq): |
| |
| |
| diffs = pose_seq[:, 1:, :] - pose_seq[:, :-1, :] |
| |
| diffs = diffs.view(diffs.size(0), diffs.size(1), NUM_KEYPOINTS, 3) |
| |
| return torch.norm(diffs, dim=3).mean().item() |
|
|
|
|
| def cosine_similarity(pred, target): |
| |
| pred = pred.view(-1, POSE_DIM).cpu().numpy() |
| target = target.view(-1, POSE_DIM).cpu().numpy() |
| |
| |
| |
| |
| if np.linalg.norm(pred) == 0 or np.linalg.norm(target) == 0: |
| return 0.0 |
| return 1 - cosine(pred.flatten(), target.flatten()) |
|
|
| def dtw_distance(pred, target): |
| |
| |
| |
| pred_seq = pred[0].view(-1, POSE_DIM).cpu().numpy() |
| target_seq = target[0].view(-1, POSE_DIM).cpu().numpy() |
| |
| distance, _ = fastdtw(pred_seq, target_seq, dist=lambda a, b: np.linalg.norm(a - b)) |
| return distance |
|
|