diff --git a/.ipynb_checkpoints/README-checkpoint.md b/.ipynb_checkpoints/README-checkpoint.md new file mode 100644 index 0000000000000000000000000000000000000000..290eb050c2de0c8a7305fe604ca61d5cda6a61a9 --- /dev/null +++ b/.ipynb_checkpoints/README-checkpoint.md @@ -0,0 +1,13 @@ +--- +title: EMAGE +emoji: ⚡ +colorFrom: yellow +colorTo: green +sdk: gradio +sdk_version: 4.24.0 +app_file: app.py +pinned: false +license: apache-2.0 +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/.ipynb_checkpoints/app-checkpoint.py b/.ipynb_checkpoints/app-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..11a141959a6f3b4638139e16772a0fa4431e69e4 --- /dev/null +++ b/.ipynb_checkpoints/app-checkpoint.py @@ -0,0 +1,664 @@ +import spaces +import os +# os.system("Xvfb :99 -ac &") +# os.environ["DISPLAY"] = ":99" +import OpenGL.GL as gl +os.environ["PYOPENGL_PLATFORM"] = "egl" +os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1" +import signal +import time +import csv +import sys +import warnings +import random +import gradio as gr +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp +import numpy as np +import time +import pprint +from loguru import logger +import smplx +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from utils import config, logger_tools, other_tools_hf, metric, data_transfer +from dataloaders import data_tools +from dataloaders.build_vocab import Vocab +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from dataloaders.data_tools import joints_list +from utils import rotation_conversions as rc +import soundfile as sf +import librosa + +def inverse_selection_tensor(filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + selected_indices = torch.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + +@spaces.GPU(duration=120) +def test_demo_gpu( + model, vq_model_face, vq_model_upper, vq_model_hands, vq_model_lower, global_motion, smplx_model, + dict_data, + args, + joints, joint_mask_upper, joint_mask_lower, joint_mask_hands, + log_softmax, +): + rank = 0 + other_tools_hf.load_checkpoints(vq_model_face, args.data_path_1 + "pretrained_vq/last_790_face_v2.bin", args.e_name) + other_tools_hf.load_checkpoints(vq_model_upper, args.data_path_1 + "pretrained_vq/upper_vertex_1layer_710.bin", args.e_name) + other_tools_hf.load_checkpoints(vq_model_hands, args.data_path_1 + "pretrained_vq/hands_vertex_1layer_710.bin", args.e_name) + other_tools_hf.load_checkpoints(vq_model_lower, args.data_path_1 + "pretrained_vq/lower_foot_600.bin", args.e_name) + other_tools_hf.load_checkpoints(global_motion, args.data_path_1 + "pretrained_vq/last_1700_foot.bin", args.e_name) + other_tools_hf.load_checkpoints(model, args.test_ckpt, args.g_name) + model.to(rank).eval() + smplx_model.to(rank).eval() + vq_model_face.to(rank).eval() + vq_model_upper.to(rank).eval() + vq_model_hands.to(rank).eval() + vq_model_lower.to(rank).eval() + global_motion.to(rank).eval() + + with torch.no_grad(): + tar_pose_raw = dict_data["pose"] + tar_pose = tar_pose_raw[:, :, :165].to(rank) + tar_contact = tar_pose_raw[:, :, 165:169].to(rank) + tar_trans = dict_data["trans"].to(rank) + tar_exps = dict_data["facial"].to(rank) + in_audio = dict_data["audio"].to(rank) + in_word = None# dict_data["word"].to(rank) + tar_beta = dict_data["beta"].to(rank) + tar_id = dict_data["id"].to(rank).long() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], joints + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + # tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2) + + tar_index_value_face_top = vq_model_face.map2index(tar_pose_face) # bs*n/4 + tar_index_value_upper_top = vq_model_upper.map2index(tar_pose_upper) # bs*n/4 + tar_index_value_hands_top = vq_model_hands.map2index(tar_pose_hands) # bs*n/4 + tar_index_value_lower_top = vq_model_lower.map2index(tar_pose_lower) # bs*n/4 + + latent_face_top = vq_model_face.map2latent(tar_pose_face) # bs*n/4 + latent_upper_top = vq_model_upper.map2latent(tar_pose_upper) # bs*n/4 + latent_hands_top = vq_model_hands.map2latent(tar_pose_hands) # bs*n/4 + latent_lower_top = vq_model_lower.map2latent(tar_pose_lower) # bs*n/4 + + latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) + + index_in = torch.stack([tar_index_value_upper_top, tar_index_value_hands_top, tar_index_value_lower_top], dim=-1).long() + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + + loaded_data = { + "tar_pose_jaw": tar_pose_jaw, + "tar_pose_face": tar_pose_face, + "tar_pose_upper": tar_pose_upper, + "tar_pose_lower": tar_pose_lower, + "tar_pose_hands": tar_pose_hands, + 'tar_pose_leg': tar_pose_leg, + "in_audio": in_audio, + "in_word": in_word, + "tar_trans": tar_trans, + "tar_exps": tar_exps, + "tar_beta": tar_beta, + "tar_pose": tar_pose, + "tar4dis": tar4dis, + "tar_index_value_face_top": tar_index_value_face_top, + "tar_index_value_upper_top": tar_index_value_upper_top, + "tar_index_value_hands_top": tar_index_value_hands_top, + "tar_index_value_lower_top": tar_index_value_lower_top, + "latent_face_top": latent_face_top, + "latent_upper_top": latent_upper_top, + "latent_hands_top": latent_hands_top, + "latent_lower_top": latent_lower_top, + "latent_in": latent_in, + "index_in": index_in, + "tar_id": tar_id, + "latent_all": latent_all, + "tar_pose_6d": tar_pose_6d, + "tar_contact": tar_contact, + } + + mode = 'test' + bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], joints + tar_pose = loaded_data["tar_pose"] + tar_beta = loaded_data["tar_beta"] + in_word =None# loaded_data["in_word"] + tar_exps = loaded_data["tar_exps"] + tar_contact = loaded_data["tar_contact"] + in_audio = loaded_data["in_audio"] + tar_trans = loaded_data["tar_trans"] + + remain = n%8 + if remain != 0: + tar_pose = tar_pose[:, :-remain, :] + tar_beta = tar_beta[:, :-remain, :] + tar_trans = tar_trans[:, :-remain, :] + # in_word = in_word[:, :-remain] + tar_exps = tar_exps[:, :-remain, :] + tar_contact = tar_contact[:, :-remain, :] + n = n - remain + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + + rec_index_all_face = [] + rec_index_all_upper = [] + rec_index_all_lower = [] + rec_index_all_hands = [] + + roundt = (n - args.pre_frames) // (args.pose_length - args.pre_frames) + remain = (n - args.pre_frames) % (args.pose_length - args.pre_frames) + round_l = args.pose_length - args.pre_frames + + for i in range(0, roundt): + # in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+args.pre_frames] + # audio fps is 16000 and pose fps is 30 + in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*args.pre_frames] + in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+args.pre_frames] + mask_val = torch.ones(bs, args.pose_length, args.pose_dims+3+4).float().cuda() + mask_val[:, :args.pre_frames, :] = 0.0 + if i == 0: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+args.pre_frames, :] + else: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+args.pre_frames, :] + # print(latent_all_tmp.shape, latent_last.shape) + latent_all_tmp[:, :args.pre_frames, :] = latent_last[:, -args.pre_frames:, :] + + net_out_val = model( + in_audio = in_audio_tmp, + in_word=None, #in_word_tmp, + mask=mask_val, + in_motion = latent_all_tmp, + in_id = in_id_tmp, + use_attentions=True,) + + if args.cu != 0: + rec_index_upper = log_softmax(net_out_val["cls_upper"]).reshape(-1, args.vae_codebook_size) + _, rec_index_upper = torch.max(rec_index_upper.reshape(-1, args.pose_length, args.vae_codebook_size), dim=2) + #rec_upper = vq_model_upper.decode(rec_index_upper) + else: + _, rec_index_upper, _, _ = vq_model_upper.quantizer(net_out_val["rec_upper"]) + #rec_upper = vq_model_upper.decoder(rec_index_upper) + if args.cl != 0: + rec_index_lower = log_softmax(net_out_val["cls_lower"]).reshape(-1, args.vae_codebook_size) + _, rec_index_lower = torch.max(rec_index_lower.reshape(-1, args.pose_length, args.vae_codebook_size), dim=2) + #rec_lower = vq_model_lower.decode(rec_index_lower) + else: + _, rec_index_lower, _, _ = vq_model_lower.quantizer(net_out_val["rec_lower"]) + #rec_lower = vq_model_lower.decoder(rec_index_lower) + if args.ch != 0: + rec_index_hands = log_softmax(net_out_val["cls_hands"]).reshape(-1, args.vae_codebook_size) + _, rec_index_hands = torch.max(rec_index_hands.reshape(-1, args.pose_length, args.vae_codebook_size), dim=2) + #rec_hands = vq_model_hands.decode(rec_index_hands) + else: + _, rec_index_hands, _, _ = vq_model_hands.quantizer(net_out_val["rec_hands"]) + #rec_hands = vq_model_hands.decoder(rec_index_hands) + if args.cf != 0: + rec_index_face = log_softmax(net_out_val["cls_face"]).reshape(-1, args.vae_codebook_size) + _, rec_index_face = torch.max(rec_index_face.reshape(-1, args.pose_length, args.vae_codebook_size), dim=2) + #rec_face = vq_model_face.decoder(rec_index_face) + else: + _, rec_index_face, _, _ = vq_model_face.quantizer(net_out_val["rec_face"]) + #rec_face = vq_model_face.decoder(rec_index_face) + + if i == 0: + rec_index_all_face.append(rec_index_face) + rec_index_all_upper.append(rec_index_upper) + rec_index_all_lower.append(rec_index_lower) + rec_index_all_hands.append(rec_index_hands) + else: + rec_index_all_face.append(rec_index_face[:, args.pre_frames:]) + rec_index_all_upper.append(rec_index_upper[:, args.pre_frames:]) + rec_index_all_lower.append(rec_index_lower[:, args.pre_frames:]) + rec_index_all_hands.append(rec_index_hands[:, args.pre_frames:]) + + if args.cu != 0: + rec_upper_last = vq_model_upper.decode(rec_index_upper) + else: + rec_upper_last = vq_model_upper.decoder(rec_index_upper) + if args.cl != 0: + rec_lower_last = vq_model_lower.decode(rec_index_lower) + else: + rec_lower_last = vq_model_lower.decoder(rec_index_lower) + if args.ch != 0: + rec_hands_last = vq_model_hands.decode(rec_index_hands) + else: + rec_hands_last = vq_model_hands.decoder(rec_index_hands) + # if args.cf != 0: + # rec_face_last = vq_model_face.decode(rec_index_face) + # else: + # rec_face_last = vq_model_face.decoder(rec_index_face) + + rec_pose_legs = rec_lower_last[:, :, :54] + bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] + rec_pose_upper = rec_upper_last.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = inverse_selection_tensor(rec_pose_upper, joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = inverse_selection_tensor(rec_pose_lower, joint_mask_lower, bs*n) + rec_pose_hands = rec_hands_last.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = inverse_selection_tensor(rec_pose_hands, joint_mask_hands, bs*n) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + rec_trans_v_s = rec_lower_last[:, :, 54:57] + rec_x_trans = other_tools_hf.velocity2position(rec_trans_v_s[:, :, 0:1], 1/args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools_hf.velocity2position(rec_trans_v_s[:, :, 2:3], 1/args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + latent_last = torch.cat([rec_pose, rec_trans, rec_lower_last[:, :, 57:61]], dim=-1) + + rec_index_face = torch.cat(rec_index_all_face, dim=1) + rec_index_upper = torch.cat(rec_index_all_upper, dim=1) + rec_index_lower = torch.cat(rec_index_all_lower, dim=1) + rec_index_hands = torch.cat(rec_index_all_hands, dim=1) + if args.cu != 0: + rec_upper = vq_model_upper.decode(rec_index_upper) + else: + rec_upper = vq_model_upper.decoder(rec_index_upper) + if args.cl != 0: + rec_lower = vq_model_lower.decode(rec_index_lower) + else: + rec_lower = vq_model_lower.decoder(rec_index_lower) + if args.ch != 0: + rec_hands = vq_model_hands.decode(rec_index_hands) + else: + rec_hands = vq_model_hands.decoder(rec_index_hands) + if args.cf != 0: + rec_face = vq_model_face.decode(rec_index_face) + else: + rec_face = vq_model_face.decoder(rec_index_face) + + rec_exps = rec_face[:, :, 6:] + rec_pose_jaw = rec_face[:, :, :6] + rec_pose_legs = rec_lower[:, :, :54] + bs, n = rec_pose_jaw.shape[0], rec_pose_jaw.shape[1] + rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = inverse_selection_tensor(rec_pose_upper, joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = inverse_selection_tensor(rec_pose_lower, joint_mask_lower, bs*n) + rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = inverse_selection_tensor(rec_pose_hands, joint_mask_hands, bs*n) + rec_pose_jaw = rec_pose_jaw.reshape(bs*n, 6) + rec_pose_jaw = rc.rotation_6d_to_matrix(rec_pose_jaw) + rec_pose_jaw = rc.matrix_to_axis_angle(rec_pose_jaw).reshape(bs*n, 1*3) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose[:, 66:69] = rec_pose_jaw + + to_global = rec_lower + to_global[:, :, 54:57] = 0.0 + to_global[:, :, :54] = rec_lower2global + rec_global = global_motion(to_global) + + rec_trans_v_s = rec_global["rec_pose"][:, :, 54:57] + rec_x_trans = other_tools_hf.velocity2position(rec_trans_v_s[:, :, 0:1], 1/args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools_hf.velocity2position(rec_trans_v_s[:, :, 2:3], 1/args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + tar_pose = tar_pose[:, :n, :] + tar_exps = tar_exps[:, :n, :] + tar_trans = tar_trans[:, :n, :] + tar_beta = tar_beta[:, :n, :] + + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + + net_out = { + 'rec_pose': rec_pose, + 'rec_trans': rec_trans, + 'tar_pose': tar_pose, + 'tar_exps': tar_exps, + 'tar_beta': tar_beta, + 'tar_trans': tar_trans, + 'rec_exps': rec_exps, + } + + + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + tar_exps = net_out['tar_exps'] + tar_beta = net_out['tar_beta'] + rec_trans = net_out['rec_trans'] + tar_trans = net_out['tar_trans'] + rec_exps = net_out['rec_exps'] + # print(rec_pose.shape, tar_pose.shape) + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], joints + # interpolate to 30fps + if (30/args.pose_fps) != 1: + assert 30%args.pose_fps == 0 + n *= int(30/args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/args.pose_fps, mode='linear').permute(0,2,1) + + # print(rec_pose.shape, tar_pose.shape) + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + + return tar_pose, rec_pose, tar_exps, tar_beta, rec_trans, tar_trans, rec_exps, bs, n, j + + +class BaseTrainer(object): + def __init__(self, args, sp, ap, tp): + hf_dir = "hf" + if not os.path.exists(args.out_path + "custom/" + hf_dir + "/"): + os.makedirs(args.out_path + "custom/" + hf_dir + "/") + sf.write(args.out_path + "custom/" + hf_dir + "/tmp.wav", ap[1][:ap[0]*8], ap[0]) + self.audio_path = args.out_path + "custom/" + hf_dir + "/tmp.wav" + audio, ssr = librosa.load(self.audio_path) + ap = (ssr, audio) + self.args = args + self.rank = 0 # dist.get_rank() + + #self.checkpoint_path = args.out_path + "custom/" + args.name + args.notes + "/" #wandb.run.dir #args.cache_path+args.out_path+"/"+args.name + self.checkpoint_path = args.out_path + "custom/" + hf_dir + "/" + if self.rank == 0: + self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test", smplx_path=sp, audio_path=ap, text_path=tp) + self.test_loader = torch.utils.data.DataLoader( + self.test_data, + batch_size=1, + shuffle=False, + num_workers=args.loader_workers, + drop_last=False, + ) + logger.info(f"Init test dataloader success") + model_module = __import__(f"models.{args.model}", fromlist=["something"]) + + if args.ddp: + self.model = getattr(model_module, args.g_name)(args).to(self.rank) + process_group = torch.distributed.new_group() + self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model, process_group) + self.model = DDP(self.model, device_ids=[self.rank], output_device=self.rank, + broadcast_buffers=False, find_unused_parameters=False) + else: + self.model = torch.nn.DataParallel(getattr(model_module, args.g_name)(args), args.gpus).cpu() + + if self.rank == 0: + logger.info(self.model) + logger.info(f"init {args.g_name} success") + + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ) + + self.args = args + self.joints = self.test_data.joints + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list_face = joints_list["beat_smplx_face"] + self.tar_joint_list_upper = joints_list["beat_smplx_upper"] + self.tar_joint_list_hands = joints_list["beat_smplx_hands"] + self.tar_joint_list_lower = joints_list["beat_smplx_lower"] + + self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = 55 + for joint_name in self.tar_joint_list_face: + self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_upper: + self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_hands: + self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_lower: + self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + + self.tracker = other_tools_hf.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False,False,False]) + + vq_model_module = __import__(f"models.motion_representation", fromlist=["something"]) + self.args.vae_layer = 2 + self.args.vae_length = 256 + self.args.vae_test_dim = 106 + self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).cpu() + # print(self.vq_model_face) + # other_tools_hf.load_checkpoints(self.vq_model_face, self.args.data_path_1 + "pretrained_vq/last_790_face_v2.bin", args.e_name) + self.args.vae_test_dim = 78 + self.vq_model_upper = getattr(vq_model_module, "VQVAEConvZero")(self.args).cpu() + # other_tools_hf.load_checkpoints(self.vq_model_upper, self.args.data_path_1 + "pretrained_vq/upper_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 180 + self.vq_model_hands = getattr(vq_model_module, "VQVAEConvZero")(self.args).cpu() + # other_tools_hf.load_checkpoints(self.vq_model_hands, self.args.data_path_1 + "pretrained_vq/hands_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).cpu() + # other_tools_hf.load_checkpoints(self.vq_model_lower, self.args.data_path_1 + "pretrained_vq/lower_foot_600.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.global_motion = getattr(vq_model_module, "VAEConvZero")(self.args).cpu() + # other_tools_hf.load_checkpoints(self.global_motion, self.args.data_path_1 + "pretrained_vq/last_1700_foot.bin", args.e_name) + self.args.vae_test_dim = 330 + self.args.vae_layer = 4 + self.args.vae_length = 240 + + # self.cls_loss = nn.NLLLoss().to(self.rank) + # self.reclatent_loss = nn.MSELoss().to(self.rank) + # self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank) + # self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank) + self.log_softmax = nn.LogSoftmax(dim=2) + + + def inverse_selection(self, filtered_t, selection_array, n): + original_shape_t = np.zeros((n, selection_array.size)) + selected_indices = np.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + selected_indices = torch.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + + def test_demo(self, epoch): + ''' + input audio and text, output motion + do not calculate loss and metric + save video + ''' + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + import shutil + shutil.rmtree(results_save_path) + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + align = 0 + latent_out = [] + latent_ori = [] + l2_all = 0 + lvel = 0 + for its, batch_data in enumerate(self.test_loader): + tar_pose, rec_pose, tar_exps, tar_beta, rec_trans, tar_trans, rec_exps, bs, n, j = test_demo_gpu( + self.model, self.vq_model_face, self.vq_model_upper, self.vq_model_hands, self.vq_model_lower, self.global_motion, self.smplx, + batch_data, + self.args, + self.joints, self.joint_mask_upper, self.joint_mask_lower, self.joint_mask_hands, + self.log_softmax, + ) + + tar_pose_np = tar_pose.detach().cpu().numpy() + rec_pose_np = rec_pose.detach().cpu().numpy() + rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) + rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) + #''' + # its = 0 + gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose_np, + expressions=tar_exp_np, + trans=tar_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + + total_length += n + render_vid_path = other_tools_hf.render_one_sequence_no_gt( + results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + # results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + results_save_path, + self.audio_path, + self.args.data_path_1+"smplx_models/", + use_matplotlib = False, + args = self.args, + ) + result = gr.Video(value=render_vid_path, visible=True) + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") + return result + + +@logger.catch +def emage(audio_path): + smplx_path = None + text_path = None + rank = 0 + world_size = 1 + args = config.parse_args() + #os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/" + if not sys.warnoptions: + warnings.simplefilter("ignore") + # dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + + #logger_tools.set_args_and_logger(args, rank) + other_tools_hf.set_random_seed(args) + other_tools_hf.print_exp_info(args) + + # return one intance of trainer + trainer = BaseTrainer(args, sp = smplx_path, ap = audio_path, tp = text_path) + result = trainer.test_demo(999) + return result + +examples = [ + ["./EMAGE/test_sequences/wave16k/2_scott_0_1_1.wav"], + ["./EMAGE/test_sequences/wave16k/2_scott_0_2_2.wav"], + ["./EMAGE/test_sequences/wave16k/2_scott_0_3_3.wav"], +] + +demo = gr.Interface( + emage, # function + inputs=[ + # gr.File(label="Please upload SMPL-X file with npz format here.", file_types=["npz", "NPZ"]), + gr.Audio(), + # gr.File(label="Please upload textgrid format file here.", file_types=["TextGrid", "Textgrid", "textgrid"]) + ], # input type + outputs=gr.Video(format="mp4", visible=True), + title='\ +
\ + EMAGE: Towards Unified Holistic Co-Speech Gesture Generation via Expressive Masked Audio Gesture Modeling
\ + CVPR 2024
\ +
', + description='\ +
\ + Haiyang Liu1*, Zihao Zhu2*, Giorgio Becherini3, Yichen Peng4, Mingyang Su5,
\ + You Zhou, Xuefei Zhe, Naoya Iwamoto, Bo Zheng, Michael J. Black3
\ + (*Equal Contribution)
\ + 1The University of Tokyo, 2Keio University, 4Japan Advanced Institute of Science and Technology,
\ + 3Max Planck Institute for Intelligent Systems, 5Tsinghua University
\ +
\ + ', + article="\ + Due to the limited resources in this space, we process the first 8s of your uploaded audio.
\ + Try to develop this space locally for longer motion generation, e.g., 60s.
\ + Relevant links: [Project Page (https://pantomatrix.github.io/EMAGE/)\ + ", + examples=examples, +) + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"]='127.0.0.1' + os.environ["MASTER_PORT"]='8675' + #os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + demo.launch(share=True) \ No newline at end of file diff --git a/.ipynb_checkpoints/packages-checkpoint.txt b/.ipynb_checkpoints/packages-checkpoint.txt new file mode 100644 index 0000000000000000000000000000000000000000..9695433780ba56a2bbc4a9ba13e2725d6f9bbb7a --- /dev/null +++ b/.ipynb_checkpoints/packages-checkpoint.txt @@ -0,0 +1,4 @@ +libgl1-mesa-dev +libglu1-mesa-dev +freeglut3-dev +mesa-common-dev \ No newline at end of file diff --git a/.ipynb_checkpoints/requirements-checkpoint.txt b/.ipynb_checkpoints/requirements-checkpoint.txt new file mode 100644 index 0000000000000000000000000000000000000000..924c282b799630240602eb639d8fb25d0505d74c --- /dev/null +++ b/.ipynb_checkpoints/requirements-checkpoint.txt @@ -0,0 +1,39 @@ +ffmpeg +ConfigArgParse==1.7 +fasttext==0.9.2 +h5py==3.10.0 +imageio==2.31.4 +ipython==8.12.3 +joblib==1.3.2 +librosa==0.10.1 +lmdb==1.4.1 +loguru==0.7.2 +matplotlib==3.7.3 +moviepy==1.0.3 +gradio +fasttext-wheel +opencv_contrib_python==4.8.1.78 +opencv_python==4.8.1.78 +pandas==1.5.3 +peakutils==1.3.4 +ptflops==0.7.1.2 +python_igraph==0.11.3 +pyvirtualdisplay==3.0 +PyYAML==6.0.1 +replicate==0.15.4 +scikit_learn==1.3.2 +scipy +soundfile==0.12.1 +termcolor==2.4.0 +textgrid==1.5 +torch==2.1.0 +torchvision +tqdm==4.66.1 +transformers==4.35.2 +trimesh==3.23.5 +wandb==0.16.0 +pyglet<2 +smplx +tensorboard +pyrender +pyarrow \ No newline at end of file diff --git a/.ipynb_checkpoints/test_demo-checkpoint.py b/.ipynb_checkpoints/test_demo-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad0b83c01947ccba896e8dd4d20a5b36860a060 --- /dev/null +++ b/.ipynb_checkpoints/test_demo-checkpoint.py @@ -0,0 +1,581 @@ +import os +import signal +import time +import csv +import sys +import warnings +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp +import numpy as np +import time +import pprint +from loguru import logger +import smplx +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from utils import config, logger_tools, other_tools, metric, data_transfer +from dataloaders import data_tools +from dataloaders.build_vocab import Vocab +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from dataloaders.data_tools import joints_list +from utils import rotation_conversions as rc + +class BaseTrainer(object): + def __init__(self, args): + self.args = args + self.rank = dist.get_rank() + self.checkpoint_path = args.out_path + "custom/" + args.name + args.notes + "/" #wandb.run.dir #args.cache_path+args.out_path+"/"+args.name + if self.rank == 0: + self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test") + self.test_loader = torch.utils.data.DataLoader( + self.test_data, + batch_size=1, + shuffle=False, + num_workers=args.loader_workers, + drop_last=False, + ) + logger.info(f"Init test dataloader success") + model_module = __import__(f"models.{args.model}", fromlist=["something"]) + + if args.ddp: + self.model = getattr(model_module, args.g_name)(args).to(self.rank) + process_group = torch.distributed.new_group() + self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model, process_group) + self.model = DDP(self.model, device_ids=[self.rank], output_device=self.rank, + broadcast_buffers=False, find_unused_parameters=False) + else: + self.model = torch.nn.DataParallel(getattr(model_module, args.g_name)(args), args.gpus).cuda() + + if self.rank == 0: + logger.info(self.model) + logger.info(f"init {args.g_name} success") + + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).to(self.rank).eval() + + self.args = args + self.joints = self.test_data.joints + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list_face = joints_list["beat_smplx_face"] + self.tar_joint_list_upper = joints_list["beat_smplx_upper"] + self.tar_joint_list_hands = joints_list["beat_smplx_hands"] + self.tar_joint_list_lower = joints_list["beat_smplx_lower"] + + self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = 55 + for joint_name in self.tar_joint_list_face: + self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_upper: + self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_hands: + self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_lower: + self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + + self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False,False,False]) + + vq_model_module = __import__(f"models.motion_representation", fromlist=["something"]) + self.args.vae_layer = 2 + self.args.vae_length = 256 + self.args.vae_test_dim = 106 + self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + # print(self.vq_model_face) + other_tools.load_checkpoints(self.vq_model_face, self.args.data_path_1 + "pretrained_vq/last_790_face_v2.bin", args.e_name) + self.args.vae_test_dim = 78 + self.vq_model_upper = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.vq_model_upper, self.args.data_path_1 + "pretrained_vq/upper_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 180 + self.vq_model_hands = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.vq_model_hands, self.args.data_path_1 + "pretrained_vq/hands_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.vq_model_lower, self.args.data_path_1 + "pretrained_vq/lower_foot_600.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.global_motion = getattr(vq_model_module, "VAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.global_motion, self.args.data_path_1 + "pretrained_vq/last_1700_foot.bin", args.e_name) + self.args.vae_test_dim = 330 + self.args.vae_layer = 4 + self.args.vae_length = 240 + + self.vq_model_face.eval() + self.vq_model_upper.eval() + self.vq_model_hands.eval() + self.vq_model_lower.eval() + self.global_motion.eval() + + self.cls_loss = nn.NLLLoss().to(self.rank) + self.reclatent_loss = nn.MSELoss().to(self.rank) + self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank) + self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank) + self.log_softmax = nn.LogSoftmax(dim=2).to(self.rank) + + + def inverse_selection(self, filtered_t, selection_array, n): + original_shape_t = np.zeros((n, selection_array.size)) + selected_indices = np.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + selected_indices = torch.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def _load_data(self, dict_data): + tar_pose_raw = dict_data["pose"] + tar_pose = tar_pose_raw[:, :, :165].to(self.rank) + tar_contact = tar_pose_raw[:, :, 165:169].to(self.rank) + tar_trans = dict_data["trans"].to(self.rank) + tar_exps = dict_data["facial"].to(self.rank) + in_audio = dict_data["audio"].to(self.rank) + in_word = dict_data["word"].to(self.rank) + tar_beta = dict_data["beta"].to(self.rank) + tar_id = dict_data["id"].to(self.rank).long() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + # tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2) + + tar_index_value_face_top = self.vq_model_face.map2index(tar_pose_face) # bs*n/4 + tar_index_value_upper_top = self.vq_model_upper.map2index(tar_pose_upper) # bs*n/4 + tar_index_value_hands_top = self.vq_model_hands.map2index(tar_pose_hands) # bs*n/4 + tar_index_value_lower_top = self.vq_model_lower.map2index(tar_pose_lower) # bs*n/4 + + latent_face_top = self.vq_model_face.map2latent(tar_pose_face) # bs*n/4 + latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper) # bs*n/4 + latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands) # bs*n/4 + latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower) # bs*n/4 + + latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) + + index_in = torch.stack([tar_index_value_upper_top, tar_index_value_hands_top, tar_index_value_lower_top], dim=-1).long() + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + # print(tar_index_value_upper_top.shape, index_in.shape) + return { + "tar_pose_jaw": tar_pose_jaw, + "tar_pose_face": tar_pose_face, + "tar_pose_upper": tar_pose_upper, + "tar_pose_lower": tar_pose_lower, + "tar_pose_hands": tar_pose_hands, + 'tar_pose_leg': tar_pose_leg, + "in_audio": in_audio, + "in_word": in_word, + "tar_trans": tar_trans, + "tar_exps": tar_exps, + "tar_beta": tar_beta, + "tar_pose": tar_pose, + "tar4dis": tar4dis, + "tar_index_value_face_top": tar_index_value_face_top, + "tar_index_value_upper_top": tar_index_value_upper_top, + "tar_index_value_hands_top": tar_index_value_hands_top, + "tar_index_value_lower_top": tar_index_value_lower_top, + "latent_face_top": latent_face_top, + "latent_upper_top": latent_upper_top, + "latent_hands_top": latent_hands_top, + "latent_lower_top": latent_lower_top, + "latent_in": latent_in, + "index_in": index_in, + "tar_id": tar_id, + "latent_all": latent_all, + "tar_pose_6d": tar_pose_6d, + "tar_contact": tar_contact, + } + + def _g_test(self, loaded_data): + mode = 'test' + bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints + tar_pose = loaded_data["tar_pose"] + tar_beta = loaded_data["tar_beta"] + in_word = loaded_data["in_word"] + tar_exps = loaded_data["tar_exps"] + tar_contact = loaded_data["tar_contact"] + in_audio = loaded_data["in_audio"] + tar_trans = loaded_data["tar_trans"] + + remain = n%8 + if remain != 0: + tar_pose = tar_pose[:, :-remain, :] + tar_beta = tar_beta[:, :-remain, :] + tar_trans = tar_trans[:, :-remain, :] + in_word = in_word[:, :-remain] + tar_exps = tar_exps[:, :-remain, :] + tar_contact = tar_contact[:, :-remain, :] + n = n - remain + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + + rec_index_all_face = [] + rec_index_all_upper = [] + rec_index_all_lower = [] + rec_index_all_hands = [] + + roundt = (n - self.args.pre_frames) // (self.args.pose_length - self.args.pre_frames) + remain = (n - self.args.pre_frames) % (self.args.pose_length - self.args.pre_frames) + round_l = self.args.pose_length - self.args.pre_frames + + for i in range(0, roundt): + in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames] + # audio fps is 16000 and pose fps is 30 + in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames] + in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames] + mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims+3+4).float().cuda() + mask_val[:, :self.args.pre_frames, :] = 0.0 + if i == 0: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames, :] + else: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames, :] + # print(latent_all_tmp.shape, latent_last.shape) + latent_all_tmp[:, :self.args.pre_frames, :] = latent_last[:, -self.args.pre_frames:, :] + + net_out_val = self.model( + in_audio = in_audio_tmp, + in_word=in_word_tmp, + mask=mask_val, + in_motion = latent_all_tmp, + in_id = in_id_tmp, + use_attentions=True,) + + if self.args.cu != 0: + rec_index_upper = self.log_softmax(net_out_val["cls_upper"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_upper = torch.max(rec_index_upper.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_upper = self.vq_model_upper.decode(rec_index_upper) + else: + _, rec_index_upper, _, _ = self.vq_model_upper.quantizer(net_out_val["rec_upper"]) + #rec_upper = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + rec_index_lower = self.log_softmax(net_out_val["cls_lower"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_lower = torch.max(rec_index_lower.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_lower = self.vq_model_lower.decode(rec_index_lower) + else: + _, rec_index_lower, _, _ = self.vq_model_lower.quantizer(net_out_val["rec_lower"]) + #rec_lower = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + rec_index_hands = self.log_softmax(net_out_val["cls_hands"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_hands = torch.max(rec_index_hands.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_hands = self.vq_model_hands.decode(rec_index_hands) + else: + _, rec_index_hands, _, _ = self.vq_model_hands.quantizer(net_out_val["rec_hands"]) + #rec_hands = self.vq_model_hands.decoder(rec_index_hands) + if self.args.cf != 0: + rec_index_face = self.log_softmax(net_out_val["cls_face"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_face = torch.max(rec_index_face.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_face = self.vq_model_face.decoder(rec_index_face) + else: + _, rec_index_face, _, _ = self.vq_model_face.quantizer(net_out_val["rec_face"]) + #rec_face = self.vq_model_face.decoder(rec_index_face) + + if i == 0: + rec_index_all_face.append(rec_index_face) + rec_index_all_upper.append(rec_index_upper) + rec_index_all_lower.append(rec_index_lower) + rec_index_all_hands.append(rec_index_hands) + else: + rec_index_all_face.append(rec_index_face[:, self.args.pre_frames:]) + rec_index_all_upper.append(rec_index_upper[:, self.args.pre_frames:]) + rec_index_all_lower.append(rec_index_lower[:, self.args.pre_frames:]) + rec_index_all_hands.append(rec_index_hands[:, self.args.pre_frames:]) + + if self.args.cu != 0: + rec_upper_last = self.vq_model_upper.decode(rec_index_upper) + else: + rec_upper_last = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + rec_lower_last = self.vq_model_lower.decode(rec_index_lower) + else: + rec_lower_last = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + rec_hands_last = self.vq_model_hands.decode(rec_index_hands) + else: + rec_hands_last = self.vq_model_hands.decoder(rec_index_hands) + # if self.args.cf != 0: + # rec_face_last = self.vq_model_face.decode(rec_index_face) + # else: + # rec_face_last = self.vq_model_face.decoder(rec_index_face) + + rec_pose_legs = rec_lower_last[:, :, :54] + bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] + rec_pose_upper = rec_upper_last.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n) + rec_pose_hands = rec_hands_last.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + rec_trans_v_s = rec_lower_last[:, :, 54:57] + rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + latent_last = torch.cat([rec_pose, rec_trans, rec_lower_last[:, :, 57:61]], dim=-1) + + rec_index_face = torch.cat(rec_index_all_face, dim=1) + rec_index_upper = torch.cat(rec_index_all_upper, dim=1) + rec_index_lower = torch.cat(rec_index_all_lower, dim=1) + rec_index_hands = torch.cat(rec_index_all_hands, dim=1) + if self.args.cu != 0: + rec_upper = self.vq_model_upper.decode(rec_index_upper) + else: + rec_upper = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + rec_lower = self.vq_model_lower.decode(rec_index_lower) + else: + rec_lower = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + rec_hands = self.vq_model_hands.decode(rec_index_hands) + else: + rec_hands = self.vq_model_hands.decoder(rec_index_hands) + if self.args.cf != 0: + rec_face = self.vq_model_face.decode(rec_index_face) + else: + rec_face = self.vq_model_face.decoder(rec_index_face) + + rec_exps = rec_face[:, :, 6:] + rec_pose_jaw = rec_face[:, :, :6] + rec_pose_legs = rec_lower[:, :, :54] + bs, n = rec_pose_jaw.shape[0], rec_pose_jaw.shape[1] + rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n) + rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n) + rec_pose_jaw = rec_pose_jaw.reshape(bs*n, 6) + rec_pose_jaw = rc.rotation_6d_to_matrix(rec_pose_jaw) + rec_pose_jaw = rc.matrix_to_axis_angle(rec_pose_jaw).reshape(bs*n, 1*3) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose[:, 66:69] = rec_pose_jaw + + to_global = rec_lower + to_global[:, :, 54:57] = 0.0 + to_global[:, :, :54] = rec_lower2global + rec_global = self.global_motion(to_global) + + rec_trans_v_s = rec_global["rec_pose"][:, :, 54:57] + rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + tar_pose = tar_pose[:, :n, :] + tar_exps = tar_exps[:, :n, :] + tar_trans = tar_trans[:, :n, :] + tar_beta = tar_beta[:, :n, :] + + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + + return { + 'rec_pose': rec_pose, + 'rec_trans': rec_trans, + 'tar_pose': tar_pose, + 'tar_exps': tar_exps, + 'tar_beta': tar_beta, + 'tar_trans': tar_trans, + 'rec_exps': rec_exps, + } + + + def test_demo(self, epoch): + ''' + input audio and text, output motion + do not calculate loss and metric + save video + ''' + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + return 0 + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + align = 0 + latent_out = [] + latent_ori = [] + l2_all = 0 + lvel = 0 + self.model.eval() + self.smplx.eval() + # self.eval_copy.eval() + with torch.no_grad(): + for its, batch_data in enumerate(self.test_loader): + loaded_data = self._load_data(batch_data) + net_out = self._g_test(loaded_data) + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + tar_exps = net_out['tar_exps'] + tar_beta = net_out['tar_beta'] + rec_trans = net_out['rec_trans'] + tar_trans = net_out['tar_trans'] + rec_exps = net_out['rec_exps'] + # print(rec_pose.shape, tar_pose.shape) + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + + # interpolate to 30fps + if (30/self.args.pose_fps) != 1: + assert 30%self.args.pose_fps == 0 + n *= int(30/self.args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + + # print(rec_pose.shape, tar_pose.shape) + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + + tar_pose_np = tar_pose.detach().cpu().numpy() + rec_pose_np = rec_pose.detach().cpu().numpy() + rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) + rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) + + gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose_np, + expressions=tar_exp_np, + trans=tar_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + total_length += n + # other_tools.render_one_sequence( + # results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + # results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + # results_save_path, + # self.args.data_path+"wave16k/"+test_seq_list.iloc[its]['id']+".wav", + # self.args.data_path_1+"smplx_models/", + # use_matplotlib = False, + # args = self.args, + # ) + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") + +@logger.catch +def main_worker(rank, world_size, args): + #os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/" + if not sys.warnoptions: + warnings.simplefilter("ignore") + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + logger_tools.set_args_and_logger(args, rank) + other_tools.set_random_seed(args) + other_tools.print_exp_info(args) + + # return one intance of trainer + other_tools.write_wav_names_to_csv(args.data_path, args.data_path+"test.csv") + trainer = BaseTrainer(args) + other_tools.load_checkpoints(trainer.model, args.test_ckpt, args.g_name) + trainer.test_demo(999) + + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"]='127.0.0.1' + os.environ["MASTER_PORT"]='8675' + #os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + args = config.parse_args() + if args.ddp: + mp.set_start_method("spawn", force=True) + mp.spawn( + main_worker, + args=(len(args.gpus), args,), + nprocs=len(args.gpus), + ) + else: + main_worker(0, 1, args) \ No newline at end of file diff --git a/EMAGE/emage_audio_175.bin b/EMAGE/emage_audio_175.bin new file mode 100644 index 0000000000000000000000000000000000000000..a65200bc53a2212e78567c7bbda12ec2e6e8dc46 --- /dev/null +++ b/EMAGE/emage_audio_175.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b19f845300e7f52c77eddfb6307f48c8fd2766edada3efa8ad1973a87990c1ea +size 556333206 diff --git a/EMAGE/pretrained_vq/.DS_Store b/EMAGE/pretrained_vq/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/EMAGE/pretrained_vq/.DS_Store differ diff --git a/EMAGE/pretrained_vq/hands_vertex_1layer_710.bin b/EMAGE/pretrained_vq/hands_vertex_1layer_710.bin new file mode 100644 index 0000000000000000000000000000000000000000..da28ca39541d2811d9753c84bee2d8ead465826e --- /dev/null +++ b/EMAGE/pretrained_vq/hands_vertex_1layer_710.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1595a13fbdf38b95da2baf6a4ba9f0c62cd6af8b8f537da12c1c90321affa3b3 +size 9644516 diff --git a/EMAGE/pretrained_vq/last_1700_foot.bin b/EMAGE/pretrained_vq/last_1700_foot.bin new file mode 100644 index 0000000000000000000000000000000000000000..06d2da3a17f60955e3e78333d0c3a2e3e99acb86 --- /dev/null +++ b/EMAGE/pretrained_vq/last_1700_foot.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f293265b828c6b45e12068c9b7956283c92b40cfdc9dd56ae960bbeb7bba1ad6 +size 14611444 diff --git a/EMAGE/pretrained_vq/last_790_face_v2.bin b/EMAGE/pretrained_vq/last_790_face_v2.bin new file mode 100644 index 0000000000000000000000000000000000000000..427de62ce83e964d2628d8b56105f80719c8ae21 --- /dev/null +++ b/EMAGE/pretrained_vq/last_790_face_v2.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13ff79afef2c3209804c0cae2b9a7c467c1a39268efa87a637e860b8e6b1b4c0 +size 8935204 diff --git a/EMAGE/pretrained_vq/lower_foot_600.bin b/EMAGE/pretrained_vq/lower_foot_600.bin new file mode 100644 index 0000000000000000000000000000000000000000..0270a5e73202732adde4e6b88cb7d24cf8adc3b2 --- /dev/null +++ b/EMAGE/pretrained_vq/lower_foot_600.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e323ed5f7014957433b59249497188656811b76952d666eae5f4affdc341786 +size 14873924 diff --git a/EMAGE/pretrained_vq/upper_vertex_1layer_710.bin b/EMAGE/pretrained_vq/upper_vertex_1layer_710.bin new file mode 100644 index 0000000000000000000000000000000000000000..33ccfbdf3aaf92df0e06bf3779ba7730bfb75503 --- /dev/null +++ b/EMAGE/pretrained_vq/upper_vertex_1layer_710.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58ffcb34ff18f3aeaf53898980ef623ea9ce36f0302a005b5f95ceef1a206a8f +size 8701092 diff --git a/EMAGE/smplx_models/.DS_Store b/EMAGE/smplx_models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..64c4cc33f2dfef96d20152797da59df9627ab464 Binary files /dev/null and b/EMAGE/smplx_models/.DS_Store differ diff --git a/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz b/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz new file mode 100644 index 0000000000000000000000000000000000000000..998c299e1ea055a6091e87edb92774c888940cd7 --- /dev/null +++ b/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdf06146e27d92022fe5dadad3b9203373f6879eca8e4d8235359ee3ec6a5a74 +size 167264530 diff --git a/EMAGE/test_sequences/smplxflame_30/2_scott_0_1_1.npz b/EMAGE/test_sequences/smplxflame_30/2_scott_0_1_1.npz new file mode 100644 index 0000000000000000000000000000000000000000..4d96616f0a09bd6df0971b9547d2e8376b078fcf --- /dev/null +++ b/EMAGE/test_sequences/smplxflame_30/2_scott_0_1_1.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37b112fd59fcabb09270d6ca3c74e7459cc5b9729564bcacf1f75609f3999592 +size 2831524 diff --git a/EMAGE/test_sequences/smplxflame_30/2_scott_0_2_2.npz b/EMAGE/test_sequences/smplxflame_30/2_scott_0_2_2.npz new file mode 100644 index 0000000000000000000000000000000000000000..f52cab82761973524dbd96ba8d442591779073bd --- /dev/null +++ b/EMAGE/test_sequences/smplxflame_30/2_scott_0_2_2.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5875f768aa4600af7d767625e0d87941b1cca9555855d8c6b509004116790f7d +size 2754356 diff --git a/EMAGE/test_sequences/smplxflame_30/2_scott_0_3_3.npz b/EMAGE/test_sequences/smplxflame_30/2_scott_0_3_3.npz new file mode 100644 index 0000000000000000000000000000000000000000..67846e0f8dd34b32f72946fff008c3ba3598ef0b --- /dev/null +++ b/EMAGE/test_sequences/smplxflame_30/2_scott_0_3_3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23ace88c7ff0288af83cc30d2428e0cb70c3d92bce981a67a5811cd53ab96db4 +size 3021476 diff --git a/EMAGE/test_sequences/smplxflame_30/2_scott_0_4_4.npz b/EMAGE/test_sequences/smplxflame_30/2_scott_0_4_4.npz new file mode 100644 index 0000000000000000000000000000000000000000..a7f4b518c86337929556e42a2f79cae7bdfd66f2 --- /dev/null +++ b/EMAGE/test_sequences/smplxflame_30/2_scott_0_4_4.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ede3993db9565b7b3a945532def69d617d6b2338f488a746a7be998f3b0685d8 +size 2976956 diff --git a/EMAGE/test_sequences/test.csv b/EMAGE/test_sequences/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..1f33f8b45bcd303b94c113e4865ab89ae14cbd24 --- /dev/null +++ b/EMAGE/test_sequences/test.csv @@ -0,0 +1,5 @@ +id,type +2_scott_0_3_3,test +2_scott_0_2_2,test +2_scott_0_1_1,test +2_scott_0_4_4,test diff --git a/EMAGE/test_sequences/textgrid/2_scott_0_1_1.TextGrid b/EMAGE/test_sequences/textgrid/2_scott_0_1_1.TextGrid new file mode 100644 index 0000000000000000000000000000000000000000..abd0228f0e219b86cc370b11c03712d85afa21cc --- /dev/null +++ b/EMAGE/test_sequences/textgrid/2_scott_0_1_1.TextGrid @@ -0,0 +1,3636 @@ +File type = "ooTextFile" +Object class = "TextGrid" + +xmin = 0 +xmax = 64.097375 +tiers? +size = 2 +item []: + item [1]: + class = "IntervalTier" + name = "words" + xmin = 0 + xmax = 64.097375 + intervals: size = 220 + intervals [1]: + xmin = 0 + xmax = 1.42 + text = "" + intervals [2]: + xmin = 1.42 + xmax = 1.52 + text = "the" + intervals [3]: + xmin = 1.52 + xmax = 1.78 + text = "first" + intervals [4]: + xmin = 1.78 + xmax = 1.97 + text = "thing" + intervals [5]: + xmin = 1.97 + xmax = 2.04 + text = "i" + intervals [6]: + xmin = 2.04 + xmax = 2.21 + text = "like" + intervals [7]: + xmin = 2.21 + xmax = 2.28 + text = "to" + intervals [8]: + xmin = 2.28 + xmax = 2.47 + text = "do" + intervals [9]: + xmin = 2.47 + xmax = 2.63 + text = "on" + intervals [10]: + xmin = 2.63 + xmax = 3.32 + text = "weekends" + intervals [11]: + xmin = 3.32 + xmax = 3.58 + text = "is" + intervals [12]: + xmin = 3.58 + xmax = 4.41 + text = "relaxing" + intervals [13]: + xmin = 4.41 + xmax = 4.52 + text = "" + intervals [14]: + xmin = 4.52 + xmax = 5.05 + text = "and" + intervals [15]: + xmin = 5.05 + xmax = 5.14 + text = "i" + intervals [16]: + xmin = 5.14 + xmax = 5.33 + text = "think" + intervals [17]: + xmin = 5.33 + xmax = 5.41 + text = "i'll" + intervals [18]: + xmin = 5.41 + xmax = 5.5 + text = "go" + intervals [19]: + xmin = 5.5 + xmax = 6 + text = "shopping" + intervals [20]: + xmin = 6 + xmax = 6.11 + text = "if" + intervals [21]: + xmin = 6.11 + xmax = 6.29 + text = "i'm" + intervals [22]: + xmin = 6.29 + xmax = 6.54 + text = "not" + intervals [23]: + xmin = 6.54 + xmax = 6.7 + text = "that" + intervals [24]: + xmin = 6.7 + xmax = 7.19 + text = "tired" + intervals [25]: + xmin = 7.19 + xmax = 7.45 + text = "" + intervals [26]: + xmin = 7.45 + xmax = 7.62 + text = "so" + intervals [27]: + xmin = 7.62 + xmax = 7.74 + text = "that" + intervals [28]: + xmin = 7.74 + xmax = 7.85 + text = "you" + intervals [29]: + xmin = 7.85 + xmax = 8.14 + text = "started" + intervals [30]: + xmin = 8.14 + xmax = 8.24 + text = "by" + intervals [31]: + xmin = 8.24 + xmax = 8.52 + text = "job" + intervals [32]: + xmin = 8.52 + xmax = 8.59 + text = "i" + intervals [33]: + xmin = 8.59 + xmax = 8.75 + text = "think" + intervals [34]: + xmin = 8.75 + xmax = 8.88 + text = "it's" + intervals [35]: + xmin = 8.88 + xmax = 9.35 + text = "very" + intervals [36]: + xmin = 9.35 + xmax = 9.8 + text = "important" + intervals [37]: + xmin = 9.8 + xmax = 9.87 + text = "to" + intervals [38]: + xmin = 9.87 + xmax = 9.99 + text = "get" + intervals [39]: + xmin = 9.99 + xmax = 10.03 + text = "a" + intervals [40]: + xmin = 10.03 + xmax = 10.17 + text = "good" + intervals [41]: + xmin = 10.17 + xmax = 10.56 + text = "sleep" + intervals [42]: + xmin = 10.56 + xmax = 11.14 + text = "during" + intervals [43]: + xmin = 11.14 + xmax = 11.32 + text = "your" + intervals [44]: + xmin = 11.32 + xmax = 11.77 + text = "weekend" + intervals [45]: + xmin = 11.77 + xmax = 12.4 + text = "because" + intervals [46]: + xmin = 12.4 + xmax = 12.95 + text = "when" + intervals [47]: + xmin = 12.95 + xmax = 13.04 + text = "you" + intervals [48]: + xmin = 13.04 + xmax = 13.19 + text = "have" + intervals [49]: + xmin = 13.19 + xmax = 13.27 + text = "to" + intervals [50]: + xmin = 13.27 + xmax = 13.44 + text = "work" + intervals [51]: + xmin = 13.44 + xmax = 13.58 + text = "on" + intervals [52]: + xmin = 13.58 + xmax = 13.96 + text = "monday" + intervals [53]: + xmin = 13.96 + xmax = 14.1 + text = "through" + intervals [54]: + xmin = 14.1 + xmax = 14.75 + text = "friday" + intervals [55]: + xmin = 14.75 + xmax = 15.41 + text = "" + intervals [56]: + xmin = 15.41 + xmax = 15.53 + text = "the" + intervals [57]: + xmin = 15.53 + xmax = 15.75 + text = "whole" + intervals [58]: + xmin = 15.75 + xmax = 16.09 + text = "week" + intervals [59]: + xmin = 16.09 + xmax = 16.28 + text = "" + intervals [60]: + xmin = 16.28 + xmax = 16.42 + text = "you" + intervals [61]: + xmin = 16.42 + xmax = 16.49 + text = "are" + intervals [62]: + xmin = 16.49 + xmax = 16.73 + text = "very" + intervals [63]: + xmin = 16.73 + xmax = 17.59 + text = "tired" + intervals [64]: + xmin = 17.59 + xmax = 17.83 + text = "" + intervals [65]: + xmin = 17.83 + xmax = 18.29 + text = "so" + intervals [66]: + xmin = 18.29 + xmax = 18.55 + text = "getting" + intervals [67]: + xmin = 18.55 + xmax = 18.61 + text = "a" + intervals [68]: + xmin = 18.61 + xmax = 18.78 + text = "good" + intervals [69]: + xmin = 18.78 + xmax = 19.08 + text = "rest" + intervals [70]: + xmin = 19.08 + xmax = 19.21 + text = "is" + intervals [71]: + xmin = 19.21 + xmax = 19.3 + text = "as" + intervals [72]: + xmin = 19.3 + xmax = 19.77 + text = "important" + intervals [73]: + xmin = 19.77 + xmax = 20.16 + text = "as" + intervals [74]: + xmin = 20.16 + xmax = 20.3 + text = "" + intervals [75]: + xmin = 20.3 + xmax = 20.66 + text = "complain" + intervals [76]: + xmin = 20.66 + xmax = 20.75 + text = "to" + intervals [77]: + xmin = 20.75 + xmax = 21.09 + text = "jaw" + intervals [78]: + xmin = 21.09 + xmax = 21.3 + text = "or" + intervals [79]: + xmin = 21.3 + xmax = 21.79 + text = "completing" + intervals [80]: + xmin = 21.79 + xmax = 21.9 + text = "an" + intervals [81]: + xmin = 21.9 + xmax = 22.23 + text = "excellent" + intervals [82]: + xmin = 22.23 + xmax = 22.64 + text = "job" + intervals [83]: + xmin = 22.64 + xmax = 23.04 + text = "" + intervals [84]: + xmin = 23.04 + xmax = 23.17 + text = "in" + intervals [85]: + xmin = 23.17 + xmax = 23.29 + text = "my" + intervals [86]: + xmin = 23.29 + xmax = 23.56 + text = "spare" + intervals [87]: + xmin = 23.56 + xmax = 23.8 + text = "time" + intervals [88]: + xmin = 23.8 + xmax = 23.88 + text = "if" + intervals [89]: + xmin = 23.88 + xmax = 23.98 + text = "i" + intervals [90]: + xmin = 23.98 + xmax = 24.18 + text = "feel" + intervals [91]: + xmin = 24.18 + xmax = 24.84 + text = "okay" + intervals [92]: + xmin = 24.84 + xmax = 25.07 + text = "i" + intervals [93]: + xmin = 25.07 + xmax = 25.1 + text = "" + intervals [94]: + xmin = 25.1 + xmax = 25.38 + text = "like" + intervals [95]: + xmin = 25.38 + xmax = 25.44 + text = "to" + intervals [96]: + xmin = 25.44 + xmax = 25.55 + text = "go" + intervals [97]: + xmin = 25.55 + xmax = 25.79 + text = "for" + intervals [98]: + xmin = 25.79 + xmax = 25.83 + text = "a" + intervals [99]: + xmin = 25.83 + xmax = 26.12 + text = "hike" + intervals [100]: + xmin = 26.12 + xmax = 26.21 + text = "in" + intervals [101]: + xmin = 26.21 + xmax = 26.81 + text = "nature" + intervals [102]: + xmin = 26.81 + xmax = 27.11 + text = "" + intervals [103]: + xmin = 27.11 + xmax = 27.45 + text = "sometimes" + intervals [104]: + xmin = 27.45 + xmax = 27.51 + text = "i" + intervals [105]: + xmin = 27.51 + xmax = 27.74 + text = "try" + intervals [106]: + xmin = 27.74 + xmax = 27.88 + text = "to" + intervals [107]: + xmin = 27.88 + xmax = 28.37 + text = "organize" + intervals [108]: + xmin = 28.37 + xmax = 28.94 + text = "something" + intervals [109]: + xmin = 28.94 + xmax = 28.98 + text = "" + intervals [110]: + xmin = 28.98 + xmax = 29.19 + text = "for" + intervals [111]: + xmin = 29.19 + xmax = 29.32 + text = "my" + intervals [112]: + xmin = 29.32 + xmax = 29.89 + text = "friends" + intervals [113]: + xmin = 29.89 + xmax = 29.92 + text = "" + intervals [114]: + xmin = 29.92 + xmax = 29.95 + text = "i" + intervals [115]: + xmin = 29.95 + xmax = 30.2 + text = "" + intervals [116]: + xmin = 30.2 + xmax = 30.73 + text = "volunteer" + intervals [117]: + xmin = 30.73 + xmax = 30.86 + text = "at" + intervals [118]: + xmin = 30.86 + xmax = 30.97 + text = "the" + intervals [119]: + xmin = 30.97 + xmax = 31.38 + text = "buddhist" + intervals [120]: + xmin = 31.38 + xmax = 31.83 + text = "temple" + intervals [121]: + xmin = 31.83 + xmax = 31.94 + text = "on" + intervals [122]: + xmin = 31.94 + xmax = 32.01 + text = "the" + intervals [123]: + xmin = 32.01 + xmax = 32.6 + text = "weekend" + intervals [124]: + xmin = 32.6 + xmax = 33.01 + text = "or" + intervals [125]: + xmin = 33.01 + xmax = 33.24 + text = "i" + intervals [126]: + xmin = 33.24 + xmax = 33.62 + text = "can" + intervals [127]: + xmin = 33.62 + xmax = 33.91 + text = "just" + intervals [128]: + xmin = 33.91 + xmax = 34.3 + text = "walk" + intervals [129]: + xmin = 34.3 + xmax = 34.69 + text = "around" + intervals [130]: + xmin = 34.69 + xmax = 35.08 + text = "enjoying" + intervals [131]: + xmin = 35.08 + xmax = 35.17 + text = "the" + intervals [132]: + xmin = 35.17 + xmax = 35.87 + text = "sunshine" + intervals [133]: + xmin = 35.87 + xmax = 36.15 + text = "" + intervals [134]: + xmin = 36.15 + xmax = 36.34 + text = "i'd" + intervals [135]: + xmin = 36.34 + xmax = 36.52 + text = "like" + intervals [136]: + xmin = 36.52 + xmax = 36.59 + text = "to" + intervals [137]: + xmin = 36.59 + xmax = 36.74 + text = "have" + intervals [138]: + xmin = 36.74 + xmax = 36.79 + text = "a" + intervals [139]: + xmin = 36.79 + xmax = 37.06 + text = "healthy" + intervals [140]: + xmin = 37.06 + xmax = 37.66 + text = "lifestyle" + intervals [141]: + xmin = 37.66 + xmax = 38.06 + text = "considering" + intervals [142]: + xmin = 38.06 + xmax = 38.17 + text = "how" + intervals [143]: + xmin = 38.17 + xmax = 38.38 + text = "much" + intervals [144]: + xmin = 38.38 + xmax = 38.74 + text = "time" + intervals [145]: + xmin = 38.74 + xmax = 38.81 + text = "i" + intervals [146]: + xmin = 38.81 + xmax = 39.18 + text = "spend" + intervals [147]: + xmin = 39.18 + xmax = 39.29 + text = "at" + intervals [148]: + xmin = 39.29 + xmax = 39.84 + text = "work" + intervals [149]: + xmin = 39.84 + xmax = 40.29 + text = "" + intervals [150]: + xmin = 40.29 + xmax = 40.52 + text = "i" + intervals [151]: + xmin = 40.52 + xmax = 40.79 + text = "always" + intervals [152]: + xmin = 40.79 + xmax = 41.28 + text = "try" + intervals [153]: + xmin = 41.28 + xmax = 41.47 + text = "to" + intervals [154]: + xmin = 41.47 + xmax = 41.85 + text = "move" + intervals [155]: + xmin = 41.85 + xmax = 42 + text = "as" + intervals [156]: + xmin = 42 + xmax = 42.22 + text = "much" + intervals [157]: + xmin = 42.22 + xmax = 42.31 + text = "as" + intervals [158]: + xmin = 42.31 + xmax = 42.4 + text = "i" + intervals [159]: + xmin = 42.4 + xmax = 42.76 + text = "can" + intervals [160]: + xmin = 42.76 + xmax = 42.89 + text = "when" + intervals [161]: + xmin = 42.89 + xmax = 42.98 + text = "i'm" + intervals [162]: + xmin = 42.98 + xmax = 43.18 + text = "not" + intervals [163]: + xmin = 43.18 + xmax = 43.76 + text = "working" + intervals [164]: + xmin = 43.76 + xmax = 44.5 + text = "" + intervals [165]: + xmin = 44.5 + xmax = 45.19 + text = "and" + intervals [166]: + xmin = 45.19 + xmax = 45.32 + text = "on" + intervals [167]: + xmin = 45.32 + xmax = 45.49 + text = "other" + intervals [168]: + xmin = 45.49 + xmax = 45.82 + text = "days" + intervals [169]: + xmin = 45.82 + xmax = 45.96 + text = "when" + intervals [170]: + xmin = 45.96 + xmax = 46.16 + text = "i'm" + intervals [171]: + xmin = 46.16 + xmax = 46.65 + text = "free" + intervals [172]: + xmin = 46.65 + xmax = 46.86 + text = "i" + intervals [173]: + xmin = 46.86 + xmax = 47.16 + text = "like" + intervals [174]: + xmin = 47.16 + xmax = 47.39 + text = "to" + intervals [175]: + xmin = 47.39 + xmax = 47.86 + text = "listen" + intervals [176]: + xmin = 47.86 + xmax = 48.03 + text = "to" + intervals [177]: + xmin = 48.03 + xmax = 48.41 + text = "music" + intervals [178]: + xmin = 48.41 + xmax = 48.73 + text = "and" + intervals [179]: + xmin = 48.73 + xmax = 48.76 + text = "" + intervals [180]: + xmin = 48.76 + xmax = 49.01 + text = "we're" + intervals [181]: + xmin = 49.01 + xmax = 49.3 + text = "watch" + intervals [182]: + xmin = 49.3 + xmax = 49.38 + text = "a" + intervals [183]: + xmin = 49.38 + xmax = 50.05 + text = "documentary" + intervals [184]: + xmin = 50.05 + xmax = 50.51 + text = "movies" + intervals [185]: + xmin = 50.51 + xmax = 50.82 + text = "on" + intervals [186]: + xmin = 50.82 + xmax = 51.11 + text = "my" + intervals [187]: + xmin = 51.11 + xmax = 51.81 + text = "laptop" + intervals [188]: + xmin = 51.81 + xmax = 52.14 + text = "" + intervals [189]: + xmin = 52.14 + xmax = 52.44 + text = "but" + intervals [190]: + xmin = 52.44 + xmax = 52.86 + text = "sometimes" + intervals [191]: + xmin = 52.86 + xmax = 52.93 + text = "it" + intervals [192]: + xmin = 52.93 + xmax = 53.13 + text = "just" + intervals [193]: + xmin = 53.13 + xmax = 53.61 + text = "sleep" + intervals [194]: + xmin = 53.61 + xmax = 53.65 + text = "" + intervals [195]: + xmin = 53.65 + xmax = 53.83 + text = "i" + intervals [196]: + xmin = 53.83 + xmax = 54.27 + text = "especially" + intervals [197]: + xmin = 54.27 + xmax = 54.61 + text = "liked" + intervals [198]: + xmin = 54.61 + xmax = 55.01 + text = "watching" + intervals [199]: + xmin = 55.01 + xmax = 55.62 + text = "japanese" + intervals [200]: + xmin = 55.62 + xmax = 55.91 + text = "anime" + intervals [201]: + xmin = 55.91 + xmax = 56.33 + text = "i" + intervals [202]: + xmin = 56.33 + xmax = 56.85 + text = "" + intervals [203]: + xmin = 56.85 + xmax = 57.12 + text = "think" + intervals [204]: + xmin = 57.12 + xmax = 57.43 + text = "watching" + intervals [205]: + xmin = 57.43 + xmax = 57.62 + text = "a" + intervals [206]: + xmin = 57.62 + xmax = 57.79 + text = "me" + intervals [207]: + xmin = 57.79 + xmax = 58.09 + text = "is" + intervals [208]: + xmin = 58.09 + xmax = 58.39 + text = "anime" + intervals [209]: + xmin = 58.39 + xmax = 59.06 + text = "is" + intervals [210]: + xmin = 59.06 + xmax = 59.31 + text = "very" + intervals [211]: + xmin = 59.31 + xmax = 59.67 + text = "helpful" + intervals [212]: + xmin = 59.67 + xmax = 59.81 + text = "for" + intervals [213]: + xmin = 59.81 + xmax = 59.98 + text = "me" + intervals [214]: + xmin = 59.98 + xmax = 60.28 + text = "to" + intervals [215]: + xmin = 60.28 + xmax = 60.69 + text = "learn" + intervals [216]: + xmin = 60.69 + xmax = 60.78 + text = "and" + intervals [217]: + xmin = 60.78 + xmax = 61.21 + text = "express" + intervals [218]: + xmin = 61.21 + xmax = 61.89 + text = "japanese" + intervals [219]: + xmin = 61.89 + xmax = 62.42 + text = "better" + intervals [220]: + xmin = 62.42 + xmax = 64.097375 + text = "" + item [2]: + class = "IntervalTier" + name = "phones" + xmin = 0 + xmax = 64.097375 + intervals: size = 684 + intervals [1]: + xmin = 0 + xmax = 1.42 + text = "" + intervals [2]: + xmin = 1.42 + xmax = 1.48 + text = "DH" + intervals [3]: + xmin = 1.48 + xmax = 1.52 + text = "AH0" + intervals [4]: + xmin = 1.52 + xmax = 1.62 + text = "F" + intervals [5]: + xmin = 1.62 + xmax = 1.72 + text = "ER1" + intervals [6]: + xmin = 1.72 + xmax = 1.75 + text = "S" + intervals [7]: + xmin = 1.75 + xmax = 1.78 + text = "T" + intervals [8]: + xmin = 1.78 + xmax = 1.81 + text = "TH" + intervals [9]: + xmin = 1.81 + xmax = 1.88 + text = "IH1" + intervals [10]: + xmin = 1.88 + xmax = 1.97 + text = "NG" + intervals [11]: + xmin = 1.97 + xmax = 2.04 + text = "AY1" + intervals [12]: + xmin = 2.04 + xmax = 2.08 + text = "L" + intervals [13]: + xmin = 2.08 + xmax = 2.17 + text = "AY1" + intervals [14]: + xmin = 2.17 + xmax = 2.21 + text = "K" + intervals [15]: + xmin = 2.21 + xmax = 2.24 + text = "T" + intervals [16]: + xmin = 2.24 + xmax = 2.28 + text = "IH0" + intervals [17]: + xmin = 2.28 + xmax = 2.34 + text = "D" + intervals [18]: + xmin = 2.34 + xmax = 2.47 + text = "UW1" + intervals [19]: + xmin = 2.47 + xmax = 2.58 + text = "AA1" + intervals [20]: + xmin = 2.58 + xmax = 2.63 + text = "N" + intervals [21]: + xmin = 2.63 + xmax = 2.68 + text = "W" + intervals [22]: + xmin = 2.68 + xmax = 2.78 + text = "IY1" + intervals [23]: + xmin = 2.78 + xmax = 2.88 + text = "K" + intervals [24]: + xmin = 2.88 + xmax = 3.01 + text = "EH2" + intervals [25]: + xmin = 3.01 + xmax = 3.14 + text = "N" + intervals [26]: + xmin = 3.14 + xmax = 3.2 + text = "D" + intervals [27]: + xmin = 3.2 + xmax = 3.32 + text = "Z" + intervals [28]: + xmin = 3.32 + xmax = 3.47 + text = "IH1" + intervals [29]: + xmin = 3.47 + xmax = 3.58 + text = "Z" + intervals [30]: + xmin = 3.58 + xmax = 3.64 + text = "R" + intervals [31]: + xmin = 3.64 + xmax = 3.7 + text = "IY0" + intervals [32]: + xmin = 3.7 + xmax = 3.8 + text = "L" + intervals [33]: + xmin = 3.8 + xmax = 3.96 + text = "AE1" + intervals [34]: + xmin = 3.96 + xmax = 4.02 + text = "K" + intervals [35]: + xmin = 4.02 + xmax = 4.11 + text = "S" + intervals [36]: + xmin = 4.11 + xmax = 4.2 + text = "IH0" + intervals [37]: + xmin = 4.2 + xmax = 4.41 + text = "NG" + intervals [38]: + xmin = 4.41 + xmax = 4.52 + text = "" + intervals [39]: + xmin = 4.52 + xmax = 4.97 + text = "AH0" + intervals [40]: + xmin = 4.97 + xmax = 5.01 + text = "N" + intervals [41]: + xmin = 5.01 + xmax = 5.05 + text = "D" + intervals [42]: + xmin = 5.05 + xmax = 5.14 + text = "AY1" + intervals [43]: + xmin = 5.14 + xmax = 5.19 + text = "TH" + intervals [44]: + xmin = 5.19 + xmax = 5.25 + text = "IH1" + intervals [45]: + xmin = 5.25 + xmax = 5.29 + text = "NG" + intervals [46]: + xmin = 5.29 + xmax = 5.33 + text = "K" + intervals [47]: + xmin = 5.33 + xmax = 5.36 + text = "AY1" + intervals [48]: + xmin = 5.36 + xmax = 5.41 + text = "L" + intervals [49]: + xmin = 5.41 + xmax = 5.44 + text = "G" + intervals [50]: + xmin = 5.44 + xmax = 5.5 + text = "OW1" + intervals [51]: + xmin = 5.5 + xmax = 5.68 + text = "SH" + intervals [52]: + xmin = 5.68 + xmax = 5.87 + text = "AA1" + intervals [53]: + xmin = 5.87 + xmax = 5.92 + text = "P" + intervals [54]: + xmin = 5.92 + xmax = 5.96 + text = "IH0" + intervals [55]: + xmin = 5.96 + xmax = 6 + text = "NG" + intervals [56]: + xmin = 6 + xmax = 6.06 + text = "IH0" + intervals [57]: + xmin = 6.06 + xmax = 6.11 + text = "F" + intervals [58]: + xmin = 6.11 + xmax = 6.16 + text = "AY1" + intervals [59]: + xmin = 6.16 + xmax = 6.29 + text = "M" + intervals [60]: + xmin = 6.29 + xmax = 6.35 + text = "N" + intervals [61]: + xmin = 6.35 + xmax = 6.48 + text = "AA1" + intervals [62]: + xmin = 6.48 + xmax = 6.54 + text = "T" + intervals [63]: + xmin = 6.54 + xmax = 6.58 + text = "DH" + intervals [64]: + xmin = 6.58 + xmax = 6.64 + text = "AE1" + intervals [65]: + xmin = 6.64 + xmax = 6.7 + text = "T" + intervals [66]: + xmin = 6.7 + xmax = 6.78 + text = "T" + intervals [67]: + xmin = 6.78 + xmax = 6.93 + text = "AY1" + intervals [68]: + xmin = 6.93 + xmax = 7.08 + text = "ER0" + intervals [69]: + xmin = 7.08 + xmax = 7.19 + text = "D" + intervals [70]: + xmin = 7.19 + xmax = 7.45 + text = "" + intervals [71]: + xmin = 7.45 + xmax = 7.59 + text = "S" + intervals [72]: + xmin = 7.59 + xmax = 7.62 + text = "OW1" + intervals [73]: + xmin = 7.62 + xmax = 7.66 + text = "DH" + intervals [74]: + xmin = 7.66 + xmax = 7.71 + text = "AE1" + intervals [75]: + xmin = 7.71 + xmax = 7.74 + text = "T" + intervals [76]: + xmin = 7.74 + xmax = 7.77 + text = "Y" + intervals [77]: + xmin = 7.77 + xmax = 7.85 + text = "UW1" + intervals [78]: + xmin = 7.85 + xmax = 7.92 + text = "S" + intervals [79]: + xmin = 7.92 + xmax = 7.97 + text = "T" + intervals [80]: + xmin = 7.97 + xmax = 8.02 + text = "AA1" + intervals [81]: + xmin = 8.02 + xmax = 8.05 + text = "R" + intervals [82]: + xmin = 8.05 + xmax = 8.08 + text = "T" + intervals [83]: + xmin = 8.08 + xmax = 8.11 + text = "AH0" + intervals [84]: + xmin = 8.11 + xmax = 8.14 + text = "D" + intervals [85]: + xmin = 8.14 + xmax = 8.17 + text = "B" + intervals [86]: + xmin = 8.17 + xmax = 8.24 + text = "AY1" + intervals [87]: + xmin = 8.24 + xmax = 8.35 + text = "JH" + intervals [88]: + xmin = 8.35 + xmax = 8.48 + text = "AA1" + intervals [89]: + xmin = 8.48 + xmax = 8.52 + text = "B" + intervals [90]: + xmin = 8.52 + xmax = 8.59 + text = "AY1" + intervals [91]: + xmin = 8.59 + xmax = 8.64 + text = "TH" + intervals [92]: + xmin = 8.64 + xmax = 8.69 + text = "IH1" + intervals [93]: + xmin = 8.69 + xmax = 8.72 + text = "NG" + intervals [94]: + xmin = 8.72 + xmax = 8.75 + text = "K" + intervals [95]: + xmin = 8.75 + xmax = 8.79 + text = "IH1" + intervals [96]: + xmin = 8.79 + xmax = 8.84 + text = "T" + intervals [97]: + xmin = 8.84 + xmax = 8.88 + text = "S" + intervals [98]: + xmin = 8.88 + xmax = 9.08 + text = "V" + intervals [99]: + xmin = 9.08 + xmax = 9.2 + text = "EH1" + intervals [100]: + xmin = 9.2 + xmax = 9.28 + text = "R" + intervals [101]: + xmin = 9.28 + xmax = 9.35 + text = "IY0" + intervals [102]: + xmin = 9.35 + xmax = 9.4 + text = "IH0" + intervals [103]: + xmin = 9.4 + xmax = 9.46 + text = "M" + intervals [104]: + xmin = 9.46 + xmax = 9.55 + text = "P" + intervals [105]: + xmin = 9.55 + xmax = 9.63 + text = "AO1" + intervals [106]: + xmin = 9.63 + xmax = 9.68 + text = "R" + intervals [107]: + xmin = 9.68 + xmax = 9.71 + text = "T" + intervals [108]: + xmin = 9.71 + xmax = 9.74 + text = "AH0" + intervals [109]: + xmin = 9.74 + xmax = 9.77 + text = "N" + intervals [110]: + xmin = 9.77 + xmax = 9.8 + text = "T" + intervals [111]: + xmin = 9.8 + xmax = 9.83 + text = "T" + intervals [112]: + xmin = 9.83 + xmax = 9.87 + text = "IH0" + intervals [113]: + xmin = 9.87 + xmax = 9.93 + text = "G" + intervals [114]: + xmin = 9.93 + xmax = 9.96 + text = "EH1" + intervals [115]: + xmin = 9.96 + xmax = 9.99 + text = "T" + intervals [116]: + xmin = 9.99 + xmax = 10.03 + text = "AH0" + intervals [117]: + xmin = 10.03 + xmax = 10.07 + text = "G" + intervals [118]: + xmin = 10.07 + xmax = 10.1 + text = "IH0" + intervals [119]: + xmin = 10.1 + xmax = 10.17 + text = "D" + intervals [120]: + xmin = 10.17 + xmax = 10.35 + text = "S" + intervals [121]: + xmin = 10.35 + xmax = 10.43 + text = "L" + intervals [122]: + xmin = 10.43 + xmax = 10.53 + text = "IY1" + intervals [123]: + xmin = 10.53 + xmax = 10.56 + text = "P" + intervals [124]: + xmin = 10.56 + xmax = 10.8 + text = "D" + intervals [125]: + xmin = 10.8 + xmax = 10.92 + text = "ER1" + intervals [126]: + xmin = 10.92 + xmax = 10.99 + text = "IH0" + intervals [127]: + xmin = 10.99 + xmax = 11.14 + text = "NG" + intervals [128]: + xmin = 11.14 + xmax = 11.2 + text = "Y" + intervals [129]: + xmin = 11.2 + xmax = 11.23 + text = "UH1" + intervals [130]: + xmin = 11.23 + xmax = 11.32 + text = "R" + intervals [131]: + xmin = 11.32 + xmax = 11.4 + text = "W" + intervals [132]: + xmin = 11.4 + xmax = 11.51 + text = "IY1" + intervals [133]: + xmin = 11.51 + xmax = 11.6 + text = "K" + intervals [134]: + xmin = 11.6 + xmax = 11.68 + text = "EH2" + intervals [135]: + xmin = 11.68 + xmax = 11.74 + text = "N" + intervals [136]: + xmin = 11.74 + xmax = 11.77 + text = "D" + intervals [137]: + xmin = 11.77 + xmax = 11.8 + text = "B" + intervals [138]: + xmin = 11.8 + xmax = 11.88 + text = "IH0" + intervals [139]: + xmin = 11.88 + xmax = 12 + text = "K" + intervals [140]: + xmin = 12 + xmax = 12.26 + text = "AH1" + intervals [141]: + xmin = 12.26 + xmax = 12.4 + text = "Z" + intervals [142]: + xmin = 12.4 + xmax = 12.6 + text = "W" + intervals [143]: + xmin = 12.6 + xmax = 12.88 + text = "EH1" + intervals [144]: + xmin = 12.88 + xmax = 12.95 + text = "N" + intervals [145]: + xmin = 12.95 + xmax = 12.99 + text = "Y" + intervals [146]: + xmin = 12.99 + xmax = 13.04 + text = "UW1" + intervals [147]: + xmin = 13.04 + xmax = 13.07 + text = "HH" + intervals [148]: + xmin = 13.07 + xmax = 13.16 + text = "AE1" + intervals [149]: + xmin = 13.16 + xmax = 13.19 + text = "V" + intervals [150]: + xmin = 13.19 + xmax = 13.22 + text = "T" + intervals [151]: + xmin = 13.22 + xmax = 13.27 + text = "UW1" + intervals [152]: + xmin = 13.27 + xmax = 13.32 + text = "W" + intervals [153]: + xmin = 13.32 + xmax = 13.4 + text = "ER1" + intervals [154]: + xmin = 13.4 + xmax = 13.44 + text = "K" + intervals [155]: + xmin = 13.44 + xmax = 13.51 + text = "AA1" + intervals [156]: + xmin = 13.51 + xmax = 13.58 + text = "N" + intervals [157]: + xmin = 13.58 + xmax = 13.66 + text = "M" + intervals [158]: + xmin = 13.66 + xmax = 13.76 + text = "AH1" + intervals [159]: + xmin = 13.76 + xmax = 13.81 + text = "N" + intervals [160]: + xmin = 13.81 + xmax = 13.85 + text = "D" + intervals [161]: + xmin = 13.85 + xmax = 13.96 + text = "EY2" + intervals [162]: + xmin = 13.96 + xmax = 14.01 + text = "TH" + intervals [163]: + xmin = 14.01 + xmax = 14.04 + text = "R" + intervals [164]: + xmin = 14.04 + xmax = 14.1 + text = "UW1" + intervals [165]: + xmin = 14.1 + xmax = 14.17 + text = "F" + intervals [166]: + xmin = 14.17 + xmax = 14.26 + text = "R" + intervals [167]: + xmin = 14.26 + xmax = 14.4 + text = "AY1" + intervals [168]: + xmin = 14.4 + xmax = 14.45 + text = "D" + intervals [169]: + xmin = 14.45 + xmax = 14.75 + text = "EY2" + intervals [170]: + xmin = 14.75 + xmax = 15.41 + text = "" + intervals [171]: + xmin = 15.41 + xmax = 15.49 + text = "DH" + intervals [172]: + xmin = 15.49 + xmax = 15.53 + text = "AH1" + intervals [173]: + xmin = 15.53 + xmax = 15.62 + text = "HH" + intervals [174]: + xmin = 15.62 + xmax = 15.67 + text = "OW1" + intervals [175]: + xmin = 15.67 + xmax = 15.75 + text = "L" + intervals [176]: + xmin = 15.75 + xmax = 15.8 + text = "W" + intervals [177]: + xmin = 15.8 + xmax = 15.94 + text = "IY1" + intervals [178]: + xmin = 15.94 + xmax = 16.09 + text = "K" + intervals [179]: + xmin = 16.09 + xmax = 16.28 + text = "" + intervals [180]: + xmin = 16.28 + xmax = 16.38 + text = "Y" + intervals [181]: + xmin = 16.38 + xmax = 16.42 + text = "UW1" + intervals [182]: + xmin = 16.42 + xmax = 16.49 + text = "ER0" + intervals [183]: + xmin = 16.49 + xmax = 16.55 + text = "V" + intervals [184]: + xmin = 16.55 + xmax = 16.58 + text = "EH1" + intervals [185]: + xmin = 16.58 + xmax = 16.65 + text = "R" + intervals [186]: + xmin = 16.65 + xmax = 16.73 + text = "IY0" + intervals [187]: + xmin = 16.73 + xmax = 16.92 + text = "T" + intervals [188]: + xmin = 16.92 + xmax = 17.08 + text = "AY1" + intervals [189]: + xmin = 17.08 + xmax = 17.22 + text = "ER0" + intervals [190]: + xmin = 17.22 + xmax = 17.59 + text = "D" + intervals [191]: + xmin = 17.59 + xmax = 17.83 + text = "" + intervals [192]: + xmin = 17.83 + xmax = 18.02 + text = "S" + intervals [193]: + xmin = 18.02 + xmax = 18.29 + text = "OW1" + intervals [194]: + xmin = 18.29 + xmax = 18.37 + text = "G" + intervals [195]: + xmin = 18.37 + xmax = 18.42 + text = "IH1" + intervals [196]: + xmin = 18.42 + xmax = 18.46 + text = "T" + intervals [197]: + xmin = 18.46 + xmax = 18.5 + text = "IH0" + intervals [198]: + xmin = 18.5 + xmax = 18.55 + text = "NG" + intervals [199]: + xmin = 18.55 + xmax = 18.61 + text = "EY1" + intervals [200]: + xmin = 18.61 + xmax = 18.67 + text = "G" + intervals [201]: + xmin = 18.67 + xmax = 18.73 + text = "UH1" + intervals [202]: + xmin = 18.73 + xmax = 18.78 + text = "D" + intervals [203]: + xmin = 18.78 + xmax = 18.86 + text = "R" + intervals [204]: + xmin = 18.86 + xmax = 18.97 + text = "EH1" + intervals [205]: + xmin = 18.97 + xmax = 19.05 + text = "S" + intervals [206]: + xmin = 19.05 + xmax = 19.08 + text = "T" + intervals [207]: + xmin = 19.08 + xmax = 19.13 + text = "IH0" + intervals [208]: + xmin = 19.13 + xmax = 19.21 + text = "Z" + intervals [209]: + xmin = 19.21 + xmax = 19.24 + text = "EH1" + intervals [210]: + xmin = 19.24 + xmax = 19.3 + text = "Z" + intervals [211]: + xmin = 19.3 + xmax = 19.34 + text = "IH0" + intervals [212]: + xmin = 19.34 + xmax = 19.38 + text = "M" + intervals [213]: + xmin = 19.38 + xmax = 19.48 + text = "P" + intervals [214]: + xmin = 19.48 + xmax = 19.55 + text = "AO1" + intervals [215]: + xmin = 19.55 + xmax = 19.59 + text = "R" + intervals [216]: + xmin = 19.59 + xmax = 19.62 + text = "T" + intervals [217]: + xmin = 19.62 + xmax = 19.65 + text = "AH0" + intervals [218]: + xmin = 19.65 + xmax = 19.68 + text = "N" + intervals [219]: + xmin = 19.68 + xmax = 19.77 + text = "T" + intervals [220]: + xmin = 19.77 + xmax = 19.94 + text = "AE1" + intervals [221]: + xmin = 19.94 + xmax = 20.16 + text = "Z" + intervals [222]: + xmin = 20.16 + xmax = 20.3 + text = "" + intervals [223]: + xmin = 20.3 + xmax = 20.39 + text = "K" + intervals [224]: + xmin = 20.39 + xmax = 20.43 + text = "AH0" + intervals [225]: + xmin = 20.43 + xmax = 20.46 + text = "M" + intervals [226]: + xmin = 20.46 + xmax = 20.53 + text = "P" + intervals [227]: + xmin = 20.53 + xmax = 20.59 + text = "L" + intervals [228]: + xmin = 20.59 + xmax = 20.63 + text = "EY1" + intervals [229]: + xmin = 20.63 + xmax = 20.66 + text = "N" + intervals [230]: + xmin = 20.66 + xmax = 20.69 + text = "T" + intervals [231]: + xmin = 20.69 + xmax = 20.75 + text = "AH0" + intervals [232]: + xmin = 20.75 + xmax = 20.87 + text = "JH" + intervals [233]: + xmin = 20.87 + xmax = 21.09 + text = "AO1" + intervals [234]: + xmin = 21.09 + xmax = 21.3 + text = "ER0" + intervals [235]: + xmin = 21.3 + xmax = 21.44 + text = "K" + intervals [236]: + xmin = 21.44 + xmax = 21.47 + text = "AH0" + intervals [237]: + xmin = 21.47 + xmax = 21.5 + text = "M" + intervals [238]: + xmin = 21.5 + xmax = 21.53 + text = "P" + intervals [239]: + xmin = 21.53 + xmax = 21.6 + text = "L" + intervals [240]: + xmin = 21.6 + xmax = 21.63 + text = "IY1" + intervals [241]: + xmin = 21.63 + xmax = 21.66 + text = "T" + intervals [242]: + xmin = 21.66 + xmax = 21.72 + text = "IH0" + intervals [243]: + xmin = 21.72 + xmax = 21.79 + text = "NG" + intervals [244]: + xmin = 21.79 + xmax = 21.83 + text = "AH0" + intervals [245]: + xmin = 21.83 + xmax = 21.9 + text = "N" + intervals [246]: + xmin = 21.9 + xmax = 21.98 + text = "EH1" + intervals [247]: + xmin = 21.98 + xmax = 22.03 + text = "K" + intervals [248]: + xmin = 22.03 + xmax = 22.07 + text = "S" + intervals [249]: + xmin = 22.07 + xmax = 22.11 + text = "AH0" + intervals [250]: + xmin = 22.11 + xmax = 22.14 + text = "L" + intervals [251]: + xmin = 22.14 + xmax = 22.17 + text = "AH0" + intervals [252]: + xmin = 22.17 + xmax = 22.2 + text = "N" + intervals [253]: + xmin = 22.2 + xmax = 22.23 + text = "T" + intervals [254]: + xmin = 22.23 + xmax = 22.34 + text = "JH" + intervals [255]: + xmin = 22.34 + xmax = 22.5 + text = "AA1" + intervals [256]: + xmin = 22.5 + xmax = 22.64 + text = "B" + intervals [257]: + xmin = 22.64 + xmax = 23.04 + text = "" + intervals [258]: + xmin = 23.04 + xmax = 23.14 + text = "IH0" + intervals [259]: + xmin = 23.14 + xmax = 23.17 + text = "N" + intervals [260]: + xmin = 23.17 + xmax = 23.2 + text = "M" + intervals [261]: + xmin = 23.2 + xmax = 23.29 + text = "AY1" + intervals [262]: + xmin = 23.29 + xmax = 23.36 + text = "S" + intervals [263]: + xmin = 23.36 + xmax = 23.41 + text = "P" + intervals [264]: + xmin = 23.41 + xmax = 23.52 + text = "EH1" + intervals [265]: + xmin = 23.52 + xmax = 23.56 + text = "R" + intervals [266]: + xmin = 23.56 + xmax = 23.65 + text = "T" + intervals [267]: + xmin = 23.65 + xmax = 23.76 + text = "AY1" + intervals [268]: + xmin = 23.76 + xmax = 23.8 + text = "M" + intervals [269]: + xmin = 23.8 + xmax = 23.85 + text = "IH0" + intervals [270]: + xmin = 23.85 + xmax = 23.88 + text = "F" + intervals [271]: + xmin = 23.88 + xmax = 23.98 + text = "AY1" + intervals [272]: + xmin = 23.98 + xmax = 24.04 + text = "F" + intervals [273]: + xmin = 24.04 + xmax = 24.13 + text = "IY1" + intervals [274]: + xmin = 24.13 + xmax = 24.18 + text = "L" + intervals [275]: + xmin = 24.18 + xmax = 24.26 + text = "OW2" + intervals [276]: + xmin = 24.26 + xmax = 24.39 + text = "K" + intervals [277]: + xmin = 24.39 + xmax = 24.84 + text = "EY1" + intervals [278]: + xmin = 24.84 + xmax = 25.07 + text = "AY1" + intervals [279]: + xmin = 25.07 + xmax = 25.1 + text = "" + intervals [280]: + xmin = 25.1 + xmax = 25.29 + text = "L" + intervals [281]: + xmin = 25.29 + xmax = 25.35 + text = "AY1" + intervals [282]: + xmin = 25.35 + xmax = 25.38 + text = "K" + intervals [283]: + xmin = 25.38 + xmax = 25.41 + text = "T" + intervals [284]: + xmin = 25.41 + xmax = 25.44 + text = "IH0" + intervals [285]: + xmin = 25.44 + xmax = 25.5 + text = "G" + intervals [286]: + xmin = 25.5 + xmax = 25.55 + text = "OW1" + intervals [287]: + xmin = 25.55 + xmax = 25.59 + text = "F" + intervals [288]: + xmin = 25.59 + xmax = 25.79 + text = "ER0" + intervals [289]: + xmin = 25.79 + xmax = 25.83 + text = "AH0" + intervals [290]: + xmin = 25.83 + xmax = 25.94 + text = "HH" + intervals [291]: + xmin = 25.94 + xmax = 26.06 + text = "AY1" + intervals [292]: + xmin = 26.06 + xmax = 26.12 + text = "K" + intervals [293]: + xmin = 26.12 + xmax = 26.17 + text = "IH1" + intervals [294]: + xmin = 26.17 + xmax = 26.21 + text = "N" + intervals [295]: + xmin = 26.21 + xmax = 26.27 + text = "N" + intervals [296]: + xmin = 26.27 + xmax = 26.4 + text = "EY1" + intervals [297]: + xmin = 26.4 + xmax = 26.53 + text = "CH" + intervals [298]: + xmin = 26.53 + xmax = 26.81 + text = "ER0" + intervals [299]: + xmin = 26.81 + xmax = 27.11 + text = "" + intervals [300]: + xmin = 27.11 + xmax = 27.21 + text = "S" + intervals [301]: + xmin = 27.21 + xmax = 27.25 + text = "AH1" + intervals [302]: + xmin = 27.25 + xmax = 27.28 + text = "M" + intervals [303]: + xmin = 27.28 + xmax = 27.31 + text = "T" + intervals [304]: + xmin = 27.31 + xmax = 27.38 + text = "AY2" + intervals [305]: + xmin = 27.38 + xmax = 27.41 + text = "M" + intervals [306]: + xmin = 27.41 + xmax = 27.45 + text = "Z" + intervals [307]: + xmin = 27.45 + xmax = 27.51 + text = "AY1" + intervals [308]: + xmin = 27.51 + xmax = 27.6 + text = "T" + intervals [309]: + xmin = 27.6 + xmax = 27.67 + text = "R" + intervals [310]: + xmin = 27.67 + xmax = 27.74 + text = "AY1" + intervals [311]: + xmin = 27.74 + xmax = 27.77 + text = "T" + intervals [312]: + xmin = 27.77 + xmax = 27.88 + text = "AH0" + intervals [313]: + xmin = 27.88 + xmax = 28.02 + text = "AO1" + intervals [314]: + xmin = 28.02 + xmax = 28.07 + text = "R" + intervals [315]: + xmin = 28.07 + xmax = 28.12 + text = "G" + intervals [316]: + xmin = 28.12 + xmax = 28.15 + text = "AH0" + intervals [317]: + xmin = 28.15 + xmax = 28.18 + text = "N" + intervals [318]: + xmin = 28.18 + xmax = 28.3 + text = "AY2" + intervals [319]: + xmin = 28.3 + xmax = 28.37 + text = "Z" + intervals [320]: + xmin = 28.37 + xmax = 28.42 + text = "S" + intervals [321]: + xmin = 28.42 + xmax = 28.47 + text = "AH1" + intervals [322]: + xmin = 28.47 + xmax = 28.5 + text = "M" + intervals [323]: + xmin = 28.5 + xmax = 28.53 + text = "TH" + intervals [324]: + xmin = 28.53 + xmax = 28.61 + text = "IH0" + intervals [325]: + xmin = 28.61 + xmax = 28.94 + text = "NG" + intervals [326]: + xmin = 28.94 + xmax = 28.98 + text = "" + intervals [327]: + xmin = 28.98 + xmax = 29.08 + text = "F" + intervals [328]: + xmin = 29.08 + xmax = 29.13 + text = "AO1" + intervals [329]: + xmin = 29.13 + xmax = 29.19 + text = "R" + intervals [330]: + xmin = 29.19 + xmax = 29.23 + text = "M" + intervals [331]: + xmin = 29.23 + xmax = 29.32 + text = "AY1" + intervals [332]: + xmin = 29.32 + xmax = 29.41 + text = "F" + intervals [333]: + xmin = 29.41 + xmax = 29.49 + text = "R" + intervals [334]: + xmin = 29.49 + xmax = 29.6 + text = "EH1" + intervals [335]: + xmin = 29.6 + xmax = 29.65 + text = "N" + intervals [336]: + xmin = 29.65 + xmax = 29.7 + text = "D" + intervals [337]: + xmin = 29.7 + xmax = 29.89 + text = "Z" + intervals [338]: + xmin = 29.89 + xmax = 29.92 + text = "" + intervals [339]: + xmin = 29.92 + xmax = 29.95 + text = "AY1" + intervals [340]: + xmin = 29.95 + xmax = 30.2 + text = "" + intervals [341]: + xmin = 30.2 + xmax = 30.26 + text = "V" + intervals [342]: + xmin = 30.26 + xmax = 30.39 + text = "AA2" + intervals [343]: + xmin = 30.39 + xmax = 30.45 + text = "L" + intervals [344]: + xmin = 30.45 + xmax = 30.48 + text = "AH0" + intervals [345]: + xmin = 30.48 + xmax = 30.51 + text = "N" + intervals [346]: + xmin = 30.51 + xmax = 30.6 + text = "T" + intervals [347]: + xmin = 30.6 + xmax = 30.67 + text = "IH1" + intervals [348]: + xmin = 30.67 + xmax = 30.73 + text = "R" + intervals [349]: + xmin = 30.73 + xmax = 30.77 + text = "AE1" + intervals [350]: + xmin = 30.77 + xmax = 30.86 + text = "T" + intervals [351]: + xmin = 30.86 + xmax = 30.91 + text = "DH" + intervals [352]: + xmin = 30.91 + xmax = 30.97 + text = "AH1" + intervals [353]: + xmin = 30.97 + xmax = 31.13 + text = "B" + intervals [354]: + xmin = 31.13 + xmax = 31.19 + text = "UW1" + intervals [355]: + xmin = 31.19 + xmax = 31.24 + text = "D" + intervals [356]: + xmin = 31.24 + xmax = 31.3 + text = "AH0" + intervals [357]: + xmin = 31.3 + xmax = 31.35 + text = "S" + intervals [358]: + xmin = 31.35 + xmax = 31.38 + text = "T" + intervals [359]: + xmin = 31.38 + xmax = 31.41 + text = "T" + intervals [360]: + xmin = 31.41 + xmax = 31.47 + text = "EH1" + intervals [361]: + xmin = 31.47 + xmax = 31.52 + text = "M" + intervals [362]: + xmin = 31.52 + xmax = 31.56 + text = "P" + intervals [363]: + xmin = 31.56 + xmax = 31.61 + text = "AH0" + intervals [364]: + xmin = 31.61 + xmax = 31.83 + text = "L" + intervals [365]: + xmin = 31.83 + xmax = 31.9 + text = "AO1" + intervals [366]: + xmin = 31.9 + xmax = 31.94 + text = "N" + intervals [367]: + xmin = 31.94 + xmax = 31.97 + text = "DH" + intervals [368]: + xmin = 31.97 + xmax = 32.01 + text = "AH1" + intervals [369]: + xmin = 32.01 + xmax = 32.08 + text = "W" + intervals [370]: + xmin = 32.08 + xmax = 32.17 + text = "IY1" + intervals [371]: + xmin = 32.17 + xmax = 32.26 + text = "K" + intervals [372]: + xmin = 32.26 + xmax = 32.45 + text = "EH2" + intervals [373]: + xmin = 32.45 + xmax = 32.51 + text = "N" + intervals [374]: + xmin = 32.51 + xmax = 32.6 + text = "D" + intervals [375]: + xmin = 32.6 + xmax = 32.88 + text = "AO1" + intervals [376]: + xmin = 32.88 + xmax = 33.01 + text = "R" + intervals [377]: + xmin = 33.01 + xmax = 33.24 + text = "AY1" + intervals [378]: + xmin = 33.24 + xmax = 33.36 + text = "K" + intervals [379]: + xmin = 33.36 + xmax = 33.51 + text = "AE1" + intervals [380]: + xmin = 33.51 + xmax = 33.62 + text = "N" + intervals [381]: + xmin = 33.62 + xmax = 33.7 + text = "JH" + intervals [382]: + xmin = 33.7 + xmax = 33.77 + text = "IH0" + intervals [383]: + xmin = 33.77 + xmax = 33.8 + text = "S" + intervals [384]: + xmin = 33.8 + xmax = 33.91 + text = "T" + intervals [385]: + xmin = 33.91 + xmax = 33.96 + text = "W" + intervals [386]: + xmin = 33.96 + xmax = 34.2 + text = "AO1" + intervals [387]: + xmin = 34.2 + xmax = 34.3 + text = "K" + intervals [388]: + xmin = 34.3 + xmax = 34.42 + text = "ER0" + intervals [389]: + xmin = 34.42 + xmax = 34.63 + text = "AW1" + intervals [390]: + xmin = 34.63 + xmax = 34.69 + text = "N" + intervals [391]: + xmin = 34.69 + xmax = 34.76 + text = "IH0" + intervals [392]: + xmin = 34.76 + xmax = 34.8 + text = "N" + intervals [393]: + xmin = 34.8 + xmax = 34.9 + text = "JH" + intervals [394]: + xmin = 34.9 + xmax = 34.99 + text = "OY1" + intervals [395]: + xmin = 34.99 + xmax = 35.03 + text = "IH0" + intervals [396]: + xmin = 35.03 + xmax = 35.08 + text = "NG" + intervals [397]: + xmin = 35.08 + xmax = 35.12 + text = "DH" + intervals [398]: + xmin = 35.12 + xmax = 35.17 + text = "AH0" + intervals [399]: + xmin = 35.17 + xmax = 35.26 + text = "S" + intervals [400]: + xmin = 35.26 + xmax = 35.33 + text = "AH1" + intervals [401]: + xmin = 35.33 + xmax = 35.4 + text = "N" + intervals [402]: + xmin = 35.4 + xmax = 35.53 + text = "SH" + intervals [403]: + xmin = 35.53 + xmax = 35.69 + text = "AY2" + intervals [404]: + xmin = 35.69 + xmax = 35.87 + text = "N" + intervals [405]: + xmin = 35.87 + xmax = 36.15 + text = "" + intervals [406]: + xmin = 36.15 + xmax = 36.3 + text = "AY1" + intervals [407]: + xmin = 36.3 + xmax = 36.34 + text = "D" + intervals [408]: + xmin = 36.34 + xmax = 36.38 + text = "L" + intervals [409]: + xmin = 36.38 + xmax = 36.49 + text = "AY1" + intervals [410]: + xmin = 36.49 + xmax = 36.52 + text = "K" + intervals [411]: + xmin = 36.52 + xmax = 36.56 + text = "T" + intervals [412]: + xmin = 36.56 + xmax = 36.59 + text = "AH0" + intervals [413]: + xmin = 36.59 + xmax = 36.62 + text = "HH" + intervals [414]: + xmin = 36.62 + xmax = 36.7 + text = "AE1" + intervals [415]: + xmin = 36.7 + xmax = 36.74 + text = "V" + intervals [416]: + xmin = 36.74 + xmax = 36.79 + text = "AH0" + intervals [417]: + xmin = 36.79 + xmax = 36.83 + text = "HH" + intervals [418]: + xmin = 36.83 + xmax = 36.88 + text = "EH1" + intervals [419]: + xmin = 36.88 + xmax = 36.93 + text = "L" + intervals [420]: + xmin = 36.93 + xmax = 37.01 + text = "TH" + intervals [421]: + xmin = 37.01 + xmax = 37.06 + text = "IY0" + intervals [422]: + xmin = 37.06 + xmax = 37.12 + text = "L" + intervals [423]: + xmin = 37.12 + xmax = 37.23 + text = "AY1" + intervals [424]: + xmin = 37.23 + xmax = 37.27 + text = "F" + intervals [425]: + xmin = 37.27 + xmax = 37.34 + text = "S" + intervals [426]: + xmin = 37.34 + xmax = 37.39 + text = "T" + intervals [427]: + xmin = 37.39 + xmax = 37.56 + text = "AY2" + intervals [428]: + xmin = 37.56 + xmax = 37.66 + text = "L" + intervals [429]: + xmin = 37.66 + xmax = 37.73 + text = "K" + intervals [430]: + xmin = 37.73 + xmax = 37.77 + text = "AH0" + intervals [431]: + xmin = 37.77 + xmax = 37.82 + text = "N" + intervals [432]: + xmin = 37.82 + xmax = 37.87 + text = "S" + intervals [433]: + xmin = 37.87 + xmax = 37.91 + text = "IH1" + intervals [434]: + xmin = 37.91 + xmax = 37.94 + text = "D" + intervals [435]: + xmin = 37.94 + xmax = 37.98 + text = "ER0" + intervals [436]: + xmin = 37.98 + xmax = 38.02 + text = "IH0" + intervals [437]: + xmin = 38.02 + xmax = 38.06 + text = "NG" + intervals [438]: + xmin = 38.06 + xmax = 38.13 + text = "HH" + intervals [439]: + xmin = 38.13 + xmax = 38.17 + text = "AW1" + intervals [440]: + xmin = 38.17 + xmax = 38.23 + text = "M" + intervals [441]: + xmin = 38.23 + xmax = 38.27 + text = "AH1" + intervals [442]: + xmin = 38.27 + xmax = 38.38 + text = "CH" + intervals [443]: + xmin = 38.38 + xmax = 38.5 + text = "T" + intervals [444]: + xmin = 38.5 + xmax = 38.67 + text = "AY1" + intervals [445]: + xmin = 38.67 + xmax = 38.74 + text = "M" + intervals [446]: + xmin = 38.74 + xmax = 38.81 + text = "AY1" + intervals [447]: + xmin = 38.81 + xmax = 38.95 + text = "S" + intervals [448]: + xmin = 38.95 + xmax = 39.02 + text = "P" + intervals [449]: + xmin = 39.02 + xmax = 39.09 + text = "EH1" + intervals [450]: + xmin = 39.09 + xmax = 39.12 + text = "N" + intervals [451]: + xmin = 39.12 + xmax = 39.18 + text = "D" + intervals [452]: + xmin = 39.18 + xmax = 39.21 + text = "AE1" + intervals [453]: + xmin = 39.21 + xmax = 39.29 + text = "T" + intervals [454]: + xmin = 39.29 + xmax = 39.47 + text = "W" + intervals [455]: + xmin = 39.47 + xmax = 39.69 + text = "ER1" + intervals [456]: + xmin = 39.69 + xmax = 39.84 + text = "K" + intervals [457]: + xmin = 39.84 + xmax = 40.29 + text = "" + intervals [458]: + xmin = 40.29 + xmax = 40.52 + text = "AY1" + intervals [459]: + xmin = 40.52 + xmax = 40.56 + text = "AO1" + intervals [460]: + xmin = 40.56 + xmax = 40.59 + text = "L" + intervals [461]: + xmin = 40.59 + xmax = 40.66 + text = "W" + intervals [462]: + xmin = 40.66 + xmax = 40.7 + text = "IY0" + intervals [463]: + xmin = 40.7 + xmax = 40.79 + text = "Z" + intervals [464]: + xmin = 40.79 + xmax = 40.94 + text = "T" + intervals [465]: + xmin = 40.94 + xmax = 41.05 + text = "R" + intervals [466]: + xmin = 41.05 + xmax = 41.28 + text = "AY1" + intervals [467]: + xmin = 41.28 + xmax = 41.38 + text = "T" + intervals [468]: + xmin = 41.38 + xmax = 41.47 + text = "IH0" + intervals [469]: + xmin = 41.47 + xmax = 41.7 + text = "M" + intervals [470]: + xmin = 41.7 + xmax = 41.77 + text = "UW1" + intervals [471]: + xmin = 41.77 + xmax = 41.85 + text = "V" + intervals [472]: + xmin = 41.85 + xmax = 41.9 + text = "EH1" + intervals [473]: + xmin = 41.9 + xmax = 42 + text = "Z" + intervals [474]: + xmin = 42 + xmax = 42.08 + text = "M" + intervals [475]: + xmin = 42.08 + xmax = 42.13 + text = "AH1" + intervals [476]: + xmin = 42.13 + xmax = 42.22 + text = "CH" + intervals [477]: + xmin = 42.22 + xmax = 42.26 + text = "EH1" + intervals [478]: + xmin = 42.26 + xmax = 42.31 + text = "Z" + intervals [479]: + xmin = 42.31 + xmax = 42.4 + text = "AY1" + intervals [480]: + xmin = 42.4 + xmax = 42.51 + text = "K" + intervals [481]: + xmin = 42.51 + xmax = 42.64 + text = "AE1" + intervals [482]: + xmin = 42.64 + xmax = 42.76 + text = "N" + intervals [483]: + xmin = 42.76 + xmax = 42.81 + text = "W" + intervals [484]: + xmin = 42.81 + xmax = 42.84 + text = "EH1" + intervals [485]: + xmin = 42.84 + xmax = 42.89 + text = "N" + intervals [486]: + xmin = 42.89 + xmax = 42.95 + text = "AH0" + intervals [487]: + xmin = 42.95 + xmax = 42.98 + text = "M" + intervals [488]: + xmin = 42.98 + xmax = 43.03 + text = "N" + intervals [489]: + xmin = 43.03 + xmax = 43.12 + text = "AA1" + intervals [490]: + xmin = 43.12 + xmax = 43.18 + text = "T" + intervals [491]: + xmin = 43.18 + xmax = 43.28 + text = "W" + intervals [492]: + xmin = 43.28 + xmax = 43.42 + text = "ER1" + intervals [493]: + xmin = 43.42 + xmax = 43.49 + text = "K" + intervals [494]: + xmin = 43.49 + xmax = 43.53 + text = "IH0" + intervals [495]: + xmin = 43.53 + xmax = 43.76 + text = "NG" + intervals [496]: + xmin = 43.76 + xmax = 44.5 + text = "" + intervals [497]: + xmin = 44.5 + xmax = 44.86 + text = "AH0" + intervals [498]: + xmin = 44.86 + xmax = 45.15 + text = "N" + intervals [499]: + xmin = 45.15 + xmax = 45.19 + text = "D" + intervals [500]: + xmin = 45.19 + xmax = 45.27 + text = "AA1" + intervals [501]: + xmin = 45.27 + xmax = 45.32 + text = "N" + intervals [502]: + xmin = 45.32 + xmax = 45.4 + text = "AH1" + intervals [503]: + xmin = 45.4 + xmax = 45.46 + text = "DH" + intervals [504]: + xmin = 45.46 + xmax = 45.49 + text = "ER0" + intervals [505]: + xmin = 45.49 + xmax = 45.55 + text = "D" + intervals [506]: + xmin = 45.55 + xmax = 45.74 + text = "EY1" + intervals [507]: + xmin = 45.74 + xmax = 45.82 + text = "Z" + intervals [508]: + xmin = 45.82 + xmax = 45.89 + text = "W" + intervals [509]: + xmin = 45.89 + xmax = 45.92 + text = "EH1" + intervals [510]: + xmin = 45.92 + xmax = 45.96 + text = "N" + intervals [511]: + xmin = 45.96 + xmax = 46.09 + text = "AY1" + intervals [512]: + xmin = 46.09 + xmax = 46.16 + text = "M" + intervals [513]: + xmin = 46.16 + xmax = 46.29 + text = "F" + intervals [514]: + xmin = 46.29 + xmax = 46.39 + text = "R" + intervals [515]: + xmin = 46.39 + xmax = 46.65 + text = "IY1" + intervals [516]: + xmin = 46.65 + xmax = 46.86 + text = "AY1" + intervals [517]: + xmin = 46.86 + xmax = 46.94 + text = "L" + intervals [518]: + xmin = 46.94 + xmax = 47.08 + text = "AY1" + intervals [519]: + xmin = 47.08 + xmax = 47.16 + text = "K" + intervals [520]: + xmin = 47.16 + xmax = 47.25 + text = "T" + intervals [521]: + xmin = 47.25 + xmax = 47.39 + text = "UW1" + intervals [522]: + xmin = 47.39 + xmax = 47.48 + text = "L" + intervals [523]: + xmin = 47.48 + xmax = 47.53 + text = "IH1" + intervals [524]: + xmin = 47.53 + xmax = 47.6 + text = "S" + intervals [525]: + xmin = 47.6 + xmax = 47.64 + text = "AH0" + intervals [526]: + xmin = 47.64 + xmax = 47.86 + text = "N" + intervals [527]: + xmin = 47.86 + xmax = 47.93 + text = "T" + intervals [528]: + xmin = 47.93 + xmax = 48.03 + text = "IH0" + intervals [529]: + xmin = 48.03 + xmax = 48.07 + text = "M" + intervals [530]: + xmin = 48.07 + xmax = 48.15 + text = "Y" + intervals [531]: + xmin = 48.15 + xmax = 48.2 + text = "UW1" + intervals [532]: + xmin = 48.2 + xmax = 48.27 + text = "Z" + intervals [533]: + xmin = 48.27 + xmax = 48.35 + text = "IH0" + intervals [534]: + xmin = 48.35 + xmax = 48.41 + text = "K" + intervals [535]: + xmin = 48.41 + xmax = 48.48 + text = "AH0" + intervals [536]: + xmin = 48.48 + xmax = 48.56 + text = "N" + intervals [537]: + xmin = 48.56 + xmax = 48.73 + text = "D" + intervals [538]: + xmin = 48.73 + xmax = 48.76 + text = "" + intervals [539]: + xmin = 48.76 + xmax = 48.91 + text = "W" + intervals [540]: + xmin = 48.91 + xmax = 49.01 + text = "ER1" + intervals [541]: + xmin = 49.01 + xmax = 49.13 + text = "W" + intervals [542]: + xmin = 49.13 + xmax = 49.23 + text = "AA1" + intervals [543]: + xmin = 49.23 + xmax = 49.3 + text = "CH" + intervals [544]: + xmin = 49.3 + xmax = 49.38 + text = "AH0" + intervals [545]: + xmin = 49.38 + xmax = 49.46 + text = "D" + intervals [546]: + xmin = 49.46 + xmax = 49.56 + text = "AA2" + intervals [547]: + xmin = 49.56 + xmax = 49.62 + text = "K" + intervals [548]: + xmin = 49.62 + xmax = 49.66 + text = "Y" + intervals [549]: + xmin = 49.66 + xmax = 49.7 + text = "AH0" + intervals [550]: + xmin = 49.7 + xmax = 49.76 + text = "M" + intervals [551]: + xmin = 49.76 + xmax = 49.81 + text = "EH1" + intervals [552]: + xmin = 49.81 + xmax = 49.85 + text = "N" + intervals [553]: + xmin = 49.85 + xmax = 49.98 + text = "ER0" + intervals [554]: + xmin = 49.98 + xmax = 50.05 + text = "IY0" + intervals [555]: + xmin = 50.05 + xmax = 50.17 + text = "M" + intervals [556]: + xmin = 50.17 + xmax = 50.2 + text = "UW1" + intervals [557]: + xmin = 50.2 + xmax = 50.28 + text = "V" + intervals [558]: + xmin = 50.28 + xmax = 50.38 + text = "IY0" + intervals [559]: + xmin = 50.38 + xmax = 50.51 + text = "Z" + intervals [560]: + xmin = 50.51 + xmax = 50.75 + text = "AA1" + intervals [561]: + xmin = 50.75 + xmax = 50.82 + text = "N" + intervals [562]: + xmin = 50.82 + xmax = 50.9 + text = "M" + intervals [563]: + xmin = 50.9 + xmax = 51.11 + text = "AY1" + intervals [564]: + xmin = 51.11 + xmax = 51.22 + text = "L" + intervals [565]: + xmin = 51.22 + xmax = 51.39 + text = "AE1" + intervals [566]: + xmin = 51.39 + xmax = 51.44 + text = "P" + intervals [567]: + xmin = 51.44 + xmax = 51.49 + text = "T" + intervals [568]: + xmin = 51.49 + xmax = 51.66 + text = "AA2" + intervals [569]: + xmin = 51.66 + xmax = 51.81 + text = "P" + intervals [570]: + xmin = 51.81 + xmax = 52.14 + text = "" + intervals [571]: + xmin = 52.14 + xmax = 52.2 + text = "B" + intervals [572]: + xmin = 52.2 + xmax = 52.33 + text = "AH1" + intervals [573]: + xmin = 52.33 + xmax = 52.44 + text = "T" + intervals [574]: + xmin = 52.44 + xmax = 52.51 + text = "S" + intervals [575]: + xmin = 52.51 + xmax = 52.59 + text = "AH1" + intervals [576]: + xmin = 52.59 + xmax = 52.64 + text = "M" + intervals [577]: + xmin = 52.64 + xmax = 52.67 + text = "T" + intervals [578]: + xmin = 52.67 + xmax = 52.77 + text = "AY2" + intervals [579]: + xmin = 52.77 + xmax = 52.82 + text = "M" + intervals [580]: + xmin = 52.82 + xmax = 52.86 + text = "Z" + intervals [581]: + xmin = 52.86 + xmax = 52.9 + text = "IH1" + intervals [582]: + xmin = 52.9 + xmax = 52.93 + text = "T" + intervals [583]: + xmin = 52.93 + xmax = 52.98 + text = "JH" + intervals [584]: + xmin = 52.98 + xmax = 53.07 + text = "IH0" + intervals [585]: + xmin = 53.07 + xmax = 53.1 + text = "S" + intervals [586]: + xmin = 53.1 + xmax = 53.13 + text = "T" + intervals [587]: + xmin = 53.13 + xmax = 53.18 + text = "S" + intervals [588]: + xmin = 53.18 + xmax = 53.26 + text = "L" + intervals [589]: + xmin = 53.26 + xmax = 53.35 + text = "IY1" + intervals [590]: + xmin = 53.35 + xmax = 53.61 + text = "P" + intervals [591]: + xmin = 53.61 + xmax = 53.65 + text = "" + intervals [592]: + xmin = 53.65 + xmax = 53.83 + text = "AY1" + intervals [593]: + xmin = 53.83 + xmax = 53.88 + text = "AH0" + intervals [594]: + xmin = 53.88 + xmax = 53.95 + text = "S" + intervals [595]: + xmin = 53.95 + xmax = 54 + text = "P" + intervals [596]: + xmin = 54 + xmax = 54.09 + text = "EH1" + intervals [597]: + xmin = 54.09 + xmax = 54.19 + text = "SH" + intervals [598]: + xmin = 54.19 + xmax = 54.22 + text = "L" + intervals [599]: + xmin = 54.22 + xmax = 54.27 + text = "IY0" + intervals [600]: + xmin = 54.27 + xmax = 54.33 + text = "L" + intervals [601]: + xmin = 54.33 + xmax = 54.43 + text = "AY1" + intervals [602]: + xmin = 54.43 + xmax = 54.57 + text = "K" + intervals [603]: + xmin = 54.57 + xmax = 54.61 + text = "T" + intervals [604]: + xmin = 54.61 + xmax = 54.69 + text = "W" + intervals [605]: + xmin = 54.69 + xmax = 54.79 + text = "AA1" + intervals [606]: + xmin = 54.79 + xmax = 54.85 + text = "CH" + intervals [607]: + xmin = 54.85 + xmax = 54.89 + text = "IH0" + intervals [608]: + xmin = 54.89 + xmax = 55.01 + text = "NG" + intervals [609]: + xmin = 55.01 + xmax = 55.12 + text = "JH" + intervals [610]: + xmin = 55.12 + xmax = 55.25 + text = "AE2" + intervals [611]: + xmin = 55.25 + xmax = 55.3 + text = "P" + intervals [612]: + xmin = 55.3 + xmax = 55.35 + text = "AH0" + intervals [613]: + xmin = 55.35 + xmax = 55.4 + text = "N" + intervals [614]: + xmin = 55.4 + xmax = 55.59 + text = "IY1" + intervals [615]: + xmin = 55.59 + xmax = 55.62 + text = "Z" + intervals [616]: + xmin = 55.62 + xmax = 55.77 + text = "AE1" + intervals [617]: + xmin = 55.77 + xmax = 55.83 + text = "N" + intervals [618]: + xmin = 55.83 + xmax = 55.87 + text = "AH0" + intervals [619]: + xmin = 55.87 + xmax = 55.91 + text = "M" + intervals [620]: + xmin = 55.91 + xmax = 56.33 + text = "AY1" + intervals [621]: + xmin = 56.33 + xmax = 56.85 + text = "" + intervals [622]: + xmin = 56.85 + xmax = 56.99 + text = "TH" + intervals [623]: + xmin = 56.99 + xmax = 57.05 + text = "IH1" + intervals [624]: + xmin = 57.05 + xmax = 57.09 + text = "NG" + intervals [625]: + xmin = 57.09 + xmax = 57.12 + text = "K" + intervals [626]: + xmin = 57.12 + xmax = 57.2 + text = "W" + intervals [627]: + xmin = 57.2 + xmax = 57.27 + text = "AA1" + intervals [628]: + xmin = 57.27 + xmax = 57.35 + text = "CH" + intervals [629]: + xmin = 57.35 + xmax = 57.4 + text = "IH0" + intervals [630]: + xmin = 57.4 + xmax = 57.43 + text = "NG" + intervals [631]: + xmin = 57.43 + xmax = 57.62 + text = "EY1" + intervals [632]: + xmin = 57.62 + xmax = 57.69 + text = "M" + intervals [633]: + xmin = 57.69 + xmax = 57.79 + text = "IY1" + intervals [634]: + xmin = 57.79 + xmax = 57.92 + text = "IH0" + intervals [635]: + xmin = 57.92 + xmax = 58.09 + text = "Z" + intervals [636]: + xmin = 58.09 + xmax = 58.12 + text = "AE1" + intervals [637]: + xmin = 58.12 + xmax = 58.19 + text = "N" + intervals [638]: + xmin = 58.19 + xmax = 58.23 + text = "AH0" + intervals [639]: + xmin = 58.23 + xmax = 58.39 + text = "M" + intervals [640]: + xmin = 58.39 + xmax = 58.97 + text = "IH1" + intervals [641]: + xmin = 58.97 + xmax = 59.06 + text = "Z" + intervals [642]: + xmin = 59.06 + xmax = 59.11 + text = "V" + intervals [643]: + xmin = 59.11 + xmax = 59.15 + text = "EH1" + intervals [644]: + xmin = 59.15 + xmax = 59.24 + text = "R" + intervals [645]: + xmin = 59.24 + xmax = 59.31 + text = "IY0" + intervals [646]: + xmin = 59.31 + xmax = 59.38 + text = "HH" + intervals [647]: + xmin = 59.38 + xmax = 59.43 + text = "EH1" + intervals [648]: + xmin = 59.43 + xmax = 59.52 + text = "L" + intervals [649]: + xmin = 59.52 + xmax = 59.55 + text = "P" + intervals [650]: + xmin = 59.55 + xmax = 59.58 + text = "F" + intervals [651]: + xmin = 59.58 + xmax = 59.61 + text = "AH0" + intervals [652]: + xmin = 59.61 + xmax = 59.67 + text = "L" + intervals [653]: + xmin = 59.67 + xmax = 59.72 + text = "F" + intervals [654]: + xmin = 59.72 + xmax = 59.75 + text = "R" + intervals [655]: + xmin = 59.75 + xmax = 59.81 + text = "ER0" + intervals [656]: + xmin = 59.81 + xmax = 59.88 + text = "M" + intervals [657]: + xmin = 59.88 + xmax = 59.98 + text = "IY1" + intervals [658]: + xmin = 59.98 + xmax = 60.08 + text = "T" + intervals [659]: + xmin = 60.08 + xmax = 60.28 + text = "UW1" + intervals [660]: + xmin = 60.28 + xmax = 60.42 + text = "L" + intervals [661]: + xmin = 60.42 + xmax = 60.63 + text = "ER1" + intervals [662]: + xmin = 60.63 + xmax = 60.69 + text = "N" + intervals [663]: + xmin = 60.69 + xmax = 60.72 + text = "AE1" + intervals [664]: + xmin = 60.72 + xmax = 60.75 + text = "N" + intervals [665]: + xmin = 60.75 + xmax = 60.78 + text = "D" + intervals [666]: + xmin = 60.78 + xmax = 60.84 + text = "IH0" + intervals [667]: + xmin = 60.84 + xmax = 60.88 + text = "K" + intervals [668]: + xmin = 60.88 + xmax = 60.95 + text = "S" + intervals [669]: + xmin = 60.95 + xmax = 61.01 + text = "P" + intervals [670]: + xmin = 61.01 + xmax = 61.09 + text = "R" + intervals [671]: + xmin = 61.09 + xmax = 61.14 + text = "EH1" + intervals [672]: + xmin = 61.14 + xmax = 61.21 + text = "S" + intervals [673]: + xmin = 61.21 + xmax = 61.33 + text = "JH" + intervals [674]: + xmin = 61.33 + xmax = 61.45 + text = "AE2" + intervals [675]: + xmin = 61.45 + xmax = 61.51 + text = "P" + intervals [676]: + xmin = 61.51 + xmax = 61.55 + text = "AH0" + intervals [677]: + xmin = 61.55 + xmax = 61.59 + text = "N" + intervals [678]: + xmin = 61.59 + xmax = 61.75 + text = "IY1" + intervals [679]: + xmin = 61.75 + xmax = 61.89 + text = "Z" + intervals [680]: + xmin = 61.89 + xmax = 62.02 + text = "B" + intervals [681]: + xmin = 62.02 + xmax = 62.11 + text = "EH1" + intervals [682]: + xmin = 62.11 + xmax = 62.19 + text = "T" + intervals [683]: + xmin = 62.19 + xmax = 62.42 + text = "ER0" + intervals [684]: + xmin = 62.42 + xmax = 64.097375 + text = "" diff --git a/EMAGE/test_sequences/textgrid/2_scott_0_2_2.TextGrid b/EMAGE/test_sequences/textgrid/2_scott_0_2_2.TextGrid new file mode 100644 index 0000000000000000000000000000000000000000..faef4cc5a37423c8f9ad3a32097bb27a75cd9d0e --- /dev/null +++ b/EMAGE/test_sequences/textgrid/2_scott_0_2_2.TextGrid @@ -0,0 +1,3716 @@ +File type = "ooTextFile" +Object class = "TextGrid" + +xmin = 0.0 +xmax = 62 +tiers? +size = 2 +item []: + item [1]: + class = "IntervalTier" + name = "words" + xmin = 0.0 + xmax = 62 + intervals: size = 223 + intervals [1]: + xmin = 0.0 + xmax = 1.45 + text = "" + intervals [2]: + xmin = 1.45 + xmax = 1.87 + text = "so" + intervals [3]: + xmin = 1.87 + xmax = 2.02 + text = "when" + intervals [4]: + xmin = 2.02 + xmax = 2.13 + text = "i" + intervals [5]: + xmin = 2.13 + xmax = 2.35 + text = "have" + intervals [6]: + xmin = 2.35 + xmax = 2.57 + text = "time" + intervals [7]: + xmin = 2.57 + xmax = 2.65 + text = "to" + intervals [8]: + xmin = 2.65 + xmax = 3.18 + text = "kill" + intervals [9]: + xmin = 3.18 + xmax = 3.22 + text = "" + intervals [10]: + xmin = 3.22 + xmax = 3.41 + text = "i" + intervals [11]: + xmin = 3.41 + xmax = 3.6 + text = "like" + intervals [12]: + xmin = 3.6 + xmax = 3.68 + text = "to" + intervals [13]: + xmin = 3.68 + xmax = 3.88 + text = "play" + intervals [14]: + xmin = 3.88 + xmax = 3.96 + text = "on" + intervals [15]: + xmin = 3.96 + xmax = 4.08 + text = "the" + intervals [16]: + xmin = 4.08 + xmax = 4.5 + text = "internet" + intervals [17]: + xmin = 4.5 + xmax = 4.66 + text = "and" + intervals [18]: + xmin = 4.66 + xmax = 4.87 + text = "play" + intervals [19]: + xmin = 4.87 + xmax = 5.19 + text = "close" + intervals [20]: + xmin = 5.19 + xmax = 5.67 + text = "attention" + intervals [21]: + xmin = 5.67 + xmax = 6.0 + text = "to" + intervals [22]: + xmin = 6.0 + xmax = 6.26 + text = "new" + intervals [23]: + xmin = 6.26 + xmax = 6.71 + text = "fashion" + intervals [24]: + xmin = 6.71 + xmax = 7.17 + text = "events" + intervals [25]: + xmin = 7.17 + xmax = 7.43 + text = "" + intervals [26]: + xmin = 7.43 + xmax = 7.76 + text = "such" + intervals [27]: + xmin = 7.76 + xmax = 8.14 + text = "as" + intervals [28]: + xmin = 8.14 + xmax = 8.19 + text = "" + intervals [29]: + xmin = 8.19 + xmax = 8.34 + text = "the" + intervals [30]: + xmin = 8.34 + xmax = 8.47 + text = "new" + intervals [31]: + xmin = 8.47 + xmax = 8.68 + text = "york" + intervals [32]: + xmin = 8.68 + xmax = 9.12 + text = "fashion" + intervals [33]: + xmin = 9.12 + xmax = 9.42 + text = "week" + intervals [34]: + xmin = 9.42 + xmax = 9.49 + text = "the" + intervals [35]: + xmin = 9.49 + xmax = 9.87 + text = "paris" + intervals [36]: + xmin = 9.87 + xmax = 10.25 + text = "fashion" + intervals [37]: + xmin = 10.25 + xmax = 10.56 + text = "week" + intervals [38]: + xmin = 10.56 + xmax = 10.66 + text = "the" + intervals [39]: + xmin = 10.66 + xmax = 11.07 + text = "london" + intervals [40]: + xmin = 11.07 + xmax = 11.51 + text = "fashion" + intervals [41]: + xmin = 11.51 + xmax = 11.78 + text = "week" + intervals [42]: + xmin = 11.78 + xmax = 12.17 + text = "and" + intervals [43]: + xmin = 12.17 + xmax = 12.21 + text = "" + intervals [44]: + xmin = 12.21 + xmax = 12.83 + text = "milan" + intervals [45]: + xmin = 12.83 + xmax = 13.24 + text = "fashion" + intervals [46]: + xmin = 13.24 + xmax = 13.62 + text = "week" + intervals [47]: + xmin = 13.62 + xmax = 14.03 + text = "" + intervals [48]: + xmin = 14.03 + xmax = 14.15 + text = "the" + intervals [49]: + xmin = 14.15 + xmax = 14.35 + text = "rest" + intervals [50]: + xmin = 14.35 + xmax = 14.43 + text = "of" + intervals [51]: + xmin = 14.43 + xmax = 14.49 + text = "the" + intervals [52]: + xmin = 14.49 + xmax = 14.8 + text = "time" + intervals [53]: + xmin = 14.8 + xmax = 14.87 + text = "i" + intervals [54]: + xmin = 14.87 + xmax = 15.2 + text = "usually" + intervals [55]: + xmin = 15.2 + xmax = 15.3 + text = "go" + intervals [56]: + xmin = 15.3 + xmax = 15.36 + text = "to" + intervals [57]: + xmin = 15.36 + xmax = 15.44 + text = "the" + intervals [58]: + xmin = 15.44 + xmax = 15.93 + text = "library" + intervals [59]: + xmin = 15.93 + xmax = 16.04 + text = "and" + intervals [60]: + xmin = 16.04 + xmax = 16.25 + text = "find" + intervals [61]: + xmin = 16.25 + xmax = 16.35 + text = "some" + intervals [62]: + xmin = 16.35 + xmax = 16.71 + text = "interesting" + intervals [63]: + xmin = 16.71 + xmax = 17.19 + text = "books" + intervals [64]: + xmin = 17.19 + xmax = 17.31 + text = "and" + intervals [65]: + xmin = 17.31 + xmax = 17.51 + text = "then" + intervals [66]: + xmin = 17.51 + xmax = 17.63 + text = "go" + intervals [67]: + xmin = 17.63 + xmax = 17.7 + text = "to" + intervals [68]: + xmin = 17.7 + xmax = 17.78 + text = "a" + intervals [69]: + xmin = 17.78 + xmax = 18.08 + text = "park" + intervals [70]: + xmin = 18.08 + xmax = 18.17 + text = "and" + intervals [71]: + xmin = 18.17 + xmax = 18.75 + text = "relax" + intervals [72]: + xmin = 18.75 + xmax = 19.04 + text = "" + intervals [73]: + xmin = 19.04 + xmax = 19.22 + text = "there" + intervals [74]: + xmin = 19.22 + xmax = 19.27 + text = "are" + intervals [75]: + xmin = 19.27 + xmax = 19.5 + text = "many" + intervals [76]: + xmin = 19.5 + xmax = 19.78 + text = "books" + intervals [77]: + xmin = 19.78 + xmax = 19.93 + text = "that" + intervals [78]: + xmin = 19.93 + xmax = 20.11 + text = "i" + intervals [79]: + xmin = 20.11 + xmax = 20.4 + text = "find" + intervals [80]: + xmin = 20.4 + xmax = 20.92 + text = "interesting" + intervals [81]: + xmin = 20.92 + xmax = 21.15 + text = "such" + intervals [82]: + xmin = 21.15 + xmax = 21.3 + text = "as" + intervals [83]: + xmin = 21.3 + xmax = 21.62 + text = "fashion" + intervals [84]: + xmin = 21.62 + xmax = 22.19 + text = "magazines" + intervals [85]: + xmin = 22.19 + xmax = 22.8 + text = "inspirational" + intervals [86]: + xmin = 22.8 + xmax = 23.15 + text = "books" + intervals [87]: + xmin = 23.15 + xmax = 23.44 + text = "and" + intervals [88]: + xmin = 23.44 + xmax = 24.04 + text = "professional" + intervals [89]: + xmin = 24.04 + xmax = 24.46 + text = "books" + intervals [90]: + xmin = 24.46 + xmax = 24.83 + text = "" + intervals [91]: + xmin = 24.83 + xmax = 25.06 + text = "these" + intervals [92]: + xmin = 25.06 + xmax = 25.37 + text = "books" + intervals [93]: + xmin = 25.37 + xmax = 25.54 + text = "can" + intervals [94]: + xmin = 25.54 + xmax = 25.66 + text = "give" + intervals [95]: + xmin = 25.66 + xmax = 25.76 + text = "me" + intervals [96]: + xmin = 25.76 + xmax = 25.86 + text = "the" + intervals [97]: + xmin = 25.86 + xmax = 26.85 + text = "motivation" + intervals [98]: + xmin = 26.85 + xmax = 26.88 + text = "" + intervals [99]: + xmin = 26.88 + xmax = 27.07 + text = "to" + intervals [100]: + xmin = 27.07 + xmax = 27.37 + text = "be" + intervals [101]: + xmin = 27.37 + xmax = 28.01 + text = "healthier" + intervals [102]: + xmin = 28.01 + xmax = 28.18 + text = "and" + intervals [103]: + xmin = 28.18 + xmax = 28.9 + text = "energetic" + intervals [104]: + xmin = 28.9 + xmax = 29.1 + text = "" + intervals [105]: + xmin = 29.1 + xmax = 29.3 + text = "and" + intervals [106]: + xmin = 29.3 + xmax = 29.37 + text = "the" + intervals [107]: + xmin = 29.37 + xmax = 29.74 + text = "last" + intervals [108]: + xmin = 29.74 + xmax = 29.94 + text = "thing" + intervals [109]: + xmin = 29.94 + xmax = 30.14 + text = "i" + intervals [110]: + xmin = 30.14 + xmax = 30.42 + text = "like" + intervals [111]: + xmin = 30.42 + xmax = 30.53 + text = "to" + intervals [112]: + xmin = 30.53 + xmax = 30.84 + text = "do" + intervals [113]: + xmin = 30.84 + xmax = 31.22 + text = "when" + intervals [114]: + xmin = 31.22 + xmax = 31.43 + text = "i'm" + intervals [115]: + xmin = 31.43 + xmax = 31.87 + text = "free" + intervals [116]: + xmin = 31.87 + xmax = 31.99 + text = "is" + intervals [117]: + xmin = 31.99 + xmax = 32.11 + text = "it" + intervals [118]: + xmin = 32.11 + xmax = 32.23 + text = "out" + intervals [119]: + xmin = 32.23 + xmax = 32.35 + text = "with" + intervals [120]: + xmin = 32.35 + xmax = 32.48 + text = "my" + intervals [121]: + xmin = 32.48 + xmax = 32.86 + text = "family" + intervals [122]: + xmin = 32.86 + xmax = 33.33 + text = "members" + intervals [123]: + xmin = 33.33 + xmax = 33.51 + text = "" + intervals [124]: + xmin = 33.51 + xmax = 33.89 + text = "you" + intervals [125]: + xmin = 33.89 + xmax = 34.11 + text = "would" + intervals [126]: + xmin = 34.11 + xmax = 34.29 + text = "be" + intervals [127]: + xmin = 34.29 + xmax = 35.07 + text = "surprised" + intervals [128]: + xmin = 35.07 + xmax = 35.16 + text = "to" + intervals [129]: + xmin = 35.16 + xmax = 35.36 + text = "know" + intervals [130]: + xmin = 35.36 + xmax = 35.5 + text = "that" + intervals [131]: + xmin = 35.5 + xmax = 35.64 + text = "i" + intervals [132]: + xmin = 35.64 + xmax = 35.84 + text = "have" + intervals [133]: + xmin = 35.84 + xmax = 36.3 + text = "tried" + intervals [134]: + xmin = 36.3 + xmax = 36.57 + text = "" + intervals [135]: + xmin = 36.57 + xmax = 36.99 + text = "all" + intervals [136]: + xmin = 36.99 + xmax = 37.08 + text = "the" + intervals [137]: + xmin = 37.08 + xmax = 37.68 + text = "restaurants" + intervals [138]: + xmin = 37.68 + xmax = 37.71 + text = "" + intervals [139]: + xmin = 37.71 + xmax = 37.83 + text = "in" + intervals [140]: + xmin = 37.83 + xmax = 37.95 + text = "our" + intervals [141]: + xmin = 37.95 + xmax = 38.5 + text = "huge" + intervals [142]: + xmin = 38.5 + xmax = 39.07 + text = "community" + intervals [143]: + xmin = 39.07 + xmax = 39.23 + text = "" + intervals [144]: + xmin = 39.23 + xmax = 39.6 + text = "i" + intervals [145]: + xmin = 39.6 + xmax = 40.09 + text = "actually" + intervals [146]: + xmin = 40.09 + xmax = 40.32 + text = "give" + intervals [147]: + xmin = 40.32 + xmax = 40.61 + text = "each" + intervals [148]: + xmin = 40.61 + xmax = 41.08 + text = "restaurant" + intervals [149]: + xmin = 41.08 + xmax = 41.15 + text = "a" + intervals [150]: + xmin = 41.15 + xmax = 41.55 + text = "score" + intervals [151]: + xmin = 41.55 + xmax = 41.82 + text = "based" + intervals [152]: + xmin = 41.82 + xmax = 41.89 + text = "on" + intervals [153]: + xmin = 41.89 + xmax = 42.05 + text = "how" + intervals [154]: + xmin = 42.05 + xmax = 42.17 + text = "good" + intervals [155]: + xmin = 42.17 + xmax = 42.23 + text = "the" + intervals [156]: + xmin = 42.23 + xmax = 42.51 + text = "food" + intervals [157]: + xmin = 42.51 + xmax = 42.85 + text = "is" + intervals [158]: + xmin = 42.85 + xmax = 43.13 + text = "" + intervals [159]: + xmin = 43.13 + xmax = 43.36 + text = "how" + intervals [160]: + xmin = 43.36 + xmax = 43.51 + text = "good" + intervals [161]: + xmin = 43.51 + xmax = 43.62 + text = "the" + intervals [162]: + xmin = 43.62 + xmax = 44.1 + text = "environment" + intervals [163]: + xmin = 44.1 + xmax = 44.4 + text = "is" + intervals [164]: + xmin = 44.4 + xmax = 44.49 + text = "" + intervals [165]: + xmin = 44.49 + xmax = 44.98 + text = "and" + intervals [166]: + xmin = 44.98 + xmax = 45.34 + text = "at" + intervals [167]: + xmin = 45.34 + xmax = 45.62 + text = "the" + intervals [168]: + xmin = 45.62 + xmax = 45.91 + text = "same" + intervals [169]: + xmin = 45.91 + xmax = 46.29 + text = "time" + intervals [170]: + xmin = 46.29 + xmax = 46.42 + text = "i" + intervals [171]: + xmin = 46.42 + xmax = 46.54 + text = "will" + intervals [172]: + xmin = 46.54 + xmax = 46.74 + text = "write" + intervals [173]: + xmin = 46.74 + xmax = 46.94 + text = "down" + intervals [174]: + xmin = 46.94 + xmax = 47.02 + text = "the" + intervals [175]: + xmin = 47.02 + xmax = 47.24 + text = "type" + intervals [176]: + xmin = 47.24 + xmax = 47.39 + text = "of" + intervals [177]: + xmin = 47.39 + xmax = 47.8 + text = "food" + intervals [178]: + xmin = 47.8 + xmax = 48.03 + text = "" + intervals [179]: + xmin = 48.03 + xmax = 48.24 + text = "they" + intervals [180]: + xmin = 48.24 + xmax = 48.76 + text = "serve" + intervals [181]: + xmin = 48.76 + xmax = 49.42 + text = "" + intervals [182]: + xmin = 49.42 + xmax = 49.9 + text = "so" + intervals [183]: + xmin = 49.9 + xmax = 50.46 + text = "when" + intervals [184]: + xmin = 50.46 + xmax = 50.49 + text = "" + intervals [185]: + xmin = 50.49 + xmax = 50.85 + text = "you're" + intervals [186]: + xmin = 50.85 + xmax = 50.98 + text = "so" + intervals [187]: + xmin = 50.98 + xmax = 51.13 + text = "when" + intervals [188]: + xmin = 51.13 + xmax = 51.35 + text = "each" + intervals [189]: + xmin = 51.35 + xmax = 51.55 + text = "time" + intervals [190]: + xmin = 51.55 + xmax = 51.62 + text = "a" + intervals [191]: + xmin = 51.62 + xmax = 51.91 + text = "friend" + intervals [192]: + xmin = 51.91 + xmax = 52.32 + text = "comes" + intervals [193]: + xmin = 52.32 + xmax = 52.46 + text = "to" + intervals [194]: + xmin = 52.46 + xmax = 52.59 + text = "the" + intervals [195]: + xmin = 52.59 + xmax = 52.9 + text = "city" + intervals [196]: + xmin = 52.9 + xmax = 53.07 + text = "to" + intervals [197]: + xmin = 53.07 + xmax = 53.35 + text = "enjoy" + intervals [198]: + xmin = 53.35 + xmax = 53.62 + text = "time" + intervals [199]: + xmin = 53.62 + xmax = 53.74 + text = "with" + intervals [200]: + xmin = 53.74 + xmax = 54.02 + text = "me" + intervals [201]: + xmin = 54.02 + xmax = 54.31 + text = "" + intervals [202]: + xmin = 54.31 + xmax = 54.54 + text = "i" + intervals [203]: + xmin = 54.54 + xmax = 54.69 + text = "will" + intervals [204]: + xmin = 54.69 + xmax = 54.84 + text = "give" + intervals [205]: + xmin = 54.84 + xmax = 54.97 + text = "them" + intervals [206]: + xmin = 54.97 + xmax = 55.07 + text = "the" + intervals [207]: + xmin = 55.07 + xmax = 55.38 + text = "top" + intervals [208]: + xmin = 55.38 + xmax = 55.53 + text = "5" + intervals [209]: + xmin = 55.53 + xmax = 56.1 + text = "restaurants" + intervals [210]: + xmin = 56.1 + xmax = 56.44 + text = "based" + intervals [211]: + xmin = 56.44 + xmax = 56.68 + text = "on" + intervals [212]: + xmin = 56.68 + xmax = 56.99 + text = "this" + intervals [213]: + xmin = 56.99 + xmax = 57.35 + text = "ranking" + intervals [214]: + xmin = 57.35 + xmax = 57.53 + text = "and" + intervals [215]: + xmin = 57.53 + xmax = 57.86 + text = "every" + intervals [216]: + xmin = 57.86 + xmax = 58.44 + text = "time" + intervals [217]: + xmin = 58.44 + xmax = 59.02 + text = "" + intervals [218]: + xmin = 59.02 + xmax = 59.2 + text = "you're" + intervals [219]: + xmin = 59.2 + xmax = 59.72 + text = "satisfied" + intervals [220]: + xmin = 59.72 + xmax = 59.85 + text = "with" + intervals [221]: + xmin = 59.85 + xmax = 60.1 + text = "these" + intervals [222]: + xmin = 60.1 + xmax = 60.81 + text = "restaurants" + intervals [223]: + xmin = 60.81 + xmax = 62 + text = "" + item [2]: + class = "IntervalTier" + name = "phones" + xmin = 0.0 + xmax = 62 + intervals: size = 701 + intervals [1]: + xmin = 0.0 + xmax = 1.45 + text = "" + intervals [2]: + xmin = 1.45 + xmax = 1.64 + text = "S" + intervals [3]: + xmin = 1.64 + xmax = 1.87 + text = "OW1" + intervals [4]: + xmin = 1.87 + xmax = 1.94 + text = "W" + intervals [5]: + xmin = 1.94 + xmax = 1.97 + text = "EH1" + intervals [6]: + xmin = 1.97 + xmax = 2.02 + text = "N" + intervals [7]: + xmin = 2.02 + xmax = 2.13 + text = "AY1" + intervals [8]: + xmin = 2.13 + xmax = 2.21 + text = "HH" + intervals [9]: + xmin = 2.21 + xmax = 2.29 + text = "AE1" + intervals [10]: + xmin = 2.29 + xmax = 2.35 + text = "V" + intervals [11]: + xmin = 2.35 + xmax = 2.43 + text = "T" + intervals [12]: + xmin = 2.43 + xmax = 2.52 + text = "AY1" + intervals [13]: + xmin = 2.52 + xmax = 2.57 + text = "M" + intervals [14]: + xmin = 2.57 + xmax = 2.6 + text = "T" + intervals [15]: + xmin = 2.6 + xmax = 2.65 + text = "AH0" + intervals [16]: + xmin = 2.65 + xmax = 2.75 + text = "K" + intervals [17]: + xmin = 2.75 + xmax = 2.81 + text = "IH1" + intervals [18]: + xmin = 2.81 + xmax = 3.18 + text = "L" + intervals [19]: + xmin = 3.18 + xmax = 3.22 + text = "" + intervals [20]: + xmin = 3.22 + xmax = 3.41 + text = "AY1" + intervals [21]: + xmin = 3.41 + xmax = 3.46 + text = "L" + intervals [22]: + xmin = 3.46 + xmax = 3.57 + text = "AY1" + intervals [23]: + xmin = 3.57 + xmax = 3.6 + text = "K" + intervals [24]: + xmin = 3.6 + xmax = 3.63 + text = "T" + intervals [25]: + xmin = 3.63 + xmax = 3.68 + text = "IH0" + intervals [26]: + xmin = 3.68 + xmax = 3.75 + text = "P" + intervals [27]: + xmin = 3.75 + xmax = 3.83 + text = "L" + intervals [28]: + xmin = 3.83 + xmax = 3.88 + text = "EY1" + intervals [29]: + xmin = 3.88 + xmax = 3.93 + text = "AA1" + intervals [30]: + xmin = 3.93 + xmax = 3.96 + text = "N" + intervals [31]: + xmin = 3.96 + xmax = 4.01 + text = "DH" + intervals [32]: + xmin = 4.01 + xmax = 4.08 + text = "IY0" + intervals [33]: + xmin = 4.08 + xmax = 4.13 + text = "IH1" + intervals [34]: + xmin = 4.13 + xmax = 4.16 + text = "N" + intervals [35]: + xmin = 4.16 + xmax = 4.19 + text = "T" + intervals [36]: + xmin = 4.19 + xmax = 4.25 + text = "ER0" + intervals [37]: + xmin = 4.25 + xmax = 4.29 + text = "N" + intervals [38]: + xmin = 4.29 + xmax = 4.42 + text = "EH2" + intervals [39]: + xmin = 4.42 + xmax = 4.5 + text = "T" + intervals [40]: + xmin = 4.5 + xmax = 4.58 + text = "AE1" + intervals [41]: + xmin = 4.58 + xmax = 4.62 + text = "N" + intervals [42]: + xmin = 4.62 + xmax = 4.66 + text = "D" + intervals [43]: + xmin = 4.66 + xmax = 4.71 + text = "P" + intervals [44]: + xmin = 4.71 + xmax = 4.8 + text = "L" + intervals [45]: + xmin = 4.8 + xmax = 4.87 + text = "EY1" + intervals [46]: + xmin = 4.87 + xmax = 4.97 + text = "K" + intervals [47]: + xmin = 4.97 + xmax = 5.02 + text = "L" + intervals [48]: + xmin = 5.02 + xmax = 5.09 + text = "OW1" + intervals [49]: + xmin = 5.09 + xmax = 5.19 + text = "S" + intervals [50]: + xmin = 5.19 + xmax = 5.23 + text = "AH0" + intervals [51]: + xmin = 5.23 + xmax = 5.32 + text = "T" + intervals [52]: + xmin = 5.32 + xmax = 5.36 + text = "EH1" + intervals [53]: + xmin = 5.36 + xmax = 5.42 + text = "N" + intervals [54]: + xmin = 5.42 + xmax = 5.49 + text = "SH" + intervals [55]: + xmin = 5.49 + xmax = 5.55 + text = "AH0" + intervals [56]: + xmin = 5.55 + xmax = 5.67 + text = "N" + intervals [57]: + xmin = 5.67 + xmax = 5.8 + text = "T" + intervals [58]: + xmin = 5.8 + xmax = 6.0 + text = "UW1" + intervals [59]: + xmin = 6.0 + xmax = 6.03 + text = "N" + intervals [60]: + xmin = 6.03 + xmax = 6.15 + text = "Y" + intervals [61]: + xmin = 6.15 + xmax = 6.26 + text = "UW1" + intervals [62]: + xmin = 6.26 + xmax = 6.41 + text = "F" + intervals [63]: + xmin = 6.41 + xmax = 6.54 + text = "AE1" + intervals [64]: + xmin = 6.54 + xmax = 6.63 + text = "SH" + intervals [65]: + xmin = 6.63 + xmax = 6.66 + text = "AH0" + intervals [66]: + xmin = 6.66 + xmax = 6.71 + text = "N" + intervals [67]: + xmin = 6.71 + xmax = 6.75 + text = "IH0" + intervals [68]: + xmin = 6.75 + xmax = 6.81 + text = "V" + intervals [69]: + xmin = 6.81 + xmax = 6.93 + text = "EH1" + intervals [70]: + xmin = 6.93 + xmax = 6.97 + text = "N" + intervals [71]: + xmin = 6.97 + xmax = 7.02 + text = "T" + intervals [72]: + xmin = 7.02 + xmax = 7.17 + text = "S" + intervals [73]: + xmin = 7.17 + xmax = 7.43 + text = "" + intervals [74]: + xmin = 7.43 + xmax = 7.55 + text = "S" + intervals [75]: + xmin = 7.55 + xmax = 7.63 + text = "AH1" + intervals [76]: + xmin = 7.63 + xmax = 7.76 + text = "CH" + intervals [77]: + xmin = 7.76 + xmax = 7.94 + text = "EH1" + intervals [78]: + xmin = 7.94 + xmax = 8.14 + text = "Z" + intervals [79]: + xmin = 8.14 + xmax = 8.19 + text = "" + intervals [80]: + xmin = 8.19 + xmax = 8.28 + text = "DH" + intervals [81]: + xmin = 8.28 + xmax = 8.34 + text = "AH0" + intervals [82]: + xmin = 8.34 + xmax = 8.44 + text = "N" + intervals [83]: + xmin = 8.44 + xmax = 8.47 + text = "UW1" + intervals [84]: + xmin = 8.47 + xmax = 8.52 + text = "Y" + intervals [85]: + xmin = 8.52 + xmax = 8.56 + text = "AO1" + intervals [86]: + xmin = 8.56 + xmax = 8.62 + text = "R" + intervals [87]: + xmin = 8.62 + xmax = 8.68 + text = "K" + intervals [88]: + xmin = 8.68 + xmax = 8.79 + text = "F" + intervals [89]: + xmin = 8.79 + xmax = 8.93 + text = "AE1" + intervals [90]: + xmin = 8.93 + xmax = 9.03 + text = "SH" + intervals [91]: + xmin = 9.03 + xmax = 9.07 + text = "AH0" + intervals [92]: + xmin = 9.07 + xmax = 9.12 + text = "N" + intervals [93]: + xmin = 9.12 + xmax = 9.19 + text = "W" + intervals [94]: + xmin = 9.19 + xmax = 9.33 + text = "IY1" + intervals [95]: + xmin = 9.33 + xmax = 9.42 + text = "K" + intervals [96]: + xmin = 9.42 + xmax = 9.46 + text = "DH" + intervals [97]: + xmin = 9.46 + xmax = 9.49 + text = "AH0" + intervals [98]: + xmin = 9.49 + xmax = 9.57 + text = "P" + intervals [99]: + xmin = 9.57 + xmax = 9.64 + text = "EH1" + intervals [100]: + xmin = 9.64 + xmax = 9.75 + text = "R" + intervals [101]: + xmin = 9.75 + xmax = 9.8 + text = "IH0" + intervals [102]: + xmin = 9.8 + xmax = 9.87 + text = "S" + intervals [103]: + xmin = 9.87 + xmax = 9.93 + text = "F" + intervals [104]: + xmin = 9.93 + xmax = 10.09 + text = "AE1" + intervals [105]: + xmin = 10.09 + xmax = 10.19 + text = "SH" + intervals [106]: + xmin = 10.19 + xmax = 10.22 + text = "AH0" + intervals [107]: + xmin = 10.22 + xmax = 10.25 + text = "N" + intervals [108]: + xmin = 10.25 + xmax = 10.32 + text = "W" + intervals [109]: + xmin = 10.32 + xmax = 10.49 + text = "IY1" + intervals [110]: + xmin = 10.49 + xmax = 10.56 + text = "K" + intervals [111]: + xmin = 10.56 + xmax = 10.6 + text = "DH" + intervals [112]: + xmin = 10.6 + xmax = 10.66 + text = "AH0" + intervals [113]: + xmin = 10.66 + xmax = 10.76 + text = "L" + intervals [114]: + xmin = 10.76 + xmax = 10.81 + text = "AH1" + intervals [115]: + xmin = 10.81 + xmax = 10.87 + text = "N" + intervals [116]: + xmin = 10.87 + xmax = 10.92 + text = "D" + intervals [117]: + xmin = 10.92 + xmax = 10.97 + text = "AH0" + intervals [118]: + xmin = 10.97 + xmax = 11.07 + text = "N" + intervals [119]: + xmin = 11.07 + xmax = 11.18 + text = "F" + intervals [120]: + xmin = 11.18 + xmax = 11.31 + text = "AE1" + intervals [121]: + xmin = 11.31 + xmax = 11.42 + text = "SH" + intervals [122]: + xmin = 11.42 + xmax = 11.46 + text = "AH0" + intervals [123]: + xmin = 11.46 + xmax = 11.51 + text = "N" + intervals [124]: + xmin = 11.51 + xmax = 11.56 + text = "W" + intervals [125]: + xmin = 11.56 + xmax = 11.69 + text = "IY1" + intervals [126]: + xmin = 11.69 + xmax = 11.78 + text = "K" + intervals [127]: + xmin = 11.78 + xmax = 11.84 + text = "AE1" + intervals [128]: + xmin = 11.84 + xmax = 11.99 + text = "N" + intervals [129]: + xmin = 11.99 + xmax = 12.17 + text = "D" + intervals [130]: + xmin = 12.17 + xmax = 12.21 + text = "" + intervals [131]: + xmin = 12.21 + xmax = 12.36 + text = "M" + intervals [132]: + xmin = 12.36 + xmax = 12.48 + text = "IH0" + intervals [133]: + xmin = 12.48 + xmax = 12.55 + text = "L" + intervals [134]: + xmin = 12.55 + xmax = 12.73 + text = "AA1" + intervals [135]: + xmin = 12.73 + xmax = 12.83 + text = "N" + intervals [136]: + xmin = 12.83 + xmax = 12.92 + text = "F" + intervals [137]: + xmin = 12.92 + xmax = 13.06 + text = "AE1" + intervals [138]: + xmin = 13.06 + xmax = 13.16 + text = "SH" + intervals [139]: + xmin = 13.16 + xmax = 13.21 + text = "AH0" + intervals [140]: + xmin = 13.21 + xmax = 13.24 + text = "N" + intervals [141]: + xmin = 13.24 + xmax = 13.29 + text = "W" + intervals [142]: + xmin = 13.29 + xmax = 13.41 + text = "IY1" + intervals [143]: + xmin = 13.41 + xmax = 13.62 + text = "K" + intervals [144]: + xmin = 13.62 + xmax = 14.03 + text = "" + intervals [145]: + xmin = 14.03 + xmax = 14.11 + text = "DH" + intervals [146]: + xmin = 14.11 + xmax = 14.15 + text = "AH1" + intervals [147]: + xmin = 14.15 + xmax = 14.22 + text = "R" + intervals [148]: + xmin = 14.22 + xmax = 14.26 + text = "EH1" + intervals [149]: + xmin = 14.26 + xmax = 14.31 + text = "S" + intervals [150]: + xmin = 14.31 + xmax = 14.35 + text = "T" + intervals [151]: + xmin = 14.35 + xmax = 14.4 + text = "AH0" + intervals [152]: + xmin = 14.4 + xmax = 14.43 + text = "V" + intervals [153]: + xmin = 14.43 + xmax = 14.46 + text = "DH" + intervals [154]: + xmin = 14.46 + xmax = 14.49 + text = "AH0" + intervals [155]: + xmin = 14.49 + xmax = 14.55 + text = "T" + intervals [156]: + xmin = 14.55 + xmax = 14.73 + text = "AY1" + intervals [157]: + xmin = 14.73 + xmax = 14.8 + text = "M" + intervals [158]: + xmin = 14.8 + xmax = 14.87 + text = "AY1" + intervals [159]: + xmin = 14.87 + xmax = 14.96 + text = "Y" + intervals [160]: + xmin = 14.96 + xmax = 14.99 + text = "UW1" + intervals [161]: + xmin = 14.99 + xmax = 15.08 + text = "ZH" + intervals [162]: + xmin = 15.08 + xmax = 15.11 + text = "AH0" + intervals [163]: + xmin = 15.11 + xmax = 15.14 + text = "L" + intervals [164]: + xmin = 15.14 + xmax = 15.2 + text = "IY0" + intervals [165]: + xmin = 15.2 + xmax = 15.23 + text = "G" + intervals [166]: + xmin = 15.23 + xmax = 15.3 + text = "OW1" + intervals [167]: + xmin = 15.3 + xmax = 15.33 + text = "T" + intervals [168]: + xmin = 15.33 + xmax = 15.36 + text = "AH0" + intervals [169]: + xmin = 15.36 + xmax = 15.39 + text = "DH" + intervals [170]: + xmin = 15.39 + xmax = 15.44 + text = "AH1" + intervals [171]: + xmin = 15.44 + xmax = 15.51 + text = "L" + intervals [172]: + xmin = 15.51 + xmax = 15.67 + text = "AY1" + intervals [173]: + xmin = 15.67 + xmax = 15.71 + text = "B" + intervals [174]: + xmin = 15.71 + xmax = 15.74 + text = "R" + intervals [175]: + xmin = 15.74 + xmax = 15.83 + text = "EH2" + intervals [176]: + xmin = 15.83 + xmax = 15.9 + text = "R" + intervals [177]: + xmin = 15.9 + xmax = 15.93 + text = "IY0" + intervals [178]: + xmin = 15.93 + xmax = 15.96 + text = "AH0" + intervals [179]: + xmin = 15.96 + xmax = 15.99 + text = "N" + intervals [180]: + xmin = 15.99 + xmax = 16.04 + text = "D" + intervals [181]: + xmin = 16.04 + xmax = 16.11 + text = "F" + intervals [182]: + xmin = 16.11 + xmax = 16.18 + text = "AY1" + intervals [183]: + xmin = 16.18 + xmax = 16.21 + text = "N" + intervals [184]: + xmin = 16.21 + xmax = 16.25 + text = "D" + intervals [185]: + xmin = 16.25 + xmax = 16.29 + text = "S" + intervals [186]: + xmin = 16.29 + xmax = 16.32 + text = "AH1" + intervals [187]: + xmin = 16.32 + xmax = 16.35 + text = "M" + intervals [188]: + xmin = 16.35 + xmax = 16.38 + text = "IH1" + intervals [189]: + xmin = 16.38 + xmax = 16.41 + text = "N" + intervals [190]: + xmin = 16.41 + xmax = 16.46 + text = "T" + intervals [191]: + xmin = 16.46 + xmax = 16.49 + text = "R" + intervals [192]: + xmin = 16.49 + xmax = 16.53 + text = "IH0" + intervals [193]: + xmin = 16.53 + xmax = 16.57 + text = "S" + intervals [194]: + xmin = 16.57 + xmax = 16.6 + text = "T" + intervals [195]: + xmin = 16.6 + xmax = 16.64 + text = "IH0" + intervals [196]: + xmin = 16.64 + xmax = 16.71 + text = "NG" + intervals [197]: + xmin = 16.71 + xmax = 16.78 + text = "B" + intervals [198]: + xmin = 16.78 + xmax = 17.0 + text = "UH1" + intervals [199]: + xmin = 17.0 + xmax = 17.09 + text = "K" + intervals [200]: + xmin = 17.09 + xmax = 17.19 + text = "S" + intervals [201]: + xmin = 17.19 + xmax = 17.25 + text = "AH0" + intervals [202]: + xmin = 17.25 + xmax = 17.28 + text = "N" + intervals [203]: + xmin = 17.28 + xmax = 17.31 + text = "D" + intervals [204]: + xmin = 17.31 + xmax = 17.34 + text = "DH" + intervals [205]: + xmin = 17.34 + xmax = 17.42 + text = "EH1" + intervals [206]: + xmin = 17.42 + xmax = 17.51 + text = "N" + intervals [207]: + xmin = 17.51 + xmax = 17.58 + text = "G" + intervals [208]: + xmin = 17.58 + xmax = 17.63 + text = "OW1" + intervals [209]: + xmin = 17.63 + xmax = 17.67 + text = "T" + intervals [210]: + xmin = 17.67 + xmax = 17.7 + text = "AH0" + intervals [211]: + xmin = 17.7 + xmax = 17.78 + text = "AH0" + intervals [212]: + xmin = 17.78 + xmax = 17.89 + text = "P" + intervals [213]: + xmin = 17.89 + xmax = 17.95 + text = "AA1" + intervals [214]: + xmin = 17.95 + xmax = 18.04 + text = "R" + intervals [215]: + xmin = 18.04 + xmax = 18.08 + text = "K" + intervals [216]: + xmin = 18.08 + xmax = 18.11 + text = "AH0" + intervals [217]: + xmin = 18.11 + xmax = 18.14 + text = "N" + intervals [218]: + xmin = 18.14 + xmax = 18.17 + text = "D" + intervals [219]: + xmin = 18.17 + xmax = 18.2 + text = "R" + intervals [220]: + xmin = 18.2 + xmax = 18.25 + text = "IH0" + intervals [221]: + xmin = 18.25 + xmax = 18.33 + text = "L" + intervals [222]: + xmin = 18.33 + xmax = 18.53 + text = "AE1" + intervals [223]: + xmin = 18.53 + xmax = 18.58 + text = "K" + intervals [224]: + xmin = 18.58 + xmax = 18.75 + text = "S" + intervals [225]: + xmin = 18.75 + xmax = 19.04 + text = "" + intervals [226]: + xmin = 19.04 + xmax = 19.14 + text = "DH" + intervals [227]: + xmin = 19.14 + xmax = 19.18 + text = "EH1" + intervals [228]: + xmin = 19.18 + xmax = 19.22 + text = "R" + intervals [229]: + xmin = 19.22 + xmax = 19.27 + text = "ER0" + intervals [230]: + xmin = 19.27 + xmax = 19.34 + text = "M" + intervals [231]: + xmin = 19.34 + xmax = 19.39 + text = "EH1" + intervals [232]: + xmin = 19.39 + xmax = 19.43 + text = "N" + intervals [233]: + xmin = 19.43 + xmax = 19.5 + text = "IY0" + intervals [234]: + xmin = 19.5 + xmax = 19.56 + text = "B" + intervals [235]: + xmin = 19.56 + xmax = 19.66 + text = "UH1" + intervals [236]: + xmin = 19.66 + xmax = 19.72 + text = "K" + intervals [237]: + xmin = 19.72 + xmax = 19.78 + text = "S" + intervals [238]: + xmin = 19.78 + xmax = 19.81 + text = "DH" + intervals [239]: + xmin = 19.81 + xmax = 19.84 + text = "AH0" + intervals [240]: + xmin = 19.84 + xmax = 19.93 + text = "T" + intervals [241]: + xmin = 19.93 + xmax = 20.11 + text = "AY1" + intervals [242]: + xmin = 20.11 + xmax = 20.22 + text = "F" + intervals [243]: + xmin = 20.22 + xmax = 20.3 + text = "AY1" + intervals [244]: + xmin = 20.3 + xmax = 20.37 + text = "N" + intervals [245]: + xmin = 20.37 + xmax = 20.4 + text = "D" + intervals [246]: + xmin = 20.4 + xmax = 20.52 + text = "IH1" + intervals [247]: + xmin = 20.52 + xmax = 20.55 + text = "N" + intervals [248]: + xmin = 20.55 + xmax = 20.59 + text = "T" + intervals [249]: + xmin = 20.59 + xmax = 20.62 + text = "R" + intervals [250]: + xmin = 20.62 + xmax = 20.67 + text = "AH0" + intervals [251]: + xmin = 20.67 + xmax = 20.74 + text = "S" + intervals [252]: + xmin = 20.74 + xmax = 20.78 + text = "T" + intervals [253]: + xmin = 20.78 + xmax = 20.85 + text = "IH0" + intervals [254]: + xmin = 20.85 + xmax = 20.92 + text = "NG" + intervals [255]: + xmin = 20.92 + xmax = 21.02 + text = "S" + intervals [256]: + xmin = 21.02 + xmax = 21.06 + text = "AH1" + intervals [257]: + xmin = 21.06 + xmax = 21.15 + text = "CH" + intervals [258]: + xmin = 21.15 + xmax = 21.2 + text = "EH1" + intervals [259]: + xmin = 21.2 + xmax = 21.3 + text = "Z" + intervals [260]: + xmin = 21.3 + xmax = 21.36 + text = "F" + intervals [261]: + xmin = 21.36 + xmax = 21.47 + text = "AE1" + intervals [262]: + xmin = 21.47 + xmax = 21.56 + text = "SH" + intervals [263]: + xmin = 21.56 + xmax = 21.59 + text = "AH0" + intervals [264]: + xmin = 21.59 + xmax = 21.62 + text = "N" + intervals [265]: + xmin = 21.62 + xmax = 21.68 + text = "M" + intervals [266]: + xmin = 21.68 + xmax = 21.76 + text = "AE1" + intervals [267]: + xmin = 21.76 + xmax = 21.81 + text = "G" + intervals [268]: + xmin = 21.81 + xmax = 21.85 + text = "AH0" + intervals [269]: + xmin = 21.85 + xmax = 21.9 + text = "Z" + intervals [270]: + xmin = 21.9 + xmax = 22.0 + text = "IY2" + intervals [271]: + xmin = 22.0 + xmax = 22.1 + text = "N" + intervals [272]: + xmin = 22.1 + xmax = 22.19 + text = "Z" + intervals [273]: + xmin = 22.19 + xmax = 22.22 + text = "IH2" + intervals [274]: + xmin = 22.22 + xmax = 22.29 + text = "N" + intervals [275]: + xmin = 22.29 + xmax = 22.34 + text = "S" + intervals [276]: + xmin = 22.34 + xmax = 22.38 + text = "P" + intervals [277]: + xmin = 22.38 + xmax = 22.48 + text = "ER0" + intervals [278]: + xmin = 22.48 + xmax = 22.55 + text = "EY1" + intervals [279]: + xmin = 22.55 + xmax = 22.64 + text = "SH" + intervals [280]: + xmin = 22.64 + xmax = 22.67 + text = "AH0" + intervals [281]: + xmin = 22.67 + xmax = 22.7 + text = "N" + intervals [282]: + xmin = 22.7 + xmax = 22.73 + text = "AH0" + intervals [283]: + xmin = 22.73 + xmax = 22.8 + text = "L" + intervals [284]: + xmin = 22.8 + xmax = 22.88 + text = "B" + intervals [285]: + xmin = 22.88 + xmax = 23.03 + text = "UH1" + intervals [286]: + xmin = 23.03 + xmax = 23.09 + text = "K" + intervals [287]: + xmin = 23.09 + xmax = 23.15 + text = "S" + intervals [288]: + xmin = 23.15 + xmax = 23.24 + text = "AH0" + intervals [289]: + xmin = 23.24 + xmax = 23.35 + text = "N" + intervals [290]: + xmin = 23.35 + xmax = 23.44 + text = "D" + intervals [291]: + xmin = 23.44 + xmax = 23.5 + text = "P" + intervals [292]: + xmin = 23.5 + xmax = 23.55 + text = "R" + intervals [293]: + xmin = 23.55 + xmax = 23.59 + text = "AH0" + intervals [294]: + xmin = 23.59 + xmax = 23.69 + text = "F" + intervals [295]: + xmin = 23.69 + xmax = 23.76 + text = "EH1" + intervals [296]: + xmin = 23.76 + xmax = 23.87 + text = "SH" + intervals [297]: + xmin = 23.87 + xmax = 23.9 + text = "AH0" + intervals [298]: + xmin = 23.9 + xmax = 23.94 + text = "N" + intervals [299]: + xmin = 23.94 + xmax = 23.98 + text = "AH0" + intervals [300]: + xmin = 23.98 + xmax = 24.04 + text = "L" + intervals [301]: + xmin = 24.04 + xmax = 24.12 + text = "B" + intervals [302]: + xmin = 24.12 + xmax = 24.24 + text = "UH1" + intervals [303]: + xmin = 24.24 + xmax = 24.32 + text = "K" + intervals [304]: + xmin = 24.32 + xmax = 24.46 + text = "S" + intervals [305]: + xmin = 24.46 + xmax = 24.83 + text = "" + intervals [306]: + xmin = 24.83 + xmax = 24.91 + text = "DH" + intervals [307]: + xmin = 24.91 + xmax = 24.98 + text = "IY1" + intervals [308]: + xmin = 24.98 + xmax = 25.06 + text = "Z" + intervals [309]: + xmin = 25.06 + xmax = 25.13 + text = "B" + intervals [310]: + xmin = 25.13 + xmax = 25.23 + text = "UH1" + intervals [311]: + xmin = 25.23 + xmax = 25.3 + text = "K" + intervals [312]: + xmin = 25.3 + xmax = 25.37 + text = "S" + intervals [313]: + xmin = 25.37 + xmax = 25.44 + text = "K" + intervals [314]: + xmin = 25.44 + xmax = 25.51 + text = "AH0" + intervals [315]: + xmin = 25.51 + xmax = 25.54 + text = "N" + intervals [316]: + xmin = 25.54 + xmax = 25.59 + text = "G" + intervals [317]: + xmin = 25.59 + xmax = 25.63 + text = "IH1" + intervals [318]: + xmin = 25.63 + xmax = 25.66 + text = "V" + intervals [319]: + xmin = 25.66 + xmax = 25.71 + text = "M" + intervals [320]: + xmin = 25.71 + xmax = 25.76 + text = "IY1" + intervals [321]: + xmin = 25.76 + xmax = 25.82 + text = "DH" + intervals [322]: + xmin = 25.82 + xmax = 25.86 + text = "AH0" + intervals [323]: + xmin = 25.86 + xmax = 25.95 + text = "M" + intervals [324]: + xmin = 25.95 + xmax = 26.01 + text = "OW2" + intervals [325]: + xmin = 26.01 + xmax = 26.06 + text = "T" + intervals [326]: + xmin = 26.06 + xmax = 26.1 + text = "AH0" + intervals [327]: + xmin = 26.1 + xmax = 26.19 + text = "V" + intervals [328]: + xmin = 26.19 + xmax = 26.32 + text = "EY1" + intervals [329]: + xmin = 26.32 + xmax = 26.42 + text = "SH" + intervals [330]: + xmin = 26.42 + xmax = 26.51 + text = "AH0" + intervals [331]: + xmin = 26.51 + xmax = 26.85 + text = "N" + intervals [332]: + xmin = 26.85 + xmax = 26.88 + text = "" + intervals [333]: + xmin = 26.88 + xmax = 27.0 + text = "T" + intervals [334]: + xmin = 27.0 + xmax = 27.07 + text = "IH0" + intervals [335]: + xmin = 27.07 + xmax = 27.13 + text = "B" + intervals [336]: + xmin = 27.13 + xmax = 27.37 + text = "IY1" + intervals [337]: + xmin = 27.37 + xmax = 27.5 + text = "HH" + intervals [338]: + xmin = 27.5 + xmax = 27.55 + text = "EH1" + intervals [339]: + xmin = 27.55 + xmax = 27.68 + text = "L" + intervals [340]: + xmin = 27.68 + xmax = 27.72 + text = "TH" + intervals [341]: + xmin = 27.72 + xmax = 27.86 + text = "IY0" + intervals [342]: + xmin = 27.86 + xmax = 28.01 + text = "ER0" + intervals [343]: + xmin = 28.01 + xmax = 28.09 + text = "AE1" + intervals [344]: + xmin = 28.09 + xmax = 28.12 + text = "N" + intervals [345]: + xmin = 28.12 + xmax = 28.18 + text = "D" + intervals [346]: + xmin = 28.18 + xmax = 28.25 + text = "EH2" + intervals [347]: + xmin = 28.25 + xmax = 28.32 + text = "N" + intervals [348]: + xmin = 28.32 + xmax = 28.41 + text = "ER0" + intervals [349]: + xmin = 28.41 + xmax = 28.51 + text = "JH" + intervals [350]: + xmin = 28.51 + xmax = 28.59 + text = "EH1" + intervals [351]: + xmin = 28.59 + xmax = 28.62 + text = "T" + intervals [352]: + xmin = 28.62 + xmax = 28.71 + text = "IH0" + intervals [353]: + xmin = 28.71 + xmax = 28.9 + text = "K" + intervals [354]: + xmin = 28.9 + xmax = 29.1 + text = "" + intervals [355]: + xmin = 29.1 + xmax = 29.24 + text = "AE1" + intervals [356]: + xmin = 29.24 + xmax = 29.27 + text = "N" + intervals [357]: + xmin = 29.27 + xmax = 29.3 + text = "D" + intervals [358]: + xmin = 29.3 + xmax = 29.33 + text = "DH" + intervals [359]: + xmin = 29.33 + xmax = 29.37 + text = "AH0" + intervals [360]: + xmin = 29.37 + xmax = 29.47 + text = "L" + intervals [361]: + xmin = 29.47 + xmax = 29.62 + text = "AE1" + intervals [362]: + xmin = 29.62 + xmax = 29.74 + text = "S" + intervals [363]: + xmin = 29.74 + xmax = 29.8 + text = "TH" + intervals [364]: + xmin = 29.8 + xmax = 29.86 + text = "IH1" + intervals [365]: + xmin = 29.86 + xmax = 29.94 + text = "NG" + intervals [366]: + xmin = 29.94 + xmax = 30.14 + text = "AY1" + intervals [367]: + xmin = 30.14 + xmax = 30.23 + text = "L" + intervals [368]: + xmin = 30.23 + xmax = 30.38 + text = "AY1" + intervals [369]: + xmin = 30.38 + xmax = 30.42 + text = "K" + intervals [370]: + xmin = 30.42 + xmax = 30.48 + text = "T" + intervals [371]: + xmin = 30.48 + xmax = 30.53 + text = "IH0" + intervals [372]: + xmin = 30.53 + xmax = 30.59 + text = "D" + intervals [373]: + xmin = 30.59 + xmax = 30.84 + text = "UW1" + intervals [374]: + xmin = 30.84 + xmax = 30.97 + text = "W" + intervals [375]: + xmin = 30.97 + xmax = 31.03 + text = "EH1" + intervals [376]: + xmin = 31.03 + xmax = 31.22 + text = "N" + intervals [377]: + xmin = 31.22 + xmax = 31.35 + text = "AY1" + intervals [378]: + xmin = 31.35 + xmax = 31.43 + text = "M" + intervals [379]: + xmin = 31.43 + xmax = 31.55 + text = "F" + intervals [380]: + xmin = 31.55 + xmax = 31.65 + text = "R" + intervals [381]: + xmin = 31.65 + xmax = 31.87 + text = "IY1" + intervals [382]: + xmin = 31.87 + xmax = 31.91 + text = "IH1" + intervals [383]: + xmin = 31.91 + xmax = 31.99 + text = "Z" + intervals [384]: + xmin = 31.99 + xmax = 32.06 + text = "IH1" + intervals [385]: + xmin = 32.06 + xmax = 32.11 + text = "T" + intervals [386]: + xmin = 32.11 + xmax = 32.2 + text = "AW1" + intervals [387]: + xmin = 32.2 + xmax = 32.23 + text = "T" + intervals [388]: + xmin = 32.23 + xmax = 32.27 + text = "W" + intervals [389]: + xmin = 32.27 + xmax = 32.32 + text = "IH0" + intervals [390]: + xmin = 32.32 + xmax = 32.35 + text = "TH" + intervals [391]: + xmin = 32.35 + xmax = 32.39 + text = "M" + intervals [392]: + xmin = 32.39 + xmax = 32.48 + text = "AY1" + intervals [393]: + xmin = 32.48 + xmax = 32.61 + text = "F" + intervals [394]: + xmin = 32.61 + xmax = 32.72 + text = "AE1" + intervals [395]: + xmin = 32.72 + xmax = 32.76 + text = "M" + intervals [396]: + xmin = 32.76 + xmax = 32.81 + text = "L" + intervals [397]: + xmin = 32.81 + xmax = 32.86 + text = "IY0" + intervals [398]: + xmin = 32.86 + xmax = 32.92 + text = "M" + intervals [399]: + xmin = 32.92 + xmax = 32.97 + text = "EH1" + intervals [400]: + xmin = 32.97 + xmax = 33.0 + text = "M" + intervals [401]: + xmin = 33.0 + xmax = 33.05 + text = "B" + intervals [402]: + xmin = 33.05 + xmax = 33.16 + text = "ER0" + intervals [403]: + xmin = 33.16 + xmax = 33.33 + text = "Z" + intervals [404]: + xmin = 33.33 + xmax = 33.51 + text = "" + intervals [405]: + xmin = 33.51 + xmax = 33.75 + text = "Y" + intervals [406]: + xmin = 33.75 + xmax = 33.89 + text = "UW1" + intervals [407]: + xmin = 33.89 + xmax = 33.97 + text = "W" + intervals [408]: + xmin = 33.97 + xmax = 34.02 + text = "UH1" + intervals [409]: + xmin = 34.02 + xmax = 34.11 + text = "D" + intervals [410]: + xmin = 34.11 + xmax = 34.16 + text = "B" + intervals [411]: + xmin = 34.16 + xmax = 34.29 + text = "IY1" + intervals [412]: + xmin = 34.29 + xmax = 34.36 + text = "S" + intervals [413]: + xmin = 34.36 + xmax = 34.42 + text = "AH0" + intervals [414]: + xmin = 34.42 + xmax = 34.53 + text = "P" + intervals [415]: + xmin = 34.53 + xmax = 34.63 + text = "R" + intervals [416]: + xmin = 34.63 + xmax = 34.91 + text = "AY1" + intervals [417]: + xmin = 34.91 + xmax = 35.02 + text = "Z" + intervals [418]: + xmin = 35.02 + xmax = 35.07 + text = "D" + intervals [419]: + xmin = 35.07 + xmax = 35.1 + text = "T" + intervals [420]: + xmin = 35.1 + xmax = 35.16 + text = "IH0" + intervals [421]: + xmin = 35.16 + xmax = 35.26 + text = "N" + intervals [422]: + xmin = 35.26 + xmax = 35.36 + text = "OW1" + intervals [423]: + xmin = 35.36 + xmax = 35.44 + text = "DH" + intervals [424]: + xmin = 35.44 + xmax = 35.47 + text = "AE1" + intervals [425]: + xmin = 35.47 + xmax = 35.5 + text = "T" + intervals [426]: + xmin = 35.5 + xmax = 35.64 + text = "AY1" + intervals [427]: + xmin = 35.64 + xmax = 35.7 + text = "HH" + intervals [428]: + xmin = 35.7 + xmax = 35.75 + text = "AE1" + intervals [429]: + xmin = 35.75 + xmax = 35.84 + text = "V" + intervals [430]: + xmin = 35.84 + xmax = 35.9 + text = "T" + intervals [431]: + xmin = 35.9 + xmax = 35.98 + text = "R" + intervals [432]: + xmin = 35.98 + xmax = 36.15 + text = "AY1" + intervals [433]: + xmin = 36.15 + xmax = 36.3 + text = "D" + intervals [434]: + xmin = 36.3 + xmax = 36.57 + text = "" + intervals [435]: + xmin = 36.57 + xmax = 36.92 + text = "AO1" + intervals [436]: + xmin = 36.92 + xmax = 36.99 + text = "L" + intervals [437]: + xmin = 36.99 + xmax = 37.03 + text = "DH" + intervals [438]: + xmin = 37.03 + xmax = 37.08 + text = "AH1" + intervals [439]: + xmin = 37.08 + xmax = 37.17 + text = "R" + intervals [440]: + xmin = 37.17 + xmax = 37.28 + text = "EH1" + intervals [441]: + xmin = 37.28 + xmax = 37.34 + text = "S" + intervals [442]: + xmin = 37.34 + xmax = 37.39 + text = "T" + intervals [443]: + xmin = 37.39 + xmax = 37.44 + text = "R" + intervals [444]: + xmin = 37.44 + xmax = 37.55 + text = "AA2" + intervals [445]: + xmin = 37.55 + xmax = 37.6 + text = "N" + intervals [446]: + xmin = 37.6 + xmax = 37.64 + text = "T" + intervals [447]: + xmin = 37.64 + xmax = 37.68 + text = "S" + intervals [448]: + xmin = 37.68 + xmax = 37.71 + text = "" + intervals [449]: + xmin = 37.71 + xmax = 37.77 + text = "IH0" + intervals [450]: + xmin = 37.77 + xmax = 37.83 + text = "N" + intervals [451]: + xmin = 37.83 + xmax = 37.87 + text = "AA1" + intervals [452]: + xmin = 37.87 + xmax = 37.95 + text = "R" + intervals [453]: + xmin = 37.95 + xmax = 38.12 + text = "HH" + intervals [454]: + xmin = 38.12 + xmax = 38.2 + text = "Y" + intervals [455]: + xmin = 38.2 + xmax = 38.33 + text = "UW1" + intervals [456]: + xmin = 38.33 + xmax = 38.5 + text = "JH" + intervals [457]: + xmin = 38.5 + xmax = 38.53 + text = "K" + intervals [458]: + xmin = 38.53 + xmax = 38.59 + text = "AH0" + intervals [459]: + xmin = 38.59 + xmax = 38.64 + text = "M" + intervals [460]: + xmin = 38.64 + xmax = 38.67 + text = "Y" + intervals [461]: + xmin = 38.67 + xmax = 38.7 + text = "UW1" + intervals [462]: + xmin = 38.7 + xmax = 38.76 + text = "N" + intervals [463]: + xmin = 38.76 + xmax = 38.79 + text = "AH0" + intervals [464]: + xmin = 38.79 + xmax = 38.82 + text = "T" + intervals [465]: + xmin = 38.82 + xmax = 39.07 + text = "IY0" + intervals [466]: + xmin = 39.07 + xmax = 39.23 + text = "" + intervals [467]: + xmin = 39.23 + xmax = 39.6 + text = "AY1" + intervals [468]: + xmin = 39.6 + xmax = 39.81 + text = "AE1" + intervals [469]: + xmin = 39.81 + xmax = 39.86 + text = "K" + intervals [470]: + xmin = 39.86 + xmax = 39.93 + text = "SH" + intervals [471]: + xmin = 39.93 + xmax = 39.97 + text = "AH0" + intervals [472]: + xmin = 39.97 + xmax = 40.0 + text = "L" + intervals [473]: + xmin = 40.0 + xmax = 40.09 + text = "IY0" + intervals [474]: + xmin = 40.09 + xmax = 40.17 + text = "G" + intervals [475]: + xmin = 40.17 + xmax = 40.26 + text = "IH1" + intervals [476]: + xmin = 40.26 + xmax = 40.32 + text = "V" + intervals [477]: + xmin = 40.32 + xmax = 40.5 + text = "IY1" + intervals [478]: + xmin = 40.5 + xmax = 40.61 + text = "CH" + intervals [479]: + xmin = 40.61 + xmax = 40.7 + text = "R" + intervals [480]: + xmin = 40.7 + xmax = 40.78 + text = "EH1" + intervals [481]: + xmin = 40.78 + xmax = 40.83 + text = "S" + intervals [482]: + xmin = 40.83 + xmax = 40.9 + text = "T" + intervals [483]: + xmin = 40.9 + xmax = 40.94 + text = "R" + intervals [484]: + xmin = 40.94 + xmax = 41.02 + text = "AA2" + intervals [485]: + xmin = 41.02 + xmax = 41.05 + text = "N" + intervals [486]: + xmin = 41.05 + xmax = 41.08 + text = "T" + intervals [487]: + xmin = 41.08 + xmax = 41.15 + text = "AH0" + intervals [488]: + xmin = 41.15 + xmax = 41.29 + text = "S" + intervals [489]: + xmin = 41.29 + xmax = 41.33 + text = "K" + intervals [490]: + xmin = 41.33 + xmax = 41.44 + text = "AO1" + intervals [491]: + xmin = 41.44 + xmax = 41.55 + text = "R" + intervals [492]: + xmin = 41.55 + xmax = 41.61 + text = "B" + intervals [493]: + xmin = 41.61 + xmax = 41.73 + text = "EY1" + intervals [494]: + xmin = 41.73 + xmax = 41.77 + text = "S" + intervals [495]: + xmin = 41.77 + xmax = 41.82 + text = "T" + intervals [496]: + xmin = 41.82 + xmax = 41.85 + text = "AA1" + intervals [497]: + xmin = 41.85 + xmax = 41.89 + text = "N" + intervals [498]: + xmin = 41.89 + xmax = 41.98 + text = "HH" + intervals [499]: + xmin = 41.98 + xmax = 42.05 + text = "AW1" + intervals [500]: + xmin = 42.05 + xmax = 42.11 + text = "G" + intervals [501]: + xmin = 42.11 + xmax = 42.14 + text = "IH0" + intervals [502]: + xmin = 42.14 + xmax = 42.17 + text = "D" + intervals [503]: + xmin = 42.17 + xmax = 42.2 + text = "DH" + intervals [504]: + xmin = 42.2 + xmax = 42.23 + text = "IY0" + intervals [505]: + xmin = 42.23 + xmax = 42.35 + text = "F" + intervals [506]: + xmin = 42.35 + xmax = 42.48 + text = "UW1" + intervals [507]: + xmin = 42.48 + xmax = 42.51 + text = "D" + intervals [508]: + xmin = 42.51 + xmax = 42.71 + text = "IH1" + intervals [509]: + xmin = 42.71 + xmax = 42.85 + text = "Z" + intervals [510]: + xmin = 42.85 + xmax = 43.13 + text = "" + intervals [511]: + xmin = 43.13 + xmax = 43.28 + text = "HH" + intervals [512]: + xmin = 43.28 + xmax = 43.36 + text = "AW1" + intervals [513]: + xmin = 43.36 + xmax = 43.43 + text = "G" + intervals [514]: + xmin = 43.43 + xmax = 43.46 + text = "IH0" + intervals [515]: + xmin = 43.46 + xmax = 43.51 + text = "D" + intervals [516]: + xmin = 43.51 + xmax = 43.56 + text = "DH" + intervals [517]: + xmin = 43.56 + xmax = 43.62 + text = "IY0" + intervals [518]: + xmin = 43.62 + xmax = 43.65 + text = "IH0" + intervals [519]: + xmin = 43.65 + xmax = 43.69 + text = "N" + intervals [520]: + xmin = 43.69 + xmax = 43.78 + text = "V" + intervals [521]: + xmin = 43.78 + xmax = 43.89 + text = "AY1" + intervals [522]: + xmin = 43.89 + xmax = 43.92 + text = "R" + intervals [523]: + xmin = 43.92 + xmax = 43.95 + text = "AH0" + intervals [524]: + xmin = 43.95 + xmax = 43.98 + text = "N" + intervals [525]: + xmin = 43.98 + xmax = 44.01 + text = "M" + intervals [526]: + xmin = 44.01 + xmax = 44.04 + text = "AH0" + intervals [527]: + xmin = 44.04 + xmax = 44.07 + text = "N" + intervals [528]: + xmin = 44.07 + xmax = 44.1 + text = "T" + intervals [529]: + xmin = 44.1 + xmax = 44.25 + text = "IH1" + intervals [530]: + xmin = 44.25 + xmax = 44.4 + text = "Z" + intervals [531]: + xmin = 44.4 + xmax = 44.49 + text = "" + intervals [532]: + xmin = 44.49 + xmax = 44.79 + text = "AE1" + intervals [533]: + xmin = 44.79 + xmax = 44.9 + text = "N" + intervals [534]: + xmin = 44.9 + xmax = 44.98 + text = "D" + intervals [535]: + xmin = 44.98 + xmax = 45.26 + text = "AE1" + intervals [536]: + xmin = 45.26 + xmax = 45.34 + text = "T" + intervals [537]: + xmin = 45.34 + xmax = 45.39 + text = "DH" + intervals [538]: + xmin = 45.39 + xmax = 45.62 + text = "AH1" + intervals [539]: + xmin = 45.62 + xmax = 45.75 + text = "S" + intervals [540]: + xmin = 45.75 + xmax = 45.87 + text = "EY1" + intervals [541]: + xmin = 45.87 + xmax = 45.91 + text = "M" + intervals [542]: + xmin = 45.91 + xmax = 46.01 + text = "T" + intervals [543]: + xmin = 46.01 + xmax = 46.19 + text = "AY1" + intervals [544]: + xmin = 46.19 + xmax = 46.29 + text = "M" + intervals [545]: + xmin = 46.29 + xmax = 46.42 + text = "AY1" + intervals [546]: + xmin = 46.42 + xmax = 46.45 + text = "W" + intervals [547]: + xmin = 46.45 + xmax = 46.48 + text = "AH0" + intervals [548]: + xmin = 46.48 + xmax = 46.54 + text = "L" + intervals [549]: + xmin = 46.54 + xmax = 46.62 + text = "R" + intervals [550]: + xmin = 46.62 + xmax = 46.69 + text = "AY1" + intervals [551]: + xmin = 46.69 + xmax = 46.74 + text = "T" + intervals [552]: + xmin = 46.74 + xmax = 46.82 + text = "D" + intervals [553]: + xmin = 46.82 + xmax = 46.91 + text = "AW1" + intervals [554]: + xmin = 46.91 + xmax = 46.94 + text = "N" + intervals [555]: + xmin = 46.94 + xmax = 46.97 + text = "DH" + intervals [556]: + xmin = 46.97 + xmax = 47.02 + text = "AH1" + intervals [557]: + xmin = 47.02 + xmax = 47.1 + text = "T" + intervals [558]: + xmin = 47.1 + xmax = 47.19 + text = "AY1" + intervals [559]: + xmin = 47.19 + xmax = 47.24 + text = "P" + intervals [560]: + xmin = 47.24 + xmax = 47.29 + text = "AH0" + intervals [561]: + xmin = 47.29 + xmax = 47.39 + text = "V" + intervals [562]: + xmin = 47.39 + xmax = 47.45 + text = "F" + intervals [563]: + xmin = 47.45 + xmax = 47.64 + text = "UW1" + intervals [564]: + xmin = 47.64 + xmax = 47.8 + text = "D" + intervals [565]: + xmin = 47.8 + xmax = 48.03 + text = "" + intervals [566]: + xmin = 48.03 + xmax = 48.1 + text = "DH" + intervals [567]: + xmin = 48.1 + xmax = 48.24 + text = "EY1" + intervals [568]: + xmin = 48.24 + xmax = 48.37 + text = "S" + intervals [569]: + xmin = 48.37 + xmax = 48.58 + text = "ER1" + intervals [570]: + xmin = 48.58 + xmax = 48.76 + text = "V" + intervals [571]: + xmin = 48.76 + xmax = 49.42 + text = "" + intervals [572]: + xmin = 49.42 + xmax = 49.61 + text = "S" + intervals [573]: + xmin = 49.61 + xmax = 49.9 + text = "OW1" + intervals [574]: + xmin = 49.9 + xmax = 50.09 + text = "W" + intervals [575]: + xmin = 50.09 + xmax = 50.22 + text = "EH1" + intervals [576]: + xmin = 50.22 + xmax = 50.46 + text = "N" + intervals [577]: + xmin = 50.46 + xmax = 50.49 + text = "" + intervals [578]: + xmin = 50.49 + xmax = 50.58 + text = "Y" + intervals [579]: + xmin = 50.58 + xmax = 50.67 + text = "UW1" + intervals [580]: + xmin = 50.67 + xmax = 50.85 + text = "R" + intervals [581]: + xmin = 50.85 + xmax = 50.94 + text = "S" + intervals [582]: + xmin = 50.94 + xmax = 50.98 + text = "OW1" + intervals [583]: + xmin = 50.98 + xmax = 51.03 + text = "W" + intervals [584]: + xmin = 51.03 + xmax = 51.06 + text = "EH1" + intervals [585]: + xmin = 51.06 + xmax = 51.13 + text = "N" + intervals [586]: + xmin = 51.13 + xmax = 51.24 + text = "IY1" + intervals [587]: + xmin = 51.24 + xmax = 51.35 + text = "CH" + intervals [588]: + xmin = 51.35 + xmax = 51.41 + text = "T" + intervals [589]: + xmin = 51.41 + xmax = 51.49 + text = "AY1" + intervals [590]: + xmin = 51.49 + xmax = 51.55 + text = "M" + intervals [591]: + xmin = 51.55 + xmax = 51.62 + text = "AH0" + intervals [592]: + xmin = 51.62 + xmax = 51.69 + text = "F" + intervals [593]: + xmin = 51.69 + xmax = 51.74 + text = "R" + intervals [594]: + xmin = 51.74 + xmax = 51.8 + text = "EH1" + intervals [595]: + xmin = 51.8 + xmax = 51.85 + text = "N" + intervals [596]: + xmin = 51.85 + xmax = 51.91 + text = "D" + intervals [597]: + xmin = 51.91 + xmax = 51.98 + text = "K" + intervals [598]: + xmin = 51.98 + xmax = 52.03 + text = "AH1" + intervals [599]: + xmin = 52.03 + xmax = 52.14 + text = "M" + intervals [600]: + xmin = 52.14 + xmax = 52.32 + text = "Z" + intervals [601]: + xmin = 52.32 + xmax = 52.37 + text = "T" + intervals [602]: + xmin = 52.37 + xmax = 52.46 + text = "UW1" + intervals [603]: + xmin = 52.46 + xmax = 52.53 + text = "DH" + intervals [604]: + xmin = 52.53 + xmax = 52.59 + text = "AH0" + intervals [605]: + xmin = 52.59 + xmax = 52.68 + text = "S" + intervals [606]: + xmin = 52.68 + xmax = 52.74 + text = "IH1" + intervals [607]: + xmin = 52.74 + xmax = 52.77 + text = "T" + intervals [608]: + xmin = 52.77 + xmax = 52.9 + text = "IY0" + intervals [609]: + xmin = 52.9 + xmax = 52.97 + text = "T" + intervals [610]: + xmin = 52.97 + xmax = 53.07 + text = "UW1" + intervals [611]: + xmin = 53.07 + xmax = 53.12 + text = "IH0" + intervals [612]: + xmin = 53.12 + xmax = 53.17 + text = "N" + intervals [613]: + xmin = 53.17 + xmax = 53.26 + text = "JH" + intervals [614]: + xmin = 53.26 + xmax = 53.35 + text = "OY1" + intervals [615]: + xmin = 53.35 + xmax = 53.46 + text = "T" + intervals [616]: + xmin = 53.46 + xmax = 53.59 + text = "AY1" + intervals [617]: + xmin = 53.59 + xmax = 53.62 + text = "M" + intervals [618]: + xmin = 53.62 + xmax = 53.65 + text = "W" + intervals [619]: + xmin = 53.65 + xmax = 53.7 + text = "IH1" + intervals [620]: + xmin = 53.7 + xmax = 53.74 + text = "DH" + intervals [621]: + xmin = 53.74 + xmax = 53.81 + text = "M" + intervals [622]: + xmin = 53.81 + xmax = 54.02 + text = "IY1" + intervals [623]: + xmin = 54.02 + xmax = 54.31 + text = "" + intervals [624]: + xmin = 54.31 + xmax = 54.54 + text = "AY1" + intervals [625]: + xmin = 54.54 + xmax = 54.6 + text = "W" + intervals [626]: + xmin = 54.6 + xmax = 54.66 + text = "AH0" + intervals [627]: + xmin = 54.66 + xmax = 54.69 + text = "L" + intervals [628]: + xmin = 54.69 + xmax = 54.77 + text = "G" + intervals [629]: + xmin = 54.77 + xmax = 54.81 + text = "IH1" + intervals [630]: + xmin = 54.81 + xmax = 54.84 + text = "V" + intervals [631]: + xmin = 54.84 + xmax = 54.87 + text = "DH" + intervals [632]: + xmin = 54.87 + xmax = 54.94 + text = "AH0" + intervals [633]: + xmin = 54.94 + xmax = 54.97 + text = "M" + intervals [634]: + xmin = 54.97 + xmax = 55.0 + text = "DH" + intervals [635]: + xmin = 55.0 + xmax = 55.07 + text = "AH0" + intervals [636]: + xmin = 55.07 + xmax = 55.17 + text = "T" + intervals [637]: + xmin = 55.17 + xmax = 55.27 + text = "AA1" + intervals [638]: + xmin = 55.27 + xmax = 55.38 + text = "P" + intervals [639]: + xmin = 55.38 + xmax = 55.53 + text = "spn" + intervals [640]: + xmin = 55.53 + xmax = 55.62 + text = "R" + intervals [641]: + xmin = 55.62 + xmax = 55.7 + text = "EH1" + intervals [642]: + xmin = 55.7 + xmax = 55.76 + text = "S" + intervals [643]: + xmin = 55.76 + xmax = 55.83 + text = "T" + intervals [644]: + xmin = 55.83 + xmax = 55.86 + text = "R" + intervals [645]: + xmin = 55.86 + xmax = 55.97 + text = "AA2" + intervals [646]: + xmin = 55.97 + xmax = 56.01 + text = "N" + intervals [647]: + xmin = 56.01 + xmax = 56.04 + text = "T" + intervals [648]: + xmin = 56.04 + xmax = 56.1 + text = "S" + intervals [649]: + xmin = 56.1 + xmax = 56.17 + text = "B" + intervals [650]: + xmin = 56.17 + xmax = 56.32 + text = "EY1" + intervals [651]: + xmin = 56.32 + xmax = 56.37 + text = "S" + intervals [652]: + xmin = 56.37 + xmax = 56.44 + text = "T" + intervals [653]: + xmin = 56.44 + xmax = 56.64 + text = "AA1" + intervals [654]: + xmin = 56.64 + xmax = 56.68 + text = "N" + intervals [655]: + xmin = 56.68 + xmax = 56.72 + text = "DH" + intervals [656]: + xmin = 56.72 + xmax = 56.9 + text = "IH0" + intervals [657]: + xmin = 56.9 + xmax = 56.99 + text = "S" + intervals [658]: + xmin = 56.99 + xmax = 57.05 + text = "R" + intervals [659]: + xmin = 57.05 + xmax = 57.11 + text = "AE1" + intervals [660]: + xmin = 57.11 + xmax = 57.16 + text = "NG" + intervals [661]: + xmin = 57.16 + xmax = 57.22 + text = "K" + intervals [662]: + xmin = 57.22 + xmax = 57.29 + text = "IH0" + intervals [663]: + xmin = 57.29 + xmax = 57.35 + text = "NG" + intervals [664]: + xmin = 57.35 + xmax = 57.4 + text = "AH0" + intervals [665]: + xmin = 57.4 + xmax = 57.43 + text = "N" + intervals [666]: + xmin = 57.43 + xmax = 57.53 + text = "D" + intervals [667]: + xmin = 57.53 + xmax = 57.7 + text = "EH1" + intervals [668]: + xmin = 57.7 + xmax = 57.76 + text = "V" + intervals [669]: + xmin = 57.76 + xmax = 57.82 + text = "R" + intervals [670]: + xmin = 57.82 + xmax = 57.86 + text = "IY0" + intervals [671]: + xmin = 57.86 + xmax = 57.95 + text = "T" + intervals [672]: + xmin = 57.95 + xmax = 58.19 + text = "AY1" + intervals [673]: + xmin = 58.19 + xmax = 58.44 + text = "M" + intervals [674]: + xmin = 58.44 + xmax = 59.02 + text = "" + intervals [675]: + xmin = 59.02 + xmax = 59.12 + text = "Y" + intervals [676]: + xmin = 59.12 + xmax = 59.15 + text = "UH1" + intervals [677]: + xmin = 59.15 + xmax = 59.2 + text = "R" + intervals [678]: + xmin = 59.2 + xmax = 59.32 + text = "S" + intervals [679]: + xmin = 59.32 + xmax = 59.41 + text = "AE1" + intervals [680]: + xmin = 59.41 + xmax = 59.44 + text = "T" + intervals [681]: + xmin = 59.44 + xmax = 59.49 + text = "AH0" + intervals [682]: + xmin = 59.49 + xmax = 59.55 + text = "S" + intervals [683]: + xmin = 59.55 + xmax = 59.62 + text = "F" + intervals [684]: + xmin = 59.62 + xmax = 59.69 + text = "AY2" + intervals [685]: + xmin = 59.69 + xmax = 59.72 + text = "D" + intervals [686]: + xmin = 59.72 + xmax = 59.77 + text = "W" + intervals [687]: + xmin = 59.77 + xmax = 59.82 + text = "IH0" + intervals [688]: + xmin = 59.82 + xmax = 59.85 + text = "DH" + intervals [689]: + xmin = 59.85 + xmax = 59.88 + text = "DH" + intervals [690]: + xmin = 59.88 + xmax = 60.0 + text = "IY1" + intervals [691]: + xmin = 60.0 + xmax = 60.1 + text = "Z" + intervals [692]: + xmin = 60.1 + xmax = 60.15 + text = "R" + intervals [693]: + xmin = 60.15 + xmax = 60.23 + text = "EH1" + intervals [694]: + xmin = 60.23 + xmax = 60.28 + text = "S" + intervals [695]: + xmin = 60.28 + xmax = 60.34 + text = "T" + intervals [696]: + xmin = 60.34 + xmax = 60.4 + text = "R" + intervals [697]: + xmin = 60.4 + xmax = 60.52 + text = "AA2" + intervals [698]: + xmin = 60.52 + xmax = 60.57 + text = "N" + intervals [699]: + xmin = 60.57 + xmax = 60.62 + text = "T" + intervals [700]: + xmin = 60.62 + xmax = 60.81 + text = "S" + intervals [701]: + xmin = 60.81 + xmax = 62 + text = "" diff --git a/EMAGE/test_sequences/textgrid/2_scott_0_3_3.TextGrid b/EMAGE/test_sequences/textgrid/2_scott_0_3_3.TextGrid new file mode 100644 index 0000000000000000000000000000000000000000..42b1321f18dcafe25ab172fe1607812a4c5c8eb9 --- /dev/null +++ b/EMAGE/test_sequences/textgrid/2_scott_0_3_3.TextGrid @@ -0,0 +1,3676 @@ +File type = "ooTextFile" +Object class = "TextGrid" + +xmin = 0.0 +xmax = 68 +tiers? +size = 2 +item []: + item [1]: + class = "IntervalTier" + name = "words" + xmin = 0.0 + xmax = 68 + intervals: size = 213 + intervals [1]: + xmin = 0.0 + xmax = 1.47 + text = "" + intervals [2]: + xmin = 1.47 + xmax = 2.36 + text = "well" + intervals [3]: + xmin = 2.36 + xmax = 2.56 + text = "" + intervals [4]: + xmin = 2.56 + xmax = 3.05 + text = "in" + intervals [5]: + xmin = 3.05 + xmax = 3.43 + text = "my" + intervals [6]: + xmin = 3.43 + xmax = 4.29 + text = "opinion" + intervals [7]: + xmin = 4.29 + xmax = 4.43 + text = "i" + intervals [8]: + xmin = 4.43 + xmax = 4.7 + text = "think" + intervals [9]: + xmin = 4.7 + xmax = 4.77 + text = "the" + intervals [10]: + xmin = 4.77 + xmax = 5.06 + text = "best" + intervals [11]: + xmin = 5.06 + xmax = 5.31 + text = "job" + intervals [12]: + xmin = 5.31 + xmax = 5.41 + text = "for" + intervals [13]: + xmin = 5.41 + xmax = 5.62 + text = "me" + intervals [14]: + xmin = 5.62 + xmax = 5.76 + text = "is" + intervals [15]: + xmin = 5.76 + xmax = 5.84 + text = "to" + intervals [16]: + xmin = 5.84 + xmax = 6.11 + text = "become" + intervals [17]: + xmin = 6.11 + xmax = 6.2 + text = "a" + intervals [18]: + xmin = 6.2 + xmax = 6.93 + text = "journalist" + intervals [19]: + xmin = 6.93 + xmax = 7.3 + text = "cuz" + intervals [20]: + xmin = 7.3 + xmax = 7.53 + text = "this" + intervals [21]: + xmin = 7.53 + xmax = 7.63 + text = "is" + intervals [22]: + xmin = 7.63 + xmax = 7.79 + text = "my" + intervals [23]: + xmin = 7.79 + xmax = 8.14 + text = "dream" + intervals [24]: + xmin = 8.14 + xmax = 8.53 + text = "job" + intervals [25]: + xmin = 8.53 + xmax = 8.71 + text = "i've" + intervals [26]: + xmin = 8.71 + xmax = 9.13 + text = "always" + intervals [27]: + xmin = 9.13 + xmax = 9.53 + text = "wanted" + intervals [28]: + xmin = 9.53 + xmax = 9.6 + text = "to" + intervals [29]: + xmin = 9.6 + xmax = 9.77 + text = "be" + intervals [30]: + xmin = 9.77 + xmax = 9.84 + text = "a" + intervals [31]: + xmin = 9.84 + xmax = 10.29 + text = "journalist" + intervals [32]: + xmin = 10.29 + xmax = 10.48 + text = "since" + intervals [33]: + xmin = 10.48 + xmax = 10.54 + text = "i" + intervals [34]: + xmin = 10.54 + xmax = 10.71 + text = "was" + intervals [35]: + xmin = 10.71 + xmax = 10.78 + text = "in" + intervals [36]: + xmin = 10.78 + xmax = 11.01 + text = "middle" + intervals [37]: + xmin = 11.01 + xmax = 11.4 + text = "school" + intervals [38]: + xmin = 11.4 + xmax = 11.96 + text = "" + intervals [39]: + xmin = 11.96 + xmax = 12.92 + text = "journalists" + intervals [40]: + xmin = 12.92 + xmax = 13.73 + text = "never" + intervals [41]: + xmin = 13.73 + xmax = 13.81 + text = "" + intervals [42]: + xmin = 13.81 + xmax = 14.38 + text = "tell" + intervals [43]: + xmin = 14.38 + xmax = 15.08 + text = "lies" + intervals [44]: + xmin = 15.08 + xmax = 15.19 + text = "and" + intervals [45]: + xmin = 15.19 + xmax = 15.38 + text = "are" + intervals [46]: + xmin = 15.38 + xmax = 15.75 + text = "always" + intervals [47]: + xmin = 15.75 + xmax = 16.07 + text = "seeking" + intervals [48]: + xmin = 16.07 + xmax = 16.17 + text = "the" + intervals [49]: + xmin = 16.17 + xmax = 16.69 + text = "truth" + intervals [50]: + xmin = 16.69 + xmax = 16.85 + text = "" + intervals [51]: + xmin = 16.85 + xmax = 17.09 + text = "i" + intervals [52]: + xmin = 17.09 + xmax = 17.27 + text = "want" + intervals [53]: + xmin = 17.27 + xmax = 17.33 + text = "to" + intervals [54]: + xmin = 17.33 + xmax = 17.45 + text = "be" + intervals [55]: + xmin = 17.45 + xmax = 17.73 + text = "just" + intervals [56]: + xmin = 17.73 + xmax = 18.0 + text = "like" + intervals [57]: + xmin = 18.0 + xmax = 18.36 + text = "that" + intervals [58]: + xmin = 18.36 + xmax = 18.73 + text = "" + intervals [59]: + xmin = 18.73 + xmax = 19.54 + text = "i" + intervals [60]: + xmin = 19.54 + xmax = 19.71 + text = "" + intervals [61]: + xmin = 19.71 + xmax = 20.27 + text = "usually" + intervals [62]: + xmin = 20.27 + xmax = 20.52 + text = "feel" + intervals [63]: + xmin = 20.52 + xmax = 20.96 + text = "shy" + intervals [64]: + xmin = 20.96 + xmax = 21.14 + text = "when" + intervals [65]: + xmin = 21.14 + xmax = 21.26 + text = "i" + intervals [66]: + xmin = 21.26 + xmax = 21.38 + text = "am" + intervals [67]: + xmin = 21.38 + xmax = 21.75 + text = "talking" + intervals [68]: + xmin = 21.75 + xmax = 21.86 + text = "to" + intervals [69]: + xmin = 21.86 + xmax = 22.22 + text = "others" + intervals [70]: + xmin = 22.22 + xmax = 22.37 + text = "and" + intervals [71]: + xmin = 22.37 + xmax = 22.46 + text = "i" + intervals [72]: + xmin = 22.46 + xmax = 22.8 + text = "know" + intervals [73]: + xmin = 22.8 + xmax = 23.1 + text = "that" + intervals [74]: + xmin = 23.1 + xmax = 23.79 + text = "journalists" + intervals [75]: + xmin = 23.79 + xmax = 23.99 + text = "are" + intervals [76]: + xmin = 23.99 + xmax = 24.73 + text = "very" + intervals [77]: + xmin = 24.73 + xmax = 25.29 + text = "good" + intervals [78]: + xmin = 25.29 + xmax = 25.44 + text = "" + intervals [79]: + xmin = 25.44 + xmax = 25.7 + text = "at" + intervals [80]: + xmin = 25.7 + xmax = 26.41 + text = "communicating" + intervals [81]: + xmin = 26.41 + xmax = 26.94 + text = "because" + intervals [82]: + xmin = 26.94 + xmax = 26.98 + text = "" + intervals [83]: + xmin = 26.98 + xmax = 27.2 + text = "good" + intervals [84]: + xmin = 27.2 + xmax = 27.95 + text = "communication" + intervals [85]: + xmin = 27.95 + xmax = 28.36 + text = "skills" + intervals [86]: + xmin = 28.36 + xmax = 28.51 + text = "are" + intervals [87]: + xmin = 28.51 + xmax = 28.75 + text = "very" + intervals [88]: + xmin = 28.75 + xmax = 29.22 + text = "important" + intervals [89]: + xmin = 29.22 + xmax = 29.39 + text = "when" + intervals [90]: + xmin = 29.39 + xmax = 29.51 + text = "you're" + intervals [91]: + xmin = 29.51 + xmax = 29.85 + text = "doing" + intervals [92]: + xmin = 29.85 + xmax = 30.43 + text = "interviews" + intervals [93]: + xmin = 30.43 + xmax = 30.71 + text = "" + intervals [94]: + xmin = 30.71 + xmax = 30.99 + text = "i" + intervals [95]: + xmin = 30.99 + xmax = 31.21 + text = "want" + intervals [96]: + xmin = 31.21 + xmax = 31.3 + text = "to" + intervals [97]: + xmin = 31.3 + xmax = 31.71 + text = "possess" + intervals [98]: + xmin = 31.71 + xmax = 31.82 + text = "the" + intervals [99]: + xmin = 31.82 + xmax = 32.16 + text = "skill" + intervals [100]: + xmin = 32.16 + xmax = 32.93 + text = "myself" + intervals [101]: + xmin = 32.93 + xmax = 33.0 + text = "" + intervals [102]: + xmin = 33.0 + xmax = 33.53 + text = "so" + intervals [103]: + xmin = 33.53 + xmax = 33.56 + text = "" + intervals [104]: + xmin = 33.56 + xmax = 33.95 + text = "that's" + intervals [105]: + xmin = 33.95 + xmax = 34.31 + text = "why" + intervals [106]: + xmin = 34.31 + xmax = 34.53 + text = "i" + intervals [107]: + xmin = 34.53 + xmax = 34.83 + text = "want" + intervals [108]: + xmin = 34.83 + xmax = 34.89 + text = "to" + intervals [109]: + xmin = 34.89 + xmax = 35.42 + text = "become" + intervals [110]: + xmin = 35.42 + xmax = 35.46 + text = "" + intervals [111]: + xmin = 35.46 + xmax = 35.59 + text = "a" + intervals [112]: + xmin = 35.59 + xmax = 36.37 + text = "journalist" + intervals [113]: + xmin = 36.37 + xmax = 36.74 + text = "" + intervals [114]: + xmin = 36.74 + xmax = 37.02 + text = "other" + intervals [115]: + xmin = 37.02 + xmax = 37.18 + text = "than" + intervals [116]: + xmin = 37.18 + xmax = 37.73 + text = "that" + intervals [117]: + xmin = 37.73 + xmax = 37.76 + text = "" + intervals [118]: + xmin = 37.76 + xmax = 38.72 + text = "photography" + intervals [119]: + xmin = 38.72 + xmax = 38.96 + text = "" + intervals [120]: + xmin = 38.96 + xmax = 39.4 + text = "often" + intervals [121]: + xmin = 39.4 + xmax = 39.66 + text = "makes" + intervals [122]: + xmin = 39.66 + xmax = 39.8 + text = "me" + intervals [123]: + xmin = 39.8 + xmax = 40.39 + text = "feel" + intervals [124]: + xmin = 40.39 + xmax = 40.72 + text = "like" + intervals [125]: + xmin = 40.72 + xmax = 41.42 + text = "i'm" + intervals [126]: + xmin = 41.42 + xmax = 41.77 + text = "" + intervals [127]: + xmin = 41.77 + xmax = 42.06 + text = "doing" + intervals [128]: + xmin = 42.06 + xmax = 42.13 + text = "a" + intervals [129]: + xmin = 42.13 + xmax = 42.41 + text = "job" + intervals [130]: + xmin = 42.41 + xmax = 42.58 + text = "full" + intervals [131]: + xmin = 42.58 + xmax = 42.64 + text = "of" + intervals [132]: + xmin = 42.64 + xmax = 42.99 + text = "design" + intervals [133]: + xmin = 42.99 + xmax = 43.12 + text = "and" + intervals [134]: + xmin = 43.12 + xmax = 43.25 + text = "in" + intervals [135]: + xmin = 43.25 + xmax = 43.52 + text = "for" + intervals [136]: + xmin = 43.52 + xmax = 44.31 + text = "innovation" + intervals [137]: + xmin = 44.31 + xmax = 45.18 + text = "" + intervals [138]: + xmin = 45.18 + xmax = 45.73 + text = "because" + intervals [139]: + xmin = 45.73 + xmax = 45.89 + text = "for" + intervals [140]: + xmin = 45.89 + xmax = 45.99 + text = "the" + intervals [141]: + xmin = 45.99 + xmax = 46.35 + text = "same" + intervals [142]: + xmin = 46.35 + xmax = 46.88 + text = "scenery" + intervals [143]: + xmin = 46.88 + xmax = 47.01 + text = "we" + intervals [144]: + xmin = 47.01 + xmax = 47.12 + text = "can" + intervals [145]: + xmin = 47.12 + xmax = 47.34 + text = "use" + intervals [146]: + xmin = 47.34 + xmax = 47.61 + text = "different" + intervals [147]: + xmin = 47.61 + xmax = 48.12 + text = "angles" + intervals [148]: + xmin = 48.12 + xmax = 48.21 + text = "and" + intervals [149]: + xmin = 48.21 + xmax = 48.48 + text = "different" + intervals [150]: + xmin = 48.48 + xmax = 49.32 + text = "compositions" + intervals [151]: + xmin = 49.32 + xmax = 49.68 + text = "" + intervals [152]: + xmin = 49.68 + xmax = 49.98 + text = "for" + intervals [153]: + xmin = 49.98 + xmax = 50.4 + text = "example" + intervals [154]: + xmin = 50.4 + xmax = 50.91 + text = "people" + intervals [155]: + xmin = 50.91 + xmax = 51.42 + text = "being" + intervals [156]: + xmin = 51.42 + xmax = 51.91 + text = "shifted" + intervals [157]: + xmin = 51.91 + xmax = 52.06 + text = "from" + intervals [158]: + xmin = 52.06 + xmax = 52.15 + text = "the" + intervals [159]: + xmin = 52.15 + xmax = 52.47 + text = "center" + intervals [160]: + xmin = 52.47 + xmax = 52.54 + text = "of" + intervals [161]: + xmin = 52.54 + xmax = 52.63 + text = "the" + intervals [162]: + xmin = 52.63 + xmax = 52.95 + text = "frame" + intervals [163]: + xmin = 52.95 + xmax = 53.1 + text = "to" + intervals [164]: + xmin = 53.1 + xmax = 53.32 + text = "left" + intervals [165]: + xmin = 53.32 + xmax = 53.54 + text = "side" + intervals [166]: + xmin = 53.54 + xmax = 53.6 + text = "of" + intervals [167]: + xmin = 53.6 + xmax = 53.69 + text = "the" + intervals [168]: + xmin = 53.69 + xmax = 54.17 + text = "frame" + intervals [169]: + xmin = 54.17 + xmax = 54.62 + text = "" + intervals [170]: + xmin = 54.62 + xmax = 54.75 + text = "it" + intervals [171]: + xmin = 54.75 + xmax = 54.91 + text = "can" + intervals [172]: + xmin = 54.91 + xmax = 55.13 + text = "make" + intervals [173]: + xmin = 55.13 + xmax = 55.22 + text = "a" + intervals [174]: + xmin = 55.22 + xmax = 55.56 + text = "different" + intervals [175]: + xmin = 55.56 + xmax = 56.05 + text = "feeling" + intervals [176]: + xmin = 56.05 + xmax = 56.41 + text = "" + intervals [177]: + xmin = 56.41 + xmax = 56.69 + text = "when" + intervals [178]: + xmin = 56.69 + xmax = 57.25 + text = "we" + intervals [179]: + xmin = 57.25 + xmax = 57.5 + text = "" + intervals [180]: + xmin = 57.5 + xmax = 57.82 + text = "when" + intervals [181]: + xmin = 57.82 + xmax = 58.07 + text = "seen" + intervals [182]: + xmin = 58.07 + xmax = 58.17 + text = "in" + intervals [183]: + xmin = 58.17 + xmax = 58.77 + text = "context" + intervals [184]: + xmin = 58.77 + xmax = 58.86 + text = "with" + intervals [185]: + xmin = 58.86 + xmax = 58.93 + text = "the" + intervals [186]: + xmin = 58.93 + xmax = 59.61 + text = "background" + intervals [187]: + xmin = 59.61 + xmax = 59.96 + text = "" + intervals [188]: + xmin = 59.96 + xmax = 60.31 + text = "when" + intervals [189]: + xmin = 60.31 + xmax = 60.79 + text = "everyone's" + intervals [190]: + xmin = 60.79 + xmax = 61.08 + text = "taking" + intervals [191]: + xmin = 61.08 + xmax = 61.14 + text = "a" + intervals [192]: + xmin = 61.14 + xmax = 61.46 + text = "picture" + intervals [193]: + xmin = 61.46 + xmax = 61.53 + text = "of" + intervals [194]: + xmin = 61.53 + xmax = 61.6 + text = "the" + intervals [195]: + xmin = 61.6 + xmax = 61.98 + text = "exact" + intervals [196]: + xmin = 61.98 + xmax = 62.19 + text = "same" + intervals [197]: + xmin = 62.19 + xmax = 62.64 + text = "scenery" + intervals [198]: + xmin = 62.64 + xmax = 62.67 + text = "" + intervals [199]: + xmin = 62.67 + xmax = 62.89 + text = "i'm" + intervals [200]: + xmin = 62.89 + xmax = 63.28 + text = "very" + intervals [201]: + xmin = 63.28 + xmax = 63.75 + text = "happy" + intervals [202]: + xmin = 63.75 + xmax = 63.9 + text = "when" + intervals [203]: + xmin = 63.9 + xmax = 64.35 + text = "people" + intervals [204]: + xmin = 64.35 + xmax = 64.6 + text = "say" + intervals [205]: + xmin = 64.6 + xmax = 64.84 + text = "my" + intervals [206]: + xmin = 64.84 + xmax = 65.29 + text = "photos" + intervals [207]: + xmin = 65.29 + xmax = 65.47 + text = "look" + intervals [208]: + xmin = 65.47 + xmax = 66.11 + text = "better" + intervals [209]: + xmin = 66.11 + xmax = 66.43 + text = "" + intervals [210]: + xmin = 66.43 + xmax = 66.56 + text = "than" + intervals [211]: + xmin = 66.56 + xmax = 66.7 + text = "the" + intervals [212]: + xmin = 66.7 + xmax = 67.19 + text = "others" + intervals [213]: + xmin = 67.19 + xmax = 68 + text = "" + item [2]: + class = "IntervalTier" + name = "phones" + xmin = 0.0 + xmax = 68 + intervals: size = 701 + intervals [1]: + xmin = 0.0 + xmax = 1.47 + text = "" + intervals [2]: + xmin = 1.47 + xmax = 1.59 + text = "W" + intervals [3]: + xmin = 1.59 + xmax = 1.99 + text = "EH1" + intervals [4]: + xmin = 1.99 + xmax = 2.36 + text = "L" + intervals [5]: + xmin = 2.36 + xmax = 2.56 + text = "" + intervals [6]: + xmin = 2.56 + xmax = 2.88 + text = "IH0" + intervals [7]: + xmin = 2.88 + xmax = 3.05 + text = "N" + intervals [8]: + xmin = 3.05 + xmax = 3.18 + text = "M" + intervals [9]: + xmin = 3.18 + xmax = 3.43 + text = "AY1" + intervals [10]: + xmin = 3.43 + xmax = 3.53 + text = "AH0" + intervals [11]: + xmin = 3.53 + xmax = 3.6 + text = "P" + intervals [12]: + xmin = 3.6 + xmax = 3.7 + text = "IH1" + intervals [13]: + xmin = 3.7 + xmax = 3.77 + text = "N" + intervals [14]: + xmin = 3.77 + xmax = 3.85 + text = "Y" + intervals [15]: + xmin = 3.85 + xmax = 3.92 + text = "AH0" + intervals [16]: + xmin = 3.92 + xmax = 4.29 + text = "N" + intervals [17]: + xmin = 4.29 + xmax = 4.43 + text = "AY1" + intervals [18]: + xmin = 4.43 + xmax = 4.5 + text = "TH" + intervals [19]: + xmin = 4.5 + xmax = 4.57 + text = "IH1" + intervals [20]: + xmin = 4.57 + xmax = 4.67 + text = "NG" + intervals [21]: + xmin = 4.67 + xmax = 4.7 + text = "K" + intervals [22]: + xmin = 4.7 + xmax = 4.73 + text = "DH" + intervals [23]: + xmin = 4.73 + xmax = 4.77 + text = "AH0" + intervals [24]: + xmin = 4.77 + xmax = 4.84 + text = "B" + intervals [25]: + xmin = 4.84 + xmax = 4.96 + text = "EH1" + intervals [26]: + xmin = 4.96 + xmax = 5.02 + text = "S" + intervals [27]: + xmin = 5.02 + xmax = 5.06 + text = "T" + intervals [28]: + xmin = 5.06 + xmax = 5.17 + text = "JH" + intervals [29]: + xmin = 5.17 + xmax = 5.25 + text = "AA1" + intervals [30]: + xmin = 5.25 + xmax = 5.31 + text = "B" + intervals [31]: + xmin = 5.31 + xmax = 5.35 + text = "F" + intervals [32]: + xmin = 5.35 + xmax = 5.41 + text = "ER0" + intervals [33]: + xmin = 5.41 + xmax = 5.49 + text = "M" + intervals [34]: + xmin = 5.49 + xmax = 5.62 + text = "IY1" + intervals [35]: + xmin = 5.62 + xmax = 5.69 + text = "IH1" + intervals [36]: + xmin = 5.69 + xmax = 5.76 + text = "Z" + intervals [37]: + xmin = 5.76 + xmax = 5.8 + text = "T" + intervals [38]: + xmin = 5.8 + xmax = 5.84 + text = "IH0" + intervals [39]: + xmin = 5.84 + xmax = 5.88 + text = "B" + intervals [40]: + xmin = 5.88 + xmax = 5.94 + text = "IH0" + intervals [41]: + xmin = 5.94 + xmax = 6.01 + text = "K" + intervals [42]: + xmin = 6.01 + xmax = 6.06 + text = "AH1" + intervals [43]: + xmin = 6.06 + xmax = 6.11 + text = "M" + intervals [44]: + xmin = 6.11 + xmax = 6.2 + text = "AH0" + intervals [45]: + xmin = 6.2 + xmax = 6.42 + text = "JH" + intervals [46]: + xmin = 6.42 + xmax = 6.52 + text = "ER1" + intervals [47]: + xmin = 6.52 + xmax = 6.57 + text = "N" + intervals [48]: + xmin = 6.57 + xmax = 6.63 + text = "AH0" + intervals [49]: + xmin = 6.63 + xmax = 6.71 + text = "L" + intervals [50]: + xmin = 6.71 + xmax = 6.77 + text = "AH0" + intervals [51]: + xmin = 6.77 + xmax = 6.87 + text = "S" + intervals [52]: + xmin = 6.87 + xmax = 6.93 + text = "T" + intervals [53]: + xmin = 6.93 + xmax = 7.04 + text = "K" + intervals [54]: + xmin = 7.04 + xmax = 7.18 + text = "UW0" + intervals [55]: + xmin = 7.18 + xmax = 7.3 + text = "Z" + intervals [56]: + xmin = 7.3 + xmax = 7.37 + text = "DH" + intervals [57]: + xmin = 7.37 + xmax = 7.44 + text = "IH1" + intervals [58]: + xmin = 7.44 + xmax = 7.53 + text = "S" + intervals [59]: + xmin = 7.53 + xmax = 7.58 + text = "IH0" + intervals [60]: + xmin = 7.58 + xmax = 7.63 + text = "Z" + intervals [61]: + xmin = 7.63 + xmax = 7.71 + text = "M" + intervals [62]: + xmin = 7.71 + xmax = 7.79 + text = "AY1" + intervals [63]: + xmin = 7.79 + xmax = 7.9 + text = "D" + intervals [64]: + xmin = 7.9 + xmax = 7.97 + text = "R" + intervals [65]: + xmin = 7.97 + xmax = 8.07 + text = "IY1" + intervals [66]: + xmin = 8.07 + xmax = 8.14 + text = "M" + intervals [67]: + xmin = 8.14 + xmax = 8.26 + text = "JH" + intervals [68]: + xmin = 8.26 + xmax = 8.44 + text = "AA1" + intervals [69]: + xmin = 8.44 + xmax = 8.53 + text = "B" + intervals [70]: + xmin = 8.53 + xmax = 8.65 + text = "AY1" + intervals [71]: + xmin = 8.65 + xmax = 8.71 + text = "V" + intervals [72]: + xmin = 8.71 + xmax = 8.82 + text = "AO1" + intervals [73]: + xmin = 8.82 + xmax = 8.88 + text = "L" + intervals [74]: + xmin = 8.88 + xmax = 8.98 + text = "W" + intervals [75]: + xmin = 8.98 + xmax = 9.03 + text = "IY0" + intervals [76]: + xmin = 9.03 + xmax = 9.13 + text = "Z" + intervals [77]: + xmin = 9.13 + xmax = 9.3 + text = "W" + intervals [78]: + xmin = 9.3 + xmax = 9.42 + text = "AO1" + intervals [79]: + xmin = 9.42 + xmax = 9.46 + text = "N" + intervals [80]: + xmin = 9.46 + xmax = 9.5 + text = "IH0" + intervals [81]: + xmin = 9.5 + xmax = 9.53 + text = "D" + intervals [82]: + xmin = 9.53 + xmax = 9.56 + text = "T" + intervals [83]: + xmin = 9.56 + xmax = 9.6 + text = "AH0" + intervals [84]: + xmin = 9.6 + xmax = 9.66 + text = "B" + intervals [85]: + xmin = 9.66 + xmax = 9.77 + text = "IY1" + intervals [86]: + xmin = 9.77 + xmax = 9.84 + text = "AH0" + intervals [87]: + xmin = 9.84 + xmax = 9.95 + text = "JH" + intervals [88]: + xmin = 9.95 + xmax = 10.0 + text = "ER1" + intervals [89]: + xmin = 10.0 + xmax = 10.03 + text = "N" + intervals [90]: + xmin = 10.03 + xmax = 10.1 + text = "AH0" + intervals [91]: + xmin = 10.1 + xmax = 10.18 + text = "L" + intervals [92]: + xmin = 10.18 + xmax = 10.23 + text = "AH0" + intervals [93]: + xmin = 10.23 + xmax = 10.26 + text = "S" + intervals [94]: + xmin = 10.26 + xmax = 10.29 + text = "T" + intervals [95]: + xmin = 10.29 + xmax = 10.33 + text = "S" + intervals [96]: + xmin = 10.33 + xmax = 10.37 + text = "IH1" + intervals [97]: + xmin = 10.37 + xmax = 10.42 + text = "N" + intervals [98]: + xmin = 10.42 + xmax = 10.48 + text = "S" + intervals [99]: + xmin = 10.48 + xmax = 10.54 + text = "AY1" + intervals [100]: + xmin = 10.54 + xmax = 10.6 + text = "W" + intervals [101]: + xmin = 10.6 + xmax = 10.64 + text = "AH0" + intervals [102]: + xmin = 10.64 + xmax = 10.71 + text = "Z" + intervals [103]: + xmin = 10.71 + xmax = 10.75 + text = "IH0" + intervals [104]: + xmin = 10.75 + xmax = 10.78 + text = "N" + intervals [105]: + xmin = 10.78 + xmax = 10.84 + text = "M" + intervals [106]: + xmin = 10.84 + xmax = 10.87 + text = "IH1" + intervals [107]: + xmin = 10.87 + xmax = 10.9 + text = "D" + intervals [108]: + xmin = 10.9 + xmax = 10.95 + text = "AH0" + intervals [109]: + xmin = 10.95 + xmax = 11.01 + text = "L" + intervals [110]: + xmin = 11.01 + xmax = 11.11 + text = "S" + intervals [111]: + xmin = 11.11 + xmax = 11.16 + text = "K" + intervals [112]: + xmin = 11.16 + xmax = 11.21 + text = "UW1" + intervals [113]: + xmin = 11.21 + xmax = 11.4 + text = "L" + intervals [114]: + xmin = 11.4 + xmax = 11.96 + text = "" + intervals [115]: + xmin = 11.96 + xmax = 12.2 + text = "JH" + intervals [116]: + xmin = 12.2 + xmax = 12.28 + text = "ER1" + intervals [117]: + xmin = 12.28 + xmax = 12.34 + text = "N" + intervals [118]: + xmin = 12.34 + xmax = 12.38 + text = "AH0" + intervals [119]: + xmin = 12.38 + xmax = 12.5 + text = "L" + intervals [120]: + xmin = 12.5 + xmax = 12.64 + text = "AH0" + intervals [121]: + xmin = 12.64 + xmax = 12.83 + text = "S" + intervals [122]: + xmin = 12.83 + xmax = 12.86 + text = "T" + intervals [123]: + xmin = 12.86 + xmax = 12.92 + text = "S" + intervals [124]: + xmin = 12.92 + xmax = 13.14 + text = "N" + intervals [125]: + xmin = 13.14 + xmax = 13.34 + text = "EH1" + intervals [126]: + xmin = 13.34 + xmax = 13.46 + text = "V" + intervals [127]: + xmin = 13.46 + xmax = 13.73 + text = "ER0" + intervals [128]: + xmin = 13.73 + xmax = 13.81 + text = "" + intervals [129]: + xmin = 13.81 + xmax = 14.0 + text = "T" + intervals [130]: + xmin = 14.0 + xmax = 14.17 + text = "EH1" + intervals [131]: + xmin = 14.17 + xmax = 14.38 + text = "L" + intervals [132]: + xmin = 14.38 + xmax = 14.43 + text = "L" + intervals [133]: + xmin = 14.43 + xmax = 14.95 + text = "AY1" + intervals [134]: + xmin = 14.95 + xmax = 15.08 + text = "Z" + intervals [135]: + xmin = 15.08 + xmax = 15.13 + text = "AE1" + intervals [136]: + xmin = 15.13 + xmax = 15.16 + text = "N" + intervals [137]: + xmin = 15.16 + xmax = 15.19 + text = "D" + intervals [138]: + xmin = 15.19 + xmax = 15.38 + text = "ER0" + intervals [139]: + xmin = 15.38 + xmax = 15.43 + text = "AO1" + intervals [140]: + xmin = 15.43 + xmax = 15.48 + text = "L" + intervals [141]: + xmin = 15.48 + xmax = 15.57 + text = "W" + intervals [142]: + xmin = 15.57 + xmax = 15.61 + text = "IY0" + intervals [143]: + xmin = 15.61 + xmax = 15.75 + text = "Z" + intervals [144]: + xmin = 15.75 + xmax = 15.8 + text = "S" + intervals [145]: + xmin = 15.8 + xmax = 15.89 + text = "IY1" + intervals [146]: + xmin = 15.89 + xmax = 15.96 + text = "K" + intervals [147]: + xmin = 15.96 + xmax = 16.03 + text = "IH0" + intervals [148]: + xmin = 16.03 + xmax = 16.07 + text = "NG" + intervals [149]: + xmin = 16.07 + xmax = 16.12 + text = "DH" + intervals [150]: + xmin = 16.12 + xmax = 16.17 + text = "AH0" + intervals [151]: + xmin = 16.17 + xmax = 16.3 + text = "T" + intervals [152]: + xmin = 16.3 + xmax = 16.36 + text = "R" + intervals [153]: + xmin = 16.36 + xmax = 16.48 + text = "UW1" + intervals [154]: + xmin = 16.48 + xmax = 16.69 + text = "TH" + intervals [155]: + xmin = 16.69 + xmax = 16.85 + text = "" + intervals [156]: + xmin = 16.85 + xmax = 17.09 + text = "AY1" + intervals [157]: + xmin = 17.09 + xmax = 17.18 + text = "W" + intervals [158]: + xmin = 17.18 + xmax = 17.21 + text = "AA1" + intervals [159]: + xmin = 17.21 + xmax = 17.24 + text = "N" + intervals [160]: + xmin = 17.24 + xmax = 17.27 + text = "T" + intervals [161]: + xmin = 17.27 + xmax = 17.3 + text = "T" + intervals [162]: + xmin = 17.3 + xmax = 17.33 + text = "AH0" + intervals [163]: + xmin = 17.33 + xmax = 17.37 + text = "B" + intervals [164]: + xmin = 17.37 + xmax = 17.45 + text = "IY0" + intervals [165]: + xmin = 17.45 + xmax = 17.57 + text = "JH" + intervals [166]: + xmin = 17.57 + xmax = 17.64 + text = "IH0" + intervals [167]: + xmin = 17.64 + xmax = 17.7 + text = "S" + intervals [168]: + xmin = 17.7 + xmax = 17.73 + text = "T" + intervals [169]: + xmin = 17.73 + xmax = 17.81 + text = "L" + intervals [170]: + xmin = 17.81 + xmax = 17.91 + text = "AY1" + intervals [171]: + xmin = 17.91 + xmax = 18.0 + text = "K" + intervals [172]: + xmin = 18.0 + xmax = 18.06 + text = "DH" + intervals [173]: + xmin = 18.06 + xmax = 18.23 + text = "AE1" + intervals [174]: + xmin = 18.23 + xmax = 18.36 + text = "T" + intervals [175]: + xmin = 18.36 + xmax = 18.73 + text = "" + intervals [176]: + xmin = 18.73 + xmax = 19.54 + text = "AY1" + intervals [177]: + xmin = 19.54 + xmax = 19.71 + text = "" + intervals [178]: + xmin = 19.71 + xmax = 19.92 + text = "Y" + intervals [179]: + xmin = 19.92 + xmax = 20.03 + text = "UW1" + intervals [180]: + xmin = 20.03 + xmax = 20.11 + text = "ZH" + intervals [181]: + xmin = 20.11 + xmax = 20.14 + text = "AH0" + intervals [182]: + xmin = 20.14 + xmax = 20.17 + text = "L" + intervals [183]: + xmin = 20.17 + xmax = 20.27 + text = "IY0" + intervals [184]: + xmin = 20.27 + xmax = 20.37 + text = "F" + intervals [185]: + xmin = 20.37 + xmax = 20.45 + text = "IY1" + intervals [186]: + xmin = 20.45 + xmax = 20.52 + text = "L" + intervals [187]: + xmin = 20.52 + xmax = 20.73 + text = "SH" + intervals [188]: + xmin = 20.73 + xmax = 20.96 + text = "AY1" + intervals [189]: + xmin = 20.96 + xmax = 21.06 + text = "W" + intervals [190]: + xmin = 21.06 + xmax = 21.09 + text = "EH1" + intervals [191]: + xmin = 21.09 + xmax = 21.14 + text = "N" + intervals [192]: + xmin = 21.14 + xmax = 21.26 + text = "AY1" + intervals [193]: + xmin = 21.26 + xmax = 21.31 + text = "AE1" + intervals [194]: + xmin = 21.31 + xmax = 21.38 + text = "M" + intervals [195]: + xmin = 21.38 + xmax = 21.5 + text = "T" + intervals [196]: + xmin = 21.5 + xmax = 21.57 + text = "AO1" + intervals [197]: + xmin = 21.57 + xmax = 21.64 + text = "K" + intervals [198]: + xmin = 21.64 + xmax = 21.69 + text = "IH0" + intervals [199]: + xmin = 21.69 + xmax = 21.75 + text = "NG" + intervals [200]: + xmin = 21.75 + xmax = 21.8 + text = "T" + intervals [201]: + xmin = 21.8 + xmax = 21.86 + text = "AH0" + intervals [202]: + xmin = 21.86 + xmax = 21.96 + text = "AH1" + intervals [203]: + xmin = 21.96 + xmax = 22.02 + text = "DH" + intervals [204]: + xmin = 22.02 + xmax = 22.13 + text = "ER0" + intervals [205]: + xmin = 22.13 + xmax = 22.22 + text = "Z" + intervals [206]: + xmin = 22.22 + xmax = 22.29 + text = "AE1" + intervals [207]: + xmin = 22.29 + xmax = 22.32 + text = "N" + intervals [208]: + xmin = 22.32 + xmax = 22.37 + text = "D" + intervals [209]: + xmin = 22.37 + xmax = 22.46 + text = "AY1" + intervals [210]: + xmin = 22.46 + xmax = 22.6 + text = "N" + intervals [211]: + xmin = 22.6 + xmax = 22.8 + text = "OW1" + intervals [212]: + xmin = 22.8 + xmax = 22.85 + text = "DH" + intervals [213]: + xmin = 22.85 + xmax = 22.98 + text = "AH0" + intervals [214]: + xmin = 22.98 + xmax = 23.1 + text = "T" + intervals [215]: + xmin = 23.1 + xmax = 23.27 + text = "JH" + intervals [216]: + xmin = 23.27 + xmax = 23.35 + text = "ER1" + intervals [217]: + xmin = 23.35 + xmax = 23.39 + text = "N" + intervals [218]: + xmin = 23.39 + xmax = 23.46 + text = "AH0" + intervals [219]: + xmin = 23.46 + xmax = 23.54 + text = "L" + intervals [220]: + xmin = 23.54 + xmax = 23.61 + text = "AH0" + intervals [221]: + xmin = 23.61 + xmax = 23.69 + text = "S" + intervals [222]: + xmin = 23.69 + xmax = 23.72 + text = "T" + intervals [223]: + xmin = 23.72 + xmax = 23.79 + text = "S" + intervals [224]: + xmin = 23.79 + xmax = 23.89 + text = "AA1" + intervals [225]: + xmin = 23.89 + xmax = 23.99 + text = "R" + intervals [226]: + xmin = 23.99 + xmax = 24.22 + text = "V" + intervals [227]: + xmin = 24.22 + xmax = 24.49 + text = "EH1" + intervals [228]: + xmin = 24.49 + xmax = 24.64 + text = "R" + intervals [229]: + xmin = 24.64 + xmax = 24.73 + text = "IY0" + intervals [230]: + xmin = 24.73 + xmax = 24.9 + text = "G" + intervals [231]: + xmin = 24.9 + xmax = 25.12 + text = "UH1" + intervals [232]: + xmin = 25.12 + xmax = 25.29 + text = "D" + intervals [233]: + xmin = 25.29 + xmax = 25.44 + text = "" + intervals [234]: + xmin = 25.44 + xmax = 25.61 + text = "AE1" + intervals [235]: + xmin = 25.61 + xmax = 25.7 + text = "T" + intervals [236]: + xmin = 25.7 + xmax = 25.73 + text = "K" + intervals [237]: + xmin = 25.73 + xmax = 25.77 + text = "AH0" + intervals [238]: + xmin = 25.77 + xmax = 25.83 + text = "M" + intervals [239]: + xmin = 25.83 + xmax = 25.89 + text = "Y" + intervals [240]: + xmin = 25.89 + xmax = 25.93 + text = "UW1" + intervals [241]: + xmin = 25.93 + xmax = 25.98 + text = "N" + intervals [242]: + xmin = 25.98 + xmax = 26.05 + text = "AH0" + intervals [243]: + xmin = 26.05 + xmax = 26.19 + text = "K" + intervals [244]: + xmin = 26.19 + xmax = 26.25 + text = "EY2" + intervals [245]: + xmin = 26.25 + xmax = 26.28 + text = "T" + intervals [246]: + xmin = 26.28 + xmax = 26.35 + text = "IH0" + intervals [247]: + xmin = 26.35 + xmax = 26.41 + text = "NG" + intervals [248]: + xmin = 26.41 + xmax = 26.46 + text = "B" + intervals [249]: + xmin = 26.46 + xmax = 26.54 + text = "IH0" + intervals [250]: + xmin = 26.54 + xmax = 26.64 + text = "K" + intervals [251]: + xmin = 26.64 + xmax = 26.84 + text = "AH0" + intervals [252]: + xmin = 26.84 + xmax = 26.94 + text = "Z" + intervals [253]: + xmin = 26.94 + xmax = 26.98 + text = "" + intervals [254]: + xmin = 26.98 + xmax = 27.07 + text = "G" + intervals [255]: + xmin = 27.07 + xmax = 27.15 + text = "UH1" + intervals [256]: + xmin = 27.15 + xmax = 27.2 + text = "D" + intervals [257]: + xmin = 27.2 + xmax = 27.27 + text = "K" + intervals [258]: + xmin = 27.27 + xmax = 27.32 + text = "AH0" + intervals [259]: + xmin = 27.32 + xmax = 27.38 + text = "M" + intervals [260]: + xmin = 27.38 + xmax = 27.43 + text = "Y" + intervals [261]: + xmin = 27.43 + xmax = 27.46 + text = "UW2" + intervals [262]: + xmin = 27.46 + xmax = 27.49 + text = "N" + intervals [263]: + xmin = 27.49 + xmax = 27.54 + text = "AH0" + intervals [264]: + xmin = 27.54 + xmax = 27.65 + text = "K" + intervals [265]: + xmin = 27.65 + xmax = 27.75 + text = "EY1" + intervals [266]: + xmin = 27.75 + xmax = 27.84 + text = "SH" + intervals [267]: + xmin = 27.84 + xmax = 27.89 + text = "AH0" + intervals [268]: + xmin = 27.89 + xmax = 27.95 + text = "N" + intervals [269]: + xmin = 27.95 + xmax = 28.03 + text = "S" + intervals [270]: + xmin = 28.03 + xmax = 28.08 + text = "K" + intervals [271]: + xmin = 28.08 + xmax = 28.15 + text = "IH1" + intervals [272]: + xmin = 28.15 + xmax = 28.27 + text = "L" + intervals [273]: + xmin = 28.27 + xmax = 28.36 + text = "Z" + intervals [274]: + xmin = 28.36 + xmax = 28.42 + text = "AA1" + intervals [275]: + xmin = 28.42 + xmax = 28.51 + text = "R" + intervals [276]: + xmin = 28.51 + xmax = 28.56 + text = "V" + intervals [277]: + xmin = 28.56 + xmax = 28.61 + text = "EH1" + intervals [278]: + xmin = 28.61 + xmax = 28.72 + text = "R" + intervals [279]: + xmin = 28.72 + xmax = 28.75 + text = "IY0" + intervals [280]: + xmin = 28.75 + xmax = 28.8 + text = "IH0" + intervals [281]: + xmin = 28.8 + xmax = 28.86 + text = "M" + intervals [282]: + xmin = 28.86 + xmax = 28.97 + text = "P" + intervals [283]: + xmin = 28.97 + xmax = 29.01 + text = "AO1" + intervals [284]: + xmin = 29.01 + xmax = 29.07 + text = "R" + intervals [285]: + xmin = 29.07 + xmax = 29.11 + text = "T" + intervals [286]: + xmin = 29.11 + xmax = 29.14 + text = "AH0" + intervals [287]: + xmin = 29.14 + xmax = 29.18 + text = "N" + intervals [288]: + xmin = 29.18 + xmax = 29.22 + text = "T" + intervals [289]: + xmin = 29.22 + xmax = 29.29 + text = "W" + intervals [290]: + xmin = 29.29 + xmax = 29.32 + text = "EH1" + intervals [291]: + xmin = 29.32 + xmax = 29.39 + text = "N" + intervals [292]: + xmin = 29.39 + xmax = 29.43 + text = "Y" + intervals [293]: + xmin = 29.43 + xmax = 29.46 + text = "UH1" + intervals [294]: + xmin = 29.46 + xmax = 29.51 + text = "R" + intervals [295]: + xmin = 29.51 + xmax = 29.6 + text = "D" + intervals [296]: + xmin = 29.6 + xmax = 29.7 + text = "UW1" + intervals [297]: + xmin = 29.7 + xmax = 29.76 + text = "IH0" + intervals [298]: + xmin = 29.76 + xmax = 29.85 + text = "NG" + intervals [299]: + xmin = 29.85 + xmax = 29.88 + text = "IH1" + intervals [300]: + xmin = 29.88 + xmax = 29.99 + text = "N" + intervals [301]: + xmin = 29.99 + xmax = 30.06 + text = "ER0" + intervals [302]: + xmin = 30.06 + xmax = 30.1 + text = "V" + intervals [303]: + xmin = 30.1 + xmax = 30.21 + text = "Y" + intervals [304]: + xmin = 30.21 + xmax = 30.28 + text = "UW2" + intervals [305]: + xmin = 30.28 + xmax = 30.43 + text = "Z" + intervals [306]: + xmin = 30.43 + xmax = 30.71 + text = "" + intervals [307]: + xmin = 30.71 + xmax = 30.99 + text = "AY1" + intervals [308]: + xmin = 30.99 + xmax = 31.11 + text = "W" + intervals [309]: + xmin = 31.11 + xmax = 31.15 + text = "AA1" + intervals [310]: + xmin = 31.15 + xmax = 31.18 + text = "N" + intervals [311]: + xmin = 31.18 + xmax = 31.21 + text = "T" + intervals [312]: + xmin = 31.21 + xmax = 31.24 + text = "T" + intervals [313]: + xmin = 31.24 + xmax = 31.3 + text = "AH0" + intervals [314]: + xmin = 31.3 + xmax = 31.35 + text = "P" + intervals [315]: + xmin = 31.35 + xmax = 31.42 + text = "AH0" + intervals [316]: + xmin = 31.42 + xmax = 31.51 + text = "Z" + intervals [317]: + xmin = 31.51 + xmax = 31.62 + text = "EH1" + intervals [318]: + xmin = 31.62 + xmax = 31.71 + text = "S" + intervals [319]: + xmin = 31.71 + xmax = 31.75 + text = "DH" + intervals [320]: + xmin = 31.75 + xmax = 31.82 + text = "AH0" + intervals [321]: + xmin = 31.82 + xmax = 31.97 + text = "S" + intervals [322]: + xmin = 31.97 + xmax = 32.02 + text = "K" + intervals [323]: + xmin = 32.02 + xmax = 32.09 + text = "IH1" + intervals [324]: + xmin = 32.09 + xmax = 32.16 + text = "L" + intervals [325]: + xmin = 32.16 + xmax = 32.23 + text = "M" + intervals [326]: + xmin = 32.23 + xmax = 32.31 + text = "AY2" + intervals [327]: + xmin = 32.31 + xmax = 32.46 + text = "S" + intervals [328]: + xmin = 32.46 + xmax = 32.53 + text = "EH1" + intervals [329]: + xmin = 32.53 + xmax = 32.76 + text = "L" + intervals [330]: + xmin = 32.76 + xmax = 32.93 + text = "F" + intervals [331]: + xmin = 32.93 + xmax = 33.0 + text = "" + intervals [332]: + xmin = 33.0 + xmax = 33.22 + text = "S" + intervals [333]: + xmin = 33.22 + xmax = 33.53 + text = "OW1" + intervals [334]: + xmin = 33.53 + xmax = 33.56 + text = "" + intervals [335]: + xmin = 33.56 + xmax = 33.67 + text = "DH" + intervals [336]: + xmin = 33.67 + xmax = 33.82 + text = "AE1" + intervals [337]: + xmin = 33.82 + xmax = 33.89 + text = "T" + intervals [338]: + xmin = 33.89 + xmax = 33.95 + text = "S" + intervals [339]: + xmin = 33.95 + xmax = 34.09 + text = "W" + intervals [340]: + xmin = 34.09 + xmax = 34.31 + text = "AY1" + intervals [341]: + xmin = 34.31 + xmax = 34.53 + text = "AY1" + intervals [342]: + xmin = 34.53 + xmax = 34.67 + text = "W" + intervals [343]: + xmin = 34.67 + xmax = 34.77 + text = "AO1" + intervals [344]: + xmin = 34.77 + xmax = 34.8 + text = "N" + intervals [345]: + xmin = 34.8 + xmax = 34.83 + text = "T" + intervals [346]: + xmin = 34.83 + xmax = 34.86 + text = "T" + intervals [347]: + xmin = 34.86 + xmax = 34.89 + text = "IH0" + intervals [348]: + xmin = 34.89 + xmax = 34.93 + text = "B" + intervals [349]: + xmin = 34.93 + xmax = 35.0 + text = "IH0" + intervals [350]: + xmin = 35.0 + xmax = 35.08 + text = "K" + intervals [351]: + xmin = 35.08 + xmax = 35.19 + text = "AH1" + intervals [352]: + xmin = 35.19 + xmax = 35.42 + text = "M" + intervals [353]: + xmin = 35.42 + xmax = 35.46 + text = "" + intervals [354]: + xmin = 35.46 + xmax = 35.59 + text = "AH0" + intervals [355]: + xmin = 35.59 + xmax = 35.68 + text = "JH" + intervals [356]: + xmin = 35.68 + xmax = 35.79 + text = "ER1" + intervals [357]: + xmin = 35.79 + xmax = 35.84 + text = "N" + intervals [358]: + xmin = 35.84 + xmax = 35.89 + text = "AH0" + intervals [359]: + xmin = 35.89 + xmax = 35.97 + text = "L" + intervals [360]: + xmin = 35.97 + xmax = 36.07 + text = "AH0" + intervals [361]: + xmin = 36.07 + xmax = 36.27 + text = "S" + intervals [362]: + xmin = 36.27 + xmax = 36.37 + text = "T" + intervals [363]: + xmin = 36.37 + xmax = 36.74 + text = "" + intervals [364]: + xmin = 36.74 + xmax = 36.92 + text = "AH1" + intervals [365]: + xmin = 36.92 + xmax = 36.96 + text = "DH" + intervals [366]: + xmin = 36.96 + xmax = 37.02 + text = "ER0" + intervals [367]: + xmin = 37.02 + xmax = 37.05 + text = "DH" + intervals [368]: + xmin = 37.05 + xmax = 37.13 + text = "AE1" + intervals [369]: + xmin = 37.13 + xmax = 37.18 + text = "N" + intervals [370]: + xmin = 37.18 + xmax = 37.23 + text = "DH" + intervals [371]: + xmin = 37.23 + xmax = 37.47 + text = "AE1" + intervals [372]: + xmin = 37.47 + xmax = 37.73 + text = "T" + intervals [373]: + xmin = 37.73 + xmax = 37.76 + text = "" + intervals [374]: + xmin = 37.76 + xmax = 37.8 + text = "F" + intervals [375]: + xmin = 37.8 + xmax = 37.85 + text = "AH0" + intervals [376]: + xmin = 37.85 + xmax = 37.96 + text = "T" + intervals [377]: + xmin = 37.96 + xmax = 38.08 + text = "AA1" + intervals [378]: + xmin = 38.08 + xmax = 38.13 + text = "G" + intervals [379]: + xmin = 38.13 + xmax = 38.2 + text = "R" + intervals [380]: + xmin = 38.2 + xmax = 38.26 + text = "AH0" + intervals [381]: + xmin = 38.26 + xmax = 38.35 + text = "F" + intervals [382]: + xmin = 38.35 + xmax = 38.72 + text = "IY0" + intervals [383]: + xmin = 38.72 + xmax = 38.96 + text = "" + intervals [384]: + xmin = 38.96 + xmax = 39.25 + text = "AO1" + intervals [385]: + xmin = 39.25 + xmax = 39.3 + text = "F" + intervals [386]: + xmin = 39.3 + xmax = 39.36 + text = "AH0" + intervals [387]: + xmin = 39.36 + xmax = 39.4 + text = "N" + intervals [388]: + xmin = 39.4 + xmax = 39.47 + text = "M" + intervals [389]: + xmin = 39.47 + xmax = 39.54 + text = "EY1" + intervals [390]: + xmin = 39.54 + xmax = 39.59 + text = "K" + intervals [391]: + xmin = 39.59 + xmax = 39.66 + text = "S" + intervals [392]: + xmin = 39.66 + xmax = 39.7 + text = "M" + intervals [393]: + xmin = 39.7 + xmax = 39.8 + text = "IY1" + intervals [394]: + xmin = 39.8 + xmax = 39.94 + text = "F" + intervals [395]: + xmin = 39.94 + xmax = 40.09 + text = "IY1" + intervals [396]: + xmin = 40.09 + xmax = 40.39 + text = "L" + intervals [397]: + xmin = 40.39 + xmax = 40.48 + text = "L" + intervals [398]: + xmin = 40.48 + xmax = 40.62 + text = "AY1" + intervals [399]: + xmin = 40.62 + xmax = 40.72 + text = "K" + intervals [400]: + xmin = 40.72 + xmax = 40.97 + text = "AY1" + intervals [401]: + xmin = 40.97 + xmax = 41.42 + text = "M" + intervals [402]: + xmin = 41.42 + xmax = 41.77 + text = "" + intervals [403]: + xmin = 41.77 + xmax = 41.85 + text = "D" + intervals [404]: + xmin = 41.85 + xmax = 41.95 + text = "UW1" + intervals [405]: + xmin = 41.95 + xmax = 42.02 + text = "IH0" + intervals [406]: + xmin = 42.02 + xmax = 42.06 + text = "NG" + intervals [407]: + xmin = 42.06 + xmax = 42.13 + text = "EY1" + intervals [408]: + xmin = 42.13 + xmax = 42.25 + text = "JH" + intervals [409]: + xmin = 42.25 + xmax = 42.36 + text = "AA1" + intervals [410]: + xmin = 42.36 + xmax = 42.41 + text = "B" + intervals [411]: + xmin = 42.41 + xmax = 42.47 + text = "F" + intervals [412]: + xmin = 42.47 + xmax = 42.51 + text = "UH1" + intervals [413]: + xmin = 42.51 + xmax = 42.58 + text = "L" + intervals [414]: + xmin = 42.58 + xmax = 42.61 + text = "AH0" + intervals [415]: + xmin = 42.61 + xmax = 42.64 + text = "V" + intervals [416]: + xmin = 42.64 + xmax = 42.68 + text = "D" + intervals [417]: + xmin = 42.68 + xmax = 42.74 + text = "IH0" + intervals [418]: + xmin = 42.74 + xmax = 42.83 + text = "Z" + intervals [419]: + xmin = 42.83 + xmax = 42.96 + text = "AY1" + intervals [420]: + xmin = 42.96 + xmax = 42.99 + text = "N" + intervals [421]: + xmin = 42.99 + xmax = 43.03 + text = "AE1" + intervals [422]: + xmin = 43.03 + xmax = 43.07 + text = "N" + intervals [423]: + xmin = 43.07 + xmax = 43.12 + text = "D" + intervals [424]: + xmin = 43.12 + xmax = 43.19 + text = "IH0" + intervals [425]: + xmin = 43.19 + xmax = 43.25 + text = "N" + intervals [426]: + xmin = 43.25 + xmax = 43.3 + text = "F" + intervals [427]: + xmin = 43.3 + xmax = 43.52 + text = "ER0" + intervals [428]: + xmin = 43.52 + xmax = 43.68 + text = "IH2" + intervals [429]: + xmin = 43.68 + xmax = 43.74 + text = "N" + intervals [430]: + xmin = 43.74 + xmax = 43.8 + text = "AH0" + intervals [431]: + xmin = 43.8 + xmax = 43.87 + text = "V" + intervals [432]: + xmin = 43.87 + xmax = 44.01 + text = "EY1" + intervals [433]: + xmin = 44.01 + xmax = 44.09 + text = "SH" + intervals [434]: + xmin = 44.09 + xmax = 44.16 + text = "AH0" + intervals [435]: + xmin = 44.16 + xmax = 44.31 + text = "N" + intervals [436]: + xmin = 44.31 + xmax = 45.18 + text = "" + intervals [437]: + xmin = 45.18 + xmax = 45.24 + text = "B" + intervals [438]: + xmin = 45.24 + xmax = 45.3 + text = "IH0" + intervals [439]: + xmin = 45.3 + xmax = 45.39 + text = "K" + intervals [440]: + xmin = 45.39 + xmax = 45.54 + text = "AH1" + intervals [441]: + xmin = 45.54 + xmax = 45.73 + text = "Z" + intervals [442]: + xmin = 45.73 + xmax = 45.81 + text = "F" + intervals [443]: + xmin = 45.81 + xmax = 45.84 + text = "R" + intervals [444]: + xmin = 45.84 + xmax = 45.89 + text = "ER0" + intervals [445]: + xmin = 45.89 + xmax = 45.93 + text = "DH" + intervals [446]: + xmin = 45.93 + xmax = 45.99 + text = "AH0" + intervals [447]: + xmin = 45.99 + xmax = 46.12 + text = "S" + intervals [448]: + xmin = 46.12 + xmax = 46.25 + text = "EY1" + intervals [449]: + xmin = 46.25 + xmax = 46.35 + text = "M" + intervals [450]: + xmin = 46.35 + xmax = 46.48 + text = "S" + intervals [451]: + xmin = 46.48 + xmax = 46.59 + text = "IY1" + intervals [452]: + xmin = 46.59 + xmax = 46.65 + text = "N" + intervals [453]: + xmin = 46.65 + xmax = 46.74 + text = "ER0" + intervals [454]: + xmin = 46.74 + xmax = 46.88 + text = "IY0" + intervals [455]: + xmin = 46.88 + xmax = 46.94 + text = "W" + intervals [456]: + xmin = 46.94 + xmax = 47.01 + text = "IY1" + intervals [457]: + xmin = 47.01 + xmax = 47.06 + text = "K" + intervals [458]: + xmin = 47.06 + xmax = 47.09 + text = "AH0" + intervals [459]: + xmin = 47.09 + xmax = 47.12 + text = "N" + intervals [460]: + xmin = 47.12 + xmax = 47.22 + text = "Y" + intervals [461]: + xmin = 47.22 + xmax = 47.27 + text = "UW1" + intervals [462]: + xmin = 47.27 + xmax = 47.34 + text = "Z" + intervals [463]: + xmin = 47.34 + xmax = 47.4 + text = "D" + intervals [464]: + xmin = 47.4 + xmax = 47.45 + text = "IH1" + intervals [465]: + xmin = 47.45 + xmax = 47.49 + text = "F" + intervals [466]: + xmin = 47.49 + xmax = 47.52 + text = "R" + intervals [467]: + xmin = 47.52 + xmax = 47.55 + text = "AH0" + intervals [468]: + xmin = 47.55 + xmax = 47.58 + text = "N" + intervals [469]: + xmin = 47.58 + xmax = 47.61 + text = "T" + intervals [470]: + xmin = 47.61 + xmax = 47.74 + text = "AE1" + intervals [471]: + xmin = 47.74 + xmax = 47.82 + text = "NG" + intervals [472]: + xmin = 47.82 + xmax = 47.85 + text = "G" + intervals [473]: + xmin = 47.85 + xmax = 47.89 + text = "AH0" + intervals [474]: + xmin = 47.89 + xmax = 48.04 + text = "L" + intervals [475]: + xmin = 48.04 + xmax = 48.12 + text = "Z" + intervals [476]: + xmin = 48.12 + xmax = 48.15 + text = "AH0" + intervals [477]: + xmin = 48.15 + xmax = 48.18 + text = "N" + intervals [478]: + xmin = 48.18 + xmax = 48.21 + text = "D" + intervals [479]: + xmin = 48.21 + xmax = 48.26 + text = "D" + intervals [480]: + xmin = 48.26 + xmax = 48.31 + text = "IH1" + intervals [481]: + xmin = 48.31 + xmax = 48.36 + text = "F" + intervals [482]: + xmin = 48.36 + xmax = 48.39 + text = "R" + intervals [483]: + xmin = 48.39 + xmax = 48.42 + text = "AH0" + intervals [484]: + xmin = 48.42 + xmax = 48.45 + text = "N" + intervals [485]: + xmin = 48.45 + xmax = 48.48 + text = "T" + intervals [486]: + xmin = 48.48 + xmax = 48.53 + text = "K" + intervals [487]: + xmin = 48.53 + xmax = 48.6 + text = "AA2" + intervals [488]: + xmin = 48.6 + xmax = 48.64 + text = "M" + intervals [489]: + xmin = 48.64 + xmax = 48.67 + text = "P" + intervals [490]: + xmin = 48.67 + xmax = 48.72 + text = "AH0" + intervals [491]: + xmin = 48.72 + xmax = 48.8 + text = "Z" + intervals [492]: + xmin = 48.8 + xmax = 48.88 + text = "IH1" + intervals [493]: + xmin = 48.88 + xmax = 48.95 + text = "SH" + intervals [494]: + xmin = 48.95 + xmax = 49.02 + text = "AH0" + intervals [495]: + xmin = 49.02 + xmax = 49.12 + text = "N" + intervals [496]: + xmin = 49.12 + xmax = 49.32 + text = "Z" + intervals [497]: + xmin = 49.32 + xmax = 49.68 + text = "" + intervals [498]: + xmin = 49.68 + xmax = 49.91 + text = "F" + intervals [499]: + xmin = 49.91 + xmax = 49.98 + text = "ER0" + intervals [500]: + xmin = 49.98 + xmax = 50.02 + text = "IH0" + intervals [501]: + xmin = 50.02 + xmax = 50.06 + text = "G" + intervals [502]: + xmin = 50.06 + xmax = 50.13 + text = "Z" + intervals [503]: + xmin = 50.13 + xmax = 50.22 + text = "AE1" + intervals [504]: + xmin = 50.22 + xmax = 50.27 + text = "M" + intervals [505]: + xmin = 50.27 + xmax = 50.31 + text = "P" + intervals [506]: + xmin = 50.31 + xmax = 50.34 + text = "AH0" + intervals [507]: + xmin = 50.34 + xmax = 50.4 + text = "L" + intervals [508]: + xmin = 50.4 + xmax = 50.48 + text = "P" + intervals [509]: + xmin = 50.48 + xmax = 50.58 + text = "IY1" + intervals [510]: + xmin = 50.58 + xmax = 50.63 + text = "P" + intervals [511]: + xmin = 50.63 + xmax = 50.69 + text = "AH0" + intervals [512]: + xmin = 50.69 + xmax = 50.91 + text = "L" + intervals [513]: + xmin = 50.91 + xmax = 51.21 + text = "B" + intervals [514]: + xmin = 51.21 + xmax = 51.29 + text = "IY1" + intervals [515]: + xmin = 51.29 + xmax = 51.34 + text = "IH0" + intervals [516]: + xmin = 51.34 + xmax = 51.42 + text = "NG" + intervals [517]: + xmin = 51.42 + xmax = 51.57 + text = "SH" + intervals [518]: + xmin = 51.57 + xmax = 51.62 + text = "IH1" + intervals [519]: + xmin = 51.62 + xmax = 51.68 + text = "F" + intervals [520]: + xmin = 51.68 + xmax = 51.74 + text = "T" + intervals [521]: + xmin = 51.74 + xmax = 51.79 + text = "IH0" + intervals [522]: + xmin = 51.79 + xmax = 51.91 + text = "D" + intervals [523]: + xmin = 51.91 + xmax = 51.94 + text = "F" + intervals [524]: + xmin = 51.94 + xmax = 52.02 + text = "ER0" + intervals [525]: + xmin = 52.02 + xmax = 52.06 + text = "M" + intervals [526]: + xmin = 52.06 + xmax = 52.1 + text = "DH" + intervals [527]: + xmin = 52.1 + xmax = 52.15 + text = "AH0" + intervals [528]: + xmin = 52.15 + xmax = 52.27 + text = "S" + intervals [529]: + xmin = 52.27 + xmax = 52.32 + text = "EH1" + intervals [530]: + xmin = 52.32 + xmax = 52.38 + text = "N" + intervals [531]: + xmin = 52.38 + xmax = 52.47 + text = "ER0" + intervals [532]: + xmin = 52.47 + xmax = 52.51 + text = "AH0" + intervals [533]: + xmin = 52.51 + xmax = 52.54 + text = "V" + intervals [534]: + xmin = 52.54 + xmax = 52.57 + text = "DH" + intervals [535]: + xmin = 52.57 + xmax = 52.63 + text = "AH0" + intervals [536]: + xmin = 52.63 + xmax = 52.72 + text = "F" + intervals [537]: + xmin = 52.72 + xmax = 52.8 + text = "R" + intervals [538]: + xmin = 52.8 + xmax = 52.89 + text = "EY1" + intervals [539]: + xmin = 52.89 + xmax = 52.95 + text = "M" + intervals [540]: + xmin = 52.95 + xmax = 53.01 + text = "T" + intervals [541]: + xmin = 53.01 + xmax = 53.1 + text = "AH0" + intervals [542]: + xmin = 53.1 + xmax = 53.19 + text = "L" + intervals [543]: + xmin = 53.19 + xmax = 53.26 + text = "EH1" + intervals [544]: + xmin = 53.26 + xmax = 53.29 + text = "F" + intervals [545]: + xmin = 53.29 + xmax = 53.32 + text = "T" + intervals [546]: + xmin = 53.32 + xmax = 53.4 + text = "S" + intervals [547]: + xmin = 53.4 + xmax = 53.5 + text = "AY1" + intervals [548]: + xmin = 53.5 + xmax = 53.54 + text = "D" + intervals [549]: + xmin = 53.54 + xmax = 53.57 + text = "AH0" + intervals [550]: + xmin = 53.57 + xmax = 53.6 + text = "V" + intervals [551]: + xmin = 53.6 + xmax = 53.63 + text = "DH" + intervals [552]: + xmin = 53.63 + xmax = 53.69 + text = "AH0" + intervals [553]: + xmin = 53.69 + xmax = 53.78 + text = "F" + intervals [554]: + xmin = 53.78 + xmax = 53.92 + text = "R" + intervals [555]: + xmin = 53.92 + xmax = 54.03 + text = "EY1" + intervals [556]: + xmin = 54.03 + xmax = 54.17 + text = "M" + intervals [557]: + xmin = 54.17 + xmax = 54.62 + text = "" + intervals [558]: + xmin = 54.62 + xmax = 54.72 + text = "IH0" + intervals [559]: + xmin = 54.72 + xmax = 54.75 + text = "T" + intervals [560]: + xmin = 54.75 + xmax = 54.82 + text = "K" + intervals [561]: + xmin = 54.82 + xmax = 54.88 + text = "AH0" + intervals [562]: + xmin = 54.88 + xmax = 54.91 + text = "N" + intervals [563]: + xmin = 54.91 + xmax = 54.97 + text = "M" + intervals [564]: + xmin = 54.97 + xmax = 55.1 + text = "EY1" + intervals [565]: + xmin = 55.1 + xmax = 55.13 + text = "K" + intervals [566]: + xmin = 55.13 + xmax = 55.22 + text = "EY1" + intervals [567]: + xmin = 55.22 + xmax = 55.29 + text = "D" + intervals [568]: + xmin = 55.29 + xmax = 55.35 + text = "IH1" + intervals [569]: + xmin = 55.35 + xmax = 55.43 + text = "F" + intervals [570]: + xmin = 55.43 + xmax = 55.46 + text = "R" + intervals [571]: + xmin = 55.46 + xmax = 55.49 + text = "AH0" + intervals [572]: + xmin = 55.49 + xmax = 55.52 + text = "N" + intervals [573]: + xmin = 55.52 + xmax = 55.56 + text = "T" + intervals [574]: + xmin = 55.56 + xmax = 55.64 + text = "F" + intervals [575]: + xmin = 55.64 + xmax = 55.77 + text = "IY1" + intervals [576]: + xmin = 55.77 + xmax = 55.82 + text = "L" + intervals [577]: + xmin = 55.82 + xmax = 55.88 + text = "IH0" + intervals [578]: + xmin = 55.88 + xmax = 56.05 + text = "NG" + intervals [579]: + xmin = 56.05 + xmax = 56.41 + text = "" + intervals [580]: + xmin = 56.41 + xmax = 56.48 + text = "W" + intervals [581]: + xmin = 56.48 + xmax = 56.66 + text = "EH1" + intervals [582]: + xmin = 56.66 + xmax = 56.69 + text = "N" + intervals [583]: + xmin = 56.69 + xmax = 57.1 + text = "W" + intervals [584]: + xmin = 57.1 + xmax = 57.25 + text = "IY1" + intervals [585]: + xmin = 57.25 + xmax = 57.5 + text = "" + intervals [586]: + xmin = 57.5 + xmax = 57.68 + text = "W" + intervals [587]: + xmin = 57.68 + xmax = 57.72 + text = "EH1" + intervals [588]: + xmin = 57.72 + xmax = 57.82 + text = "N" + intervals [589]: + xmin = 57.82 + xmax = 57.96 + text = "S" + intervals [590]: + xmin = 57.96 + xmax = 58.02 + text = "IY1" + intervals [591]: + xmin = 58.02 + xmax = 58.07 + text = "N" + intervals [592]: + xmin = 58.07 + xmax = 58.11 + text = "IH1" + intervals [593]: + xmin = 58.11 + xmax = 58.17 + text = "N" + intervals [594]: + xmin = 58.17 + xmax = 58.25 + text = "K" + intervals [595]: + xmin = 58.25 + xmax = 58.41 + text = "AA1" + intervals [596]: + xmin = 58.41 + xmax = 58.47 + text = "N" + intervals [597]: + xmin = 58.47 + xmax = 58.53 + text = "T" + intervals [598]: + xmin = 58.53 + xmax = 58.65 + text = "EH0" + intervals [599]: + xmin = 58.65 + xmax = 58.69 + text = "K" + intervals [600]: + xmin = 58.69 + xmax = 58.74 + text = "S" + intervals [601]: + xmin = 58.74 + xmax = 58.77 + text = "T" + intervals [602]: + xmin = 58.77 + xmax = 58.8 + text = "W" + intervals [603]: + xmin = 58.8 + xmax = 58.83 + text = "IH0" + intervals [604]: + xmin = 58.83 + xmax = 58.86 + text = "DH" + intervals [605]: + xmin = 58.86 + xmax = 58.9 + text = "DH" + intervals [606]: + xmin = 58.9 + xmax = 58.93 + text = "AH1" + intervals [607]: + xmin = 58.93 + xmax = 59.03 + text = "B" + intervals [608]: + xmin = 59.03 + xmax = 59.2 + text = "AE1" + intervals [609]: + xmin = 59.2 + xmax = 59.25 + text = "K" + intervals [610]: + xmin = 59.25 + xmax = 59.29 + text = "G" + intervals [611]: + xmin = 59.29 + xmax = 59.33 + text = "R" + intervals [612]: + xmin = 59.33 + xmax = 59.45 + text = "AW2" + intervals [613]: + xmin = 59.45 + xmax = 59.52 + text = "N" + intervals [614]: + xmin = 59.52 + xmax = 59.61 + text = "D" + intervals [615]: + xmin = 59.61 + xmax = 59.96 + text = "" + intervals [616]: + xmin = 59.96 + xmax = 60.13 + text = "W" + intervals [617]: + xmin = 60.13 + xmax = 60.17 + text = "EH1" + intervals [618]: + xmin = 60.17 + xmax = 60.31 + text = "N" + intervals [619]: + xmin = 60.31 + xmax = 60.41 + text = "EH1" + intervals [620]: + xmin = 60.41 + xmax = 60.45 + text = "V" + intervals [621]: + xmin = 60.45 + xmax = 60.48 + text = "R" + intervals [622]: + xmin = 60.48 + xmax = 60.51 + text = "IY0" + intervals [623]: + xmin = 60.51 + xmax = 60.58 + text = "W" + intervals [624]: + xmin = 60.58 + xmax = 60.65 + text = "AH2" + intervals [625]: + xmin = 60.65 + xmax = 60.69 + text = "N" + intervals [626]: + xmin = 60.69 + xmax = 60.79 + text = "Z" + intervals [627]: + xmin = 60.79 + xmax = 60.86 + text = "T" + intervals [628]: + xmin = 60.86 + xmax = 60.92 + text = "EY1" + intervals [629]: + xmin = 60.92 + xmax = 60.95 + text = "K" + intervals [630]: + xmin = 60.95 + xmax = 61.0 + text = "IH0" + intervals [631]: + xmin = 61.0 + xmax = 61.08 + text = "NG" + intervals [632]: + xmin = 61.08 + xmax = 61.14 + text = "AH0" + intervals [633]: + xmin = 61.14 + xmax = 61.2 + text = "P" + intervals [634]: + xmin = 61.2 + xmax = 61.28 + text = "IH1" + intervals [635]: + xmin = 61.28 + xmax = 61.32 + text = "K" + intervals [636]: + xmin = 61.32 + xmax = 61.43 + text = "CH" + intervals [637]: + xmin = 61.43 + xmax = 61.46 + text = "ER0" + intervals [638]: + xmin = 61.46 + xmax = 61.49 + text = "AH0" + intervals [639]: + xmin = 61.49 + xmax = 61.53 + text = "V" + intervals [640]: + xmin = 61.53 + xmax = 61.56 + text = "DH" + intervals [641]: + xmin = 61.56 + xmax = 61.6 + text = "AH0" + intervals [642]: + xmin = 61.6 + xmax = 61.64 + text = "IH0" + intervals [643]: + xmin = 61.64 + xmax = 61.69 + text = "G" + intervals [644]: + xmin = 61.69 + xmax = 61.76 + text = "Z" + intervals [645]: + xmin = 61.76 + xmax = 61.9 + text = "AE1" + intervals [646]: + xmin = 61.9 + xmax = 61.94 + text = "K" + intervals [647]: + xmin = 61.94 + xmax = 61.98 + text = "T" + intervals [648]: + xmin = 61.98 + xmax = 62.05 + text = "S" + intervals [649]: + xmin = 62.05 + xmax = 62.13 + text = "EY1" + intervals [650]: + xmin = 62.13 + xmax = 62.19 + text = "M" + intervals [651]: + xmin = 62.19 + xmax = 62.3 + text = "S" + intervals [652]: + xmin = 62.3 + xmax = 62.34 + text = "IY1" + intervals [653]: + xmin = 62.34 + xmax = 62.39 + text = "N" + intervals [654]: + xmin = 62.39 + xmax = 62.47 + text = "ER0" + intervals [655]: + xmin = 62.47 + xmax = 62.64 + text = "IY0" + intervals [656]: + xmin = 62.64 + xmax = 62.67 + text = "" + intervals [657]: + xmin = 62.67 + xmax = 62.8 + text = "AY1" + intervals [658]: + xmin = 62.8 + xmax = 62.89 + text = "M" + intervals [659]: + xmin = 62.89 + xmax = 63.03 + text = "V" + intervals [660]: + xmin = 63.03 + xmax = 63.12 + text = "EH1" + intervals [661]: + xmin = 63.12 + xmax = 63.22 + text = "R" + intervals [662]: + xmin = 63.22 + xmax = 63.28 + text = "IY0" + intervals [663]: + xmin = 63.28 + xmax = 63.39 + text = "HH" + intervals [664]: + xmin = 63.39 + xmax = 63.56 + text = "AE1" + intervals [665]: + xmin = 63.56 + xmax = 63.62 + text = "P" + intervals [666]: + xmin = 63.62 + xmax = 63.75 + text = "IY0" + intervals [667]: + xmin = 63.75 + xmax = 63.81 + text = "W" + intervals [668]: + xmin = 63.81 + xmax = 63.85 + text = "IH1" + intervals [669]: + xmin = 63.85 + xmax = 63.9 + text = "N" + intervals [670]: + xmin = 63.9 + xmax = 64.03 + text = "P" + intervals [671]: + xmin = 64.03 + xmax = 64.14 + text = "IY1" + intervals [672]: + xmin = 64.14 + xmax = 64.19 + text = "P" + intervals [673]: + xmin = 64.19 + xmax = 64.23 + text = "AH0" + intervals [674]: + xmin = 64.23 + xmax = 64.35 + text = "L" + intervals [675]: + xmin = 64.35 + xmax = 64.48 + text = "S" + intervals [676]: + xmin = 64.48 + xmax = 64.6 + text = "EY1" + intervals [677]: + xmin = 64.6 + xmax = 64.69 + text = "M" + intervals [678]: + xmin = 64.69 + xmax = 64.84 + text = "AY1" + intervals [679]: + xmin = 64.84 + xmax = 64.99 + text = "F" + intervals [680]: + xmin = 64.99 + xmax = 65.07 + text = "OW1" + intervals [681]: + xmin = 65.07 + xmax = 65.1 + text = "T" + intervals [682]: + xmin = 65.1 + xmax = 65.18 + text = "OW2" + intervals [683]: + xmin = 65.18 + xmax = 65.29 + text = "Z" + intervals [684]: + xmin = 65.29 + xmax = 65.37 + text = "L" + intervals [685]: + xmin = 65.37 + xmax = 65.42 + text = "UH1" + intervals [686]: + xmin = 65.42 + xmax = 65.47 + text = "K" + intervals [687]: + xmin = 65.47 + xmax = 65.67 + text = "B" + intervals [688]: + xmin = 65.67 + xmax = 65.79 + text = "EH1" + intervals [689]: + xmin = 65.79 + xmax = 65.88 + text = "T" + intervals [690]: + xmin = 65.88 + xmax = 66.11 + text = "ER0" + intervals [691]: + xmin = 66.11 + xmax = 66.43 + text = "" + intervals [692]: + xmin = 66.43 + xmax = 66.5 + text = "DH" + intervals [693]: + xmin = 66.5 + xmax = 66.53 + text = "AH0" + intervals [694]: + xmin = 66.53 + xmax = 66.56 + text = "N" + intervals [695]: + xmin = 66.56 + xmax = 66.6 + text = "DH" + intervals [696]: + xmin = 66.6 + xmax = 66.7 + text = "IY0" + intervals [697]: + xmin = 66.7 + xmax = 66.76 + text = "AH1" + intervals [698]: + xmin = 66.76 + xmax = 66.82 + text = "DH" + intervals [699]: + xmin = 66.82 + xmax = 66.95 + text = "ER0" + intervals [700]: + xmin = 66.95 + xmax = 67.19 + text = "Z" + intervals [701]: + xmin = 67.19 + xmax = 68 + text = "" diff --git a/EMAGE/test_sequences/textgrid/2_scott_0_4_4.TextGrid b/EMAGE/test_sequences/textgrid/2_scott_0_4_4.TextGrid new file mode 100644 index 0000000000000000000000000000000000000000..017574fad30ad90188cab3099371d11ff28c71a2 --- /dev/null +++ b/EMAGE/test_sequences/textgrid/2_scott_0_4_4.TextGrid @@ -0,0 +1,3844 @@ +File type = "ooTextFile" +Object class = "TextGrid" + +xmin = 0.0 +xmax = 67 +tiers? +size = 2 +item []: + item [1]: + class = "IntervalTier" + name = "words" + xmin = 0.0 + xmax = 67 + intervals: size = 235 + intervals [1]: + xmin = 0.0 + xmax = 0.53 + text = "" + intervals [2]: + xmin = 0.53 + xmax = 0.93 + text = "my" + intervals [3]: + xmin = 0.93 + xmax = 1.34 + text = "favorite" + intervals [4]: + xmin = 1.34 + xmax = 1.57 + text = "kind" + intervals [5]: + xmin = 1.57 + xmax = 1.65 + text = "of" + intervals [6]: + xmin = 1.65 + xmax = 2.2 + text = "movies" + intervals [7]: + xmin = 2.2 + xmax = 2.45 + text = "are" + intervals [8]: + xmin = 2.45 + xmax = 3.2 + text = "romantic" + intervals [9]: + xmin = 3.2 + xmax = 3.72 + text = "movies" + intervals [10]: + xmin = 3.72 + xmax = 3.75 + text = "" + intervals [11]: + xmin = 3.75 + xmax = 4.1 + text = "such" + intervals [12]: + xmin = 4.1 + xmax = 4.33 + text = "as" + intervals [13]: + xmin = 4.33 + xmax = 5.23 + text = "titanic" + intervals [14]: + xmin = 5.23 + xmax = 5.78 + text = "" + intervals [15]: + xmin = 5.78 + xmax = 6.19 + text = "it's" + intervals [16]: + xmin = 6.19 + xmax = 6.46 + text = "a" + intervals [17]: + xmin = 6.46 + xmax = 6.49 + text = "" + intervals [18]: + xmin = 6.49 + xmax = 7.13 + text = "fantastic" + intervals [19]: + xmin = 7.13 + xmax = 7.56 + text = "film" + intervals [20]: + xmin = 7.56 + xmax = 7.77 + text = "it" + intervals [21]: + xmin = 7.77 + xmax = 8.28 + text = "captured" + intervals [22]: + xmin = 8.28 + xmax = 8.73 + text = "many" + intervals [23]: + xmin = 8.73 + xmax = 9.1 + text = "young" + intervals [24]: + xmin = 9.1 + xmax = 9.44 + text = "people's" + intervals [25]: + xmin = 9.44 + xmax = 9.79 + text = "hearts" + intervals [26]: + xmin = 9.79 + xmax = 10.02 + text = "with" + intervals [27]: + xmin = 10.02 + xmax = 10.17 + text = "it's" + intervals [28]: + xmin = 10.17 + xmax = 10.65 + text = "amazing" + intervals [29]: + xmin = 10.65 + xmax = 11.12 + text = "music" + intervals [30]: + xmin = 11.12 + xmax = 11.36 + text = "and" + intervals [31]: + xmin = 11.36 + xmax = 11.92 + text = "sentimental" + intervals [32]: + xmin = 11.92 + xmax = 12.47 + text = "plots" + intervals [33]: + xmin = 12.47 + xmax = 12.84 + text = "" + intervals [34]: + xmin = 12.84 + xmax = 12.98 + text = "when" + intervals [35]: + xmin = 12.98 + xmax = 13.12 + text = "i" + intervals [36]: + xmin = 13.12 + xmax = 13.28 + text = "think" + intervals [37]: + xmin = 13.28 + xmax = 13.35 + text = "of" + intervals [38]: + xmin = 13.35 + xmax = 13.42 + text = "the" + intervals [39]: + xmin = 13.42 + xmax = 13.62 + text = "movie" + intervals [40]: + xmin = 13.62 + xmax = 14.2 + text = "titanic" + intervals [41]: + xmin = 14.2 + xmax = 14.23 + text = "" + intervals [42]: + xmin = 14.23 + xmax = 14.42 + text = "the" + intervals [43]: + xmin = 14.42 + xmax = 14.92 + text = "word" + intervals [44]: + xmin = 14.92 + xmax = 15.06 + text = "that" + intervals [45]: + xmin = 15.06 + xmax = 15.3 + text = "comes" + intervals [46]: + xmin = 15.3 + xmax = 15.39 + text = "to" + intervals [47]: + xmin = 15.39 + xmax = 15.5 + text = "my" + intervals [48]: + xmin = 15.5 + xmax = 15.91 + text = "mind" + intervals [49]: + xmin = 15.91 + xmax = 16.06 + text = "mind" + intervals [50]: + xmin = 16.06 + xmax = 16.41 + text = "" + intervals [51]: + xmin = 16.41 + xmax = 16.6 + text = "to" + intervals [52]: + xmin = 16.6 + xmax = 17.07 + text = "mises" + intervals [53]: + xmin = 17.07 + xmax = 17.15 + text = "the" + intervals [54]: + xmin = 17.15 + xmax = 17.39 + text = "whole" + intervals [55]: + xmin = 17.39 + xmax = 17.94 + text = "film" + intervals [56]: + xmin = 17.94 + xmax = 17.97 + text = "" + intervals [57]: + xmin = 17.97 + xmax = 18.18 + text = "would" + intervals [58]: + xmin = 18.18 + xmax = 18.62 + text = "be" + intervals [59]: + xmin = 18.62 + xmax = 19.09 + text = "" + intervals [60]: + xmin = 19.09 + xmax = 19.94 + text = "love" + intervals [61]: + xmin = 19.94 + xmax = 20.07 + text = "" + intervals [62]: + xmin = 20.07 + xmax = 20.27 + text = "it's" + intervals [63]: + xmin = 20.27 + xmax = 20.36 + text = "a" + intervals [64]: + xmin = 20.36 + xmax = 20.83 + text = "kind" + intervals [65]: + xmin = 20.83 + xmax = 20.98 + text = "of" + intervals [66]: + xmin = 20.98 + xmax = 21.25 + text = "thing" + intervals [67]: + xmin = 21.25 + xmax = 21.43 + text = "that" + intervals [68]: + xmin = 21.43 + xmax = 21.8 + text = "makes" + intervals [69]: + xmin = 21.8 + xmax = 22.28 + text = "you" + intervals [70]: + xmin = 22.28 + xmax = 22.31 + text = "" + intervals [71]: + xmin = 22.31 + xmax = 22.8 + text = "makes" + intervals [72]: + xmin = 22.8 + xmax = 22.91 + text = "the" + intervals [73]: + xmin = 22.91 + xmax = 23.21 + text = "world" + intervals [74]: + xmin = 23.21 + xmax = 23.38 + text = "go" + intervals [75]: + xmin = 23.38 + xmax = 23.87 + text = "round" + intervals [76]: + xmin = 23.87 + xmax = 24.08 + text = "" + intervals [77]: + xmin = 24.08 + xmax = 24.6 + text = "watching" + intervals [78]: + xmin = 24.6 + xmax = 24.8 + text = "these" + intervals [79]: + xmin = 24.8 + xmax = 25.18 + text = "kinds" + intervals [80]: + xmin = 25.18 + xmax = 25.29 + text = "of" + intervals [81]: + xmin = 25.29 + xmax = 25.83 + text = "romantic" + intervals [82]: + xmin = 25.83 + xmax = 26.23 + text = "movies" + intervals [83]: + xmin = 26.23 + xmax = 26.43 + text = "is" + intervals [84]: + xmin = 26.43 + xmax = 26.86 + text = "just" + intervals [85]: + xmin = 26.86 + xmax = 27.07 + text = "like" + intervals [86]: + xmin = 27.07 + xmax = 27.49 + text = "reading" + intervals [87]: + xmin = 27.49 + xmax = 27.56 + text = "a" + intervals [88]: + xmin = 27.56 + xmax = 27.98 + text = "book" + intervals [89]: + xmin = 27.98 + xmax = 28.11 + text = "" + intervals [90]: + xmin = 28.11 + xmax = 28.29 + text = "that" + intervals [91]: + xmin = 28.29 + xmax = 28.65 + text = "teaches" + intervals [92]: + xmin = 28.65 + xmax = 28.78 + text = "me" + intervals [93]: + xmin = 28.78 + xmax = 28.99 + text = "how" + intervals [94]: + xmin = 28.99 + xmax = 29.19 + text = "to" + intervals [95]: + xmin = 29.19 + xmax = 29.61 + text = "love" + intervals [96]: + xmin = 29.61 + xmax = 29.93 + text = "and" + intervals [97]: + xmin = 29.93 + xmax = 30.09 + text = "be" + intervals [98]: + xmin = 30.09 + xmax = 30.53 + text = "loved" + intervals [99]: + xmin = 30.53 + xmax = 30.96 + text = "" + intervals [100]: + xmin = 30.96 + xmax = 31.68 + text = "moreover" + intervals [101]: + xmin = 31.68 + xmax = 31.81 + text = "we" + intervals [102]: + xmin = 31.81 + xmax = 32.01 + text = "" + intervals [103]: + xmin = 32.01 + xmax = 32.51 + text = "can" + intervals [104]: + xmin = 32.51 + xmax = 32.56 + text = "" + intervals [105]: + xmin = 32.56 + xmax = 32.72 + text = "learn" + intervals [106]: + xmin = 32.72 + xmax = 33.09 + text = "we" + intervals [107]: + xmin = 33.09 + xmax = 33.25 + text = "can" + intervals [108]: + xmin = 33.25 + xmax = 34.05 + text = "learn" + intervals [109]: + xmin = 34.05 + xmax = 34.2 + text = "" + intervals [110]: + xmin = 34.2 + xmax = 35.12 + text = "more" + intervals [111]: + xmin = 35.12 + xmax = 35.44 + text = "from" + intervals [112]: + xmin = 35.44 + xmax = 35.66 + text = "it" + intervals [113]: + xmin = 35.66 + xmax = 35.98 + text = "such" + intervals [114]: + xmin = 35.98 + xmax = 36.35 + text = "things" + intervals [115]: + xmin = 36.35 + xmax = 36.69 + text = "as" + intervals [116]: + xmin = 36.69 + xmax = 36.89 + text = "" + intervals [117]: + xmin = 36.89 + xmax = 37.59 + text = "loyalty" + intervals [118]: + xmin = 37.59 + xmax = 37.76 + text = "and" + intervals [119]: + xmin = 37.76 + xmax = 37.88 + text = "what" + intervals [120]: + xmin = 37.88 + xmax = 37.99 + text = "we" + intervals [121]: + xmin = 37.99 + xmax = 38.47 + text = "treasure" + intervals [122]: + xmin = 38.47 + xmax = 38.58 + text = "in" + intervals [123]: + xmin = 38.58 + xmax = 38.71 + text = "our" + intervals [124]: + xmin = 38.71 + xmax = 39.11 + text = "lives" + intervals [125]: + xmin = 39.11 + xmax = 39.4 + text = "" + intervals [126]: + xmin = 39.4 + xmax = 39.8 + text = "another" + intervals [127]: + xmin = 39.8 + xmax = 40.13 + text = "movie" + intervals [128]: + xmin = 40.13 + xmax = 40.51 + text = "about" + intervals [129]: + xmin = 40.51 + xmax = 40.83 + text = "love" + intervals [130]: + xmin = 40.83 + xmax = 41.08 + text = "is" + intervals [131]: + xmin = 41.08 + xmax = 41.24 + text = "the" + intervals [132]: + xmin = 41.24 + xmax = 41.3 + text = "" + intervals [133]: + xmin = 41.3 + xmax = 41.9 + text = "secret" + intervals [134]: + xmin = 41.9 + xmax = 42.13 + text = "" + intervals [135]: + xmin = 42.13 + xmax = 42.47 + text = "the" + intervals [136]: + xmin = 42.47 + xmax = 43.01 + text = "movie" + intervals [137]: + xmin = 43.01 + xmax = 43.58 + text = "secret" + intervals [138]: + xmin = 43.58 + xmax = 43.71 + text = "is" + intervals [139]: + xmin = 43.71 + xmax = 44.1 + text = "about" + intervals [140]: + xmin = 44.1 + xmax = 44.15 + text = "a" + intervals [141]: + xmin = 44.15 + xmax = 44.88 + text = "story" + intervals [142]: + xmin = 44.88 + xmax = 44.95 + text = "" + intervals [143]: + xmin = 44.95 + xmax = 45.14 + text = "of" + intervals [144]: + xmin = 45.14 + xmax = 45.21 + text = "a" + intervals [145]: + xmin = 45.21 + xmax = 45.56 + text = "musical" + intervals [146]: + xmin = 45.56 + xmax = 45.96 + text = "prodigy" + intervals [147]: + xmin = 45.96 + xmax = 46.04 + text = "do" + intervals [148]: + xmin = 46.04 + xmax = 46.17 + text = "that" + intervals [149]: + xmin = 46.17 + xmax = 46.42 + text = "falls" + intervals [150]: + xmin = 46.42 + xmax = 46.49 + text = "in" + intervals [151]: + xmin = 46.49 + xmax = 46.62 + text = "love" + intervals [152]: + xmin = 46.62 + xmax = 46.72 + text = "with" + intervals [153]: + xmin = 46.72 + xmax = 46.85 + text = "a" + intervals [154]: + xmin = 46.85 + xmax = 46.91 + text = "" + intervals [155]: + xmin = 46.91 + xmax = 47.21 + text = "girl" + intervals [156]: + xmin = 47.21 + xmax = 47.46 + text = "who's" + intervals [157]: + xmin = 47.46 + xmax = 48.01 + text = "dying" + intervals [158]: + xmin = 48.01 + xmax = 49.08 + text = "" + intervals [159]: + xmin = 49.08 + xmax = 49.33 + text = "there" + intervals [160]: + xmin = 49.33 + xmax = 49.39 + text = "are" + intervals [161]: + xmin = 49.39 + xmax = 49.46 + text = "a" + intervals [162]: + xmin = 49.46 + xmax = 49.82 + text = "lot" + intervals [163]: + xmin = 49.82 + xmax = 50.2 + text = "of" + intervals [164]: + xmin = 50.2 + xmax = 50.29 + text = "" + intervals [165]: + xmin = 50.29 + xmax = 50.88 + text = "enviable" + intervals [166]: + xmin = 50.88 + xmax = 51.3 + text = "moments" + intervals [167]: + xmin = 51.3 + xmax = 51.37 + text = "in" + intervals [168]: + xmin = 51.37 + xmax = 51.53 + text = "this" + intervals [169]: + xmin = 51.53 + xmax = 51.77 + text = "film" + intervals [170]: + xmin = 51.77 + xmax = 52.01 + text = "such" + intervals [171]: + xmin = 52.01 + xmax = 52.2 + text = "as" + intervals [172]: + xmin = 52.2 + xmax = 52.3 + text = "the" + intervals [173]: + xmin = 52.3 + xmax = 52.57 + text = "simple" + intervals [174]: + xmin = 52.57 + xmax = 52.74 + text = "love" + intervals [175]: + xmin = 52.74 + xmax = 53.06 + text = "between" + intervals [176]: + xmin = 53.06 + xmax = 53.26 + text = "high" + intervals [177]: + xmin = 53.26 + xmax = 53.52 + text = "school" + intervals [178]: + xmin = 53.52 + xmax = 54.09 + text = "students" + intervals [179]: + xmin = 54.09 + xmax = 54.32 + text = "" + intervals [180]: + xmin = 54.32 + xmax = 54.56 + text = "every" + intervals [181]: + xmin = 54.56 + xmax = 54.93 + text = "time" + intervals [182]: + xmin = 54.93 + xmax = 54.96 + text = "" + intervals [183]: + xmin = 54.96 + xmax = 55.08 + text = "i" + intervals [184]: + xmin = 55.08 + xmax = 55.45 + text = "watch" + intervals [185]: + xmin = 55.45 + xmax = 55.68 + text = "this" + intervals [186]: + xmin = 55.68 + xmax = 55.89 + text = "movie" + intervals [187]: + xmin = 55.89 + xmax = 55.98 + text = "it" + intervals [188]: + xmin = 55.98 + xmax = 56.44 + text = "reminds" + intervals [189]: + xmin = 56.44 + xmax = 56.55 + text = "me" + intervals [190]: + xmin = 56.55 + xmax = 56.63 + text = "of" + intervals [191]: + xmin = 56.63 + xmax = 56.7 + text = "a" + intervals [192]: + xmin = 56.7 + xmax = 56.99 + text = "time" + intervals [193]: + xmin = 56.99 + xmax = 57.08 + text = "that" + intervals [194]: + xmin = 57.08 + xmax = 57.13 + text = "i" + intervals [195]: + xmin = 57.13 + xmax = 57.23 + text = "was" + intervals [196]: + xmin = 57.23 + xmax = 57.29 + text = "in" + intervals [197]: + xmin = 57.29 + xmax = 57.53 + text = "high" + intervals [198]: + xmin = 57.53 + xmax = 57.88 + text = "school" + intervals [199]: + xmin = 57.88 + xmax = 58.03 + text = "and" + intervals [200]: + xmin = 58.03 + xmax = 58.25 + text = "" + intervals [201]: + xmin = 58.25 + xmax = 58.4 + text = "you" + intervals [202]: + xmin = 58.4 + xmax = 58.55 + text = "might" + intervals [203]: + xmin = 58.55 + xmax = 59.1 + text = "remember" + intervals [204]: + xmin = 59.1 + xmax = 59.29 + text = "" + intervals [205]: + xmin = 59.29 + xmax = 59.44 + text = "the" + intervals [206]: + xmin = 59.44 + xmax = 59.83 + text = "crush" + intervals [207]: + xmin = 59.83 + xmax = 59.94 + text = "you" + intervals [208]: + xmin = 59.94 + xmax = 60.16 + text = "had" + intervals [209]: + xmin = 60.16 + xmax = 60.27 + text = "in" + intervals [210]: + xmin = 60.27 + xmax = 60.74 + text = "school" + intervals [211]: + xmin = 60.74 + xmax = 60.9 + text = "" + intervals [212]: + xmin = 60.9 + xmax = 61.12 + text = "and" + intervals [213]: + xmin = 61.12 + xmax = 61.24 + text = "how" + intervals [214]: + xmin = 61.24 + xmax = 61.36 + text = "you" + intervals [215]: + xmin = 61.36 + xmax = 61.48 + text = "would" + intervals [216]: + xmin = 61.48 + xmax = 61.7 + text = "look" + intervals [217]: + xmin = 61.7 + xmax = 61.77 + text = "at" + intervals [218]: + xmin = 61.77 + xmax = 62.16 + text = "him" + intervals [219]: + xmin = 62.16 + xmax = 62.37 + text = "" + intervals [220]: + xmin = 62.37 + xmax = 62.54 + text = "while" + intervals [221]: + xmin = 62.54 + xmax = 62.74 + text = "he's" + intervals [222]: + xmin = 62.74 + xmax = 63.02 + text = "at" + intervals [223]: + xmin = 63.02 + xmax = 63.61 + text = "in" + intervals [224]: + xmin = 63.61 + xmax = 64.04 + text = "class" + intervals [225]: + xmin = 64.04 + xmax = 64.38 + text = "without" + intervals [226]: + xmin = 64.38 + xmax = 64.83 + text = "thinking" + intervals [227]: + xmin = 64.83 + xmax = 64.95 + text = "or" + intervals [228]: + xmin = 64.95 + xmax = 64.98 + text = "" + intervals [229]: + xmin = 64.98 + xmax = 65.27 + text = "wanting" + intervals [230]: + xmin = 65.27 + xmax = 65.36 + text = "to" + intervals [231]: + xmin = 65.36 + xmax = 65.54 + text = "go" + intervals [232]: + xmin = 65.54 + xmax = 65.95 + text = "places" + intervals [233]: + xmin = 65.95 + xmax = 66.12 + text = "with" + intervals [234]: + xmin = 66.12 + xmax = 66.38 + text = "him" + intervals [235]: + xmin = 66.38 + xmax = 67 + text = "" + item [2]: + class = "IntervalTier" + name = "phones" + xmin = 0.0 + xmax = 67 + intervals: size = 721 + intervals [1]: + xmin = 0.0 + xmax = 0.53 + text = "" + intervals [2]: + xmin = 0.53 + xmax = 0.75 + text = "M" + intervals [3]: + xmin = 0.75 + xmax = 0.93 + text = "AY1" + intervals [4]: + xmin = 0.93 + xmax = 1.06 + text = "F" + intervals [5]: + xmin = 1.06 + xmax = 1.16 + text = "EY1" + intervals [6]: + xmin = 1.16 + xmax = 1.23 + text = "V" + intervals [7]: + xmin = 1.23 + xmax = 1.26 + text = "ER0" + intervals [8]: + xmin = 1.26 + xmax = 1.31 + text = "IH0" + intervals [9]: + xmin = 1.31 + xmax = 1.34 + text = "T" + intervals [10]: + xmin = 1.34 + xmax = 1.42 + text = "K" + intervals [11]: + xmin = 1.42 + xmax = 1.51 + text = "AY1" + intervals [12]: + xmin = 1.51 + xmax = 1.54 + text = "N" + intervals [13]: + xmin = 1.54 + xmax = 1.57 + text = "D" + intervals [14]: + xmin = 1.57 + xmax = 1.61 + text = "AH0" + intervals [15]: + xmin = 1.61 + xmax = 1.65 + text = "V" + intervals [16]: + xmin = 1.65 + xmax = 1.74 + text = "M" + intervals [17]: + xmin = 1.74 + xmax = 1.8 + text = "UW1" + intervals [18]: + xmin = 1.8 + xmax = 1.9 + text = "V" + intervals [19]: + xmin = 1.9 + xmax = 2.01 + text = "IY0" + intervals [20]: + xmin = 2.01 + xmax = 2.2 + text = "Z" + intervals [21]: + xmin = 2.2 + xmax = 2.35 + text = "AA1" + intervals [22]: + xmin = 2.35 + xmax = 2.45 + text = "R" + intervals [23]: + xmin = 2.45 + xmax = 2.55 + text = "R" + intervals [24]: + xmin = 2.55 + xmax = 2.6 + text = "OW0" + intervals [25]: + xmin = 2.6 + xmax = 2.81 + text = "M" + intervals [26]: + xmin = 2.81 + xmax = 2.93 + text = "AE1" + intervals [27]: + xmin = 2.93 + xmax = 2.99 + text = "N" + intervals [28]: + xmin = 2.99 + xmax = 3.06 + text = "T" + intervals [29]: + xmin = 3.06 + xmax = 3.13 + text = "IH0" + intervals [30]: + xmin = 3.13 + xmax = 3.2 + text = "K" + intervals [31]: + xmin = 3.2 + xmax = 3.28 + text = "M" + intervals [32]: + xmin = 3.28 + xmax = 3.38 + text = "UW1" + intervals [33]: + xmin = 3.38 + xmax = 3.42 + text = "V" + intervals [34]: + xmin = 3.42 + xmax = 3.55 + text = "IY0" + intervals [35]: + xmin = 3.55 + xmax = 3.72 + text = "Z" + intervals [36]: + xmin = 3.72 + xmax = 3.75 + text = "" + intervals [37]: + xmin = 3.75 + xmax = 3.92 + text = "S" + intervals [38]: + xmin = 3.92 + xmax = 3.99 + text = "AH1" + intervals [39]: + xmin = 3.99 + xmax = 4.1 + text = "CH" + intervals [40]: + xmin = 4.1 + xmax = 4.22 + text = "EH1" + intervals [41]: + xmin = 4.22 + xmax = 4.33 + text = "Z" + intervals [42]: + xmin = 4.33 + xmax = 4.47 + text = "T" + intervals [43]: + xmin = 4.47 + xmax = 4.58 + text = "AY0" + intervals [44]: + xmin = 4.58 + xmax = 4.77 + text = "T" + intervals [45]: + xmin = 4.77 + xmax = 4.88 + text = "AE1" + intervals [46]: + xmin = 4.88 + xmax = 4.94 + text = "N" + intervals [47]: + xmin = 4.94 + xmax = 5.04 + text = "IH0" + intervals [48]: + xmin = 5.04 + xmax = 5.23 + text = "K" + intervals [49]: + xmin = 5.23 + xmax = 5.78 + text = "" + intervals [50]: + xmin = 5.78 + xmax = 5.95 + text = "IH1" + intervals [51]: + xmin = 5.95 + xmax = 6.05 + text = "T" + intervals [52]: + xmin = 6.05 + xmax = 6.19 + text = "S" + intervals [53]: + xmin = 6.19 + xmax = 6.46 + text = "AH0" + intervals [54]: + xmin = 6.46 + xmax = 6.49 + text = "" + intervals [55]: + xmin = 6.49 + xmax = 6.55 + text = "F" + intervals [56]: + xmin = 6.55 + xmax = 6.6 + text = "AE0" + intervals [57]: + xmin = 6.6 + xmax = 6.66 + text = "N" + intervals [58]: + xmin = 6.66 + xmax = 6.79 + text = "T" + intervals [59]: + xmin = 6.79 + xmax = 6.92 + text = "AE1" + intervals [60]: + xmin = 6.92 + xmax = 6.99 + text = "S" + intervals [61]: + xmin = 6.99 + xmax = 7.03 + text = "T" + intervals [62]: + xmin = 7.03 + xmax = 7.08 + text = "IH0" + intervals [63]: + xmin = 7.08 + xmax = 7.13 + text = "K" + intervals [64]: + xmin = 7.13 + xmax = 7.21 + text = "F" + intervals [65]: + xmin = 7.21 + xmax = 7.31 + text = "IH1" + intervals [66]: + xmin = 7.31 + xmax = 7.51 + text = "L" + intervals [67]: + xmin = 7.51 + xmax = 7.56 + text = "M" + intervals [68]: + xmin = 7.56 + xmax = 7.71 + text = "IH1" + intervals [69]: + xmin = 7.71 + xmax = 7.77 + text = "T" + intervals [70]: + xmin = 7.77 + xmax = 7.86 + text = "K" + intervals [71]: + xmin = 7.86 + xmax = 7.96 + text = "AE1" + intervals [72]: + xmin = 7.96 + xmax = 8.02 + text = "P" + intervals [73]: + xmin = 8.02 + xmax = 8.1 + text = "CH" + intervals [74]: + xmin = 8.1 + xmax = 8.16 + text = "ER0" + intervals [75]: + xmin = 8.16 + xmax = 8.28 + text = "D" + intervals [76]: + xmin = 8.28 + xmax = 8.4 + text = "M" + intervals [77]: + xmin = 8.4 + xmax = 8.46 + text = "EH1" + intervals [78]: + xmin = 8.46 + xmax = 8.53 + text = "N" + intervals [79]: + xmin = 8.53 + xmax = 8.73 + text = "IY0" + intervals [80]: + xmin = 8.73 + xmax = 8.81 + text = "Y" + intervals [81]: + xmin = 8.81 + xmax = 8.99 + text = "AH1" + intervals [82]: + xmin = 8.99 + xmax = 9.1 + text = "NG" + intervals [83]: + xmin = 9.1 + xmax = 9.18 + text = "P" + intervals [84]: + xmin = 9.18 + xmax = 9.26 + text = "IY1" + intervals [85]: + xmin = 9.26 + xmax = 9.31 + text = "P" + intervals [86]: + xmin = 9.31 + xmax = 9.34 + text = "AH0" + intervals [87]: + xmin = 9.34 + xmax = 9.39 + text = "L" + intervals [88]: + xmin = 9.39 + xmax = 9.44 + text = "Z" + intervals [89]: + xmin = 9.44 + xmax = 9.5 + text = "HH" + intervals [90]: + xmin = 9.5 + xmax = 9.57 + text = "AA1" + intervals [91]: + xmin = 9.57 + xmax = 9.68 + text = "R" + intervals [92]: + xmin = 9.68 + xmax = 9.73 + text = "T" + intervals [93]: + xmin = 9.73 + xmax = 9.79 + text = "S" + intervals [94]: + xmin = 9.79 + xmax = 9.88 + text = "W" + intervals [95]: + xmin = 9.88 + xmax = 9.94 + text = "IH1" + intervals [96]: + xmin = 9.94 + xmax = 10.02 + text = "DH" + intervals [97]: + xmin = 10.02 + xmax = 10.1 + text = "IH0" + intervals [98]: + xmin = 10.1 + xmax = 10.13 + text = "T" + intervals [99]: + xmin = 10.13 + xmax = 10.17 + text = "S" + intervals [100]: + xmin = 10.17 + xmax = 10.28 + text = "AH0" + intervals [101]: + xmin = 10.28 + xmax = 10.36 + text = "M" + intervals [102]: + xmin = 10.36 + xmax = 10.47 + text = "EY1" + intervals [103]: + xmin = 10.47 + xmax = 10.53 + text = "Z" + intervals [104]: + xmin = 10.53 + xmax = 10.59 + text = "IH0" + intervals [105]: + xmin = 10.59 + xmax = 10.65 + text = "NG" + intervals [106]: + xmin = 10.65 + xmax = 10.7 + text = "M" + intervals [107]: + xmin = 10.7 + xmax = 10.78 + text = "Y" + intervals [108]: + xmin = 10.78 + xmax = 10.81 + text = "UW1" + intervals [109]: + xmin = 10.81 + xmax = 10.9 + text = "Z" + intervals [110]: + xmin = 10.9 + xmax = 10.98 + text = "IH0" + intervals [111]: + xmin = 10.98 + xmax = 11.12 + text = "K" + intervals [112]: + xmin = 11.12 + xmax = 11.21 + text = "AE1" + intervals [113]: + xmin = 11.21 + xmax = 11.27 + text = "N" + intervals [114]: + xmin = 11.27 + xmax = 11.36 + text = "D" + intervals [115]: + xmin = 11.36 + xmax = 11.47 + text = "S" + intervals [116]: + xmin = 11.47 + xmax = 11.52 + text = "EH2" + intervals [117]: + xmin = 11.52 + xmax = 11.59 + text = "N" + intervals [118]: + xmin = 11.59 + xmax = 11.64 + text = "T" + intervals [119]: + xmin = 11.64 + xmax = 11.69 + text = "AH0" + intervals [120]: + xmin = 11.69 + xmax = 11.74 + text = "M" + intervals [121]: + xmin = 11.74 + xmax = 11.78 + text = "EH1" + intervals [122]: + xmin = 11.78 + xmax = 11.81 + text = "N" + intervals [123]: + xmin = 11.81 + xmax = 11.85 + text = "T" + intervals [124]: + xmin = 11.85 + xmax = 11.88 + text = "AH0" + intervals [125]: + xmin = 11.88 + xmax = 11.92 + text = "L" + intervals [126]: + xmin = 11.92 + xmax = 11.99 + text = "P" + intervals [127]: + xmin = 11.99 + xmax = 12.04 + text = "L" + intervals [128]: + xmin = 12.04 + xmax = 12.23 + text = "AA1" + intervals [129]: + xmin = 12.23 + xmax = 12.31 + text = "T" + intervals [130]: + xmin = 12.31 + xmax = 12.47 + text = "S" + intervals [131]: + xmin = 12.47 + xmax = 12.84 + text = "" + intervals [132]: + xmin = 12.84 + xmax = 12.92 + text = "W" + intervals [133]: + xmin = 12.92 + xmax = 12.95 + text = "EH1" + intervals [134]: + xmin = 12.95 + xmax = 12.98 + text = "N" + intervals [135]: + xmin = 12.98 + xmax = 13.12 + text = "AY1" + intervals [136]: + xmin = 13.12 + xmax = 13.16 + text = "TH" + intervals [137]: + xmin = 13.16 + xmax = 13.22 + text = "IH1" + intervals [138]: + xmin = 13.22 + xmax = 13.25 + text = "NG" + intervals [139]: + xmin = 13.25 + xmax = 13.28 + text = "K" + intervals [140]: + xmin = 13.28 + xmax = 13.32 + text = "AH0" + intervals [141]: + xmin = 13.32 + xmax = 13.35 + text = "V" + intervals [142]: + xmin = 13.35 + xmax = 13.38 + text = "DH" + intervals [143]: + xmin = 13.38 + xmax = 13.42 + text = "AH0" + intervals [144]: + xmin = 13.42 + xmax = 13.46 + text = "M" + intervals [145]: + xmin = 13.46 + xmax = 13.52 + text = "UW1" + intervals [146]: + xmin = 13.52 + xmax = 13.58 + text = "V" + intervals [147]: + xmin = 13.58 + xmax = 13.62 + text = "IY0" + intervals [148]: + xmin = 13.62 + xmax = 13.69 + text = "T" + intervals [149]: + xmin = 13.69 + xmax = 13.79 + text = "AY0" + intervals [150]: + xmin = 13.79 + xmax = 13.89 + text = "T" + intervals [151]: + xmin = 13.89 + xmax = 13.96 + text = "AE1" + intervals [152]: + xmin = 13.96 + xmax = 14.02 + text = "N" + intervals [153]: + xmin = 14.02 + xmax = 14.17 + text = "IH0" + intervals [154]: + xmin = 14.17 + xmax = 14.2 + text = "K" + intervals [155]: + xmin = 14.2 + xmax = 14.23 + text = "" + intervals [156]: + xmin = 14.23 + xmax = 14.34 + text = "DH" + intervals [157]: + xmin = 14.34 + xmax = 14.42 + text = "AH1" + intervals [158]: + xmin = 14.42 + xmax = 14.61 + text = "W" + intervals [159]: + xmin = 14.61 + xmax = 14.85 + text = "ER1" + intervals [160]: + xmin = 14.85 + xmax = 14.92 + text = "D" + intervals [161]: + xmin = 14.92 + xmax = 14.95 + text = "DH" + intervals [162]: + xmin = 14.95 + xmax = 15.02 + text = "AH0" + intervals [163]: + xmin = 15.02 + xmax = 15.06 + text = "T" + intervals [164]: + xmin = 15.06 + xmax = 15.13 + text = "K" + intervals [165]: + xmin = 15.13 + xmax = 15.19 + text = "AH1" + intervals [166]: + xmin = 15.19 + xmax = 15.24 + text = "M" + intervals [167]: + xmin = 15.24 + xmax = 15.3 + text = "Z" + intervals [168]: + xmin = 15.3 + xmax = 15.35 + text = "T" + intervals [169]: + xmin = 15.35 + xmax = 15.39 + text = "AH0" + intervals [170]: + xmin = 15.39 + xmax = 15.44 + text = "M" + intervals [171]: + xmin = 15.44 + xmax = 15.5 + text = "AY1" + intervals [172]: + xmin = 15.5 + xmax = 15.61 + text = "M" + intervals [173]: + xmin = 15.61 + xmax = 15.85 + text = "AY1" + intervals [174]: + xmin = 15.85 + xmax = 15.88 + text = "N" + intervals [175]: + xmin = 15.88 + xmax = 15.91 + text = "D" + intervals [176]: + xmin = 15.91 + xmax = 15.94 + text = "M" + intervals [177]: + xmin = 15.94 + xmax = 15.97 + text = "AY1" + intervals [178]: + xmin = 15.97 + xmax = 16.0 + text = "N" + intervals [179]: + xmin = 16.0 + xmax = 16.06 + text = "D" + intervals [180]: + xmin = 16.06 + xmax = 16.41 + text = "" + intervals [181]: + xmin = 16.41 + xmax = 16.54 + text = "T" + intervals [182]: + xmin = 16.54 + xmax = 16.6 + text = "IH0" + intervals [183]: + xmin = 16.6 + xmax = 16.69 + text = "M" + intervals [184]: + xmin = 16.69 + xmax = 16.85 + text = "AY1" + intervals [185]: + xmin = 16.85 + xmax = 16.95 + text = "Z" + intervals [186]: + xmin = 16.95 + xmax = 16.99 + text = "IH0" + intervals [187]: + xmin = 16.99 + xmax = 17.07 + text = "Z" + intervals [188]: + xmin = 17.07 + xmax = 17.11 + text = "DH" + intervals [189]: + xmin = 17.11 + xmax = 17.15 + text = "AH1" + intervals [190]: + xmin = 17.15 + xmax = 17.24 + text = "HH" + intervals [191]: + xmin = 17.24 + xmax = 17.32 + text = "OW1" + intervals [192]: + xmin = 17.32 + xmax = 17.39 + text = "L" + intervals [193]: + xmin = 17.39 + xmax = 17.48 + text = "F" + intervals [194]: + xmin = 17.48 + xmax = 17.57 + text = "IH1" + intervals [195]: + xmin = 17.57 + xmax = 17.68 + text = "L" + intervals [196]: + xmin = 17.68 + xmax = 17.94 + text = "M" + intervals [197]: + xmin = 17.94 + xmax = 17.97 + text = "" + intervals [198]: + xmin = 17.97 + xmax = 18.06 + text = "W" + intervals [199]: + xmin = 18.06 + xmax = 18.12 + text = "UH1" + intervals [200]: + xmin = 18.12 + xmax = 18.18 + text = "D" + intervals [201]: + xmin = 18.18 + xmax = 18.24 + text = "B" + intervals [202]: + xmin = 18.24 + xmax = 18.62 + text = "IY1" + intervals [203]: + xmin = 18.62 + xmax = 19.09 + text = "" + intervals [204]: + xmin = 19.09 + xmax = 19.34 + text = "L" + intervals [205]: + xmin = 19.34 + xmax = 19.51 + text = "AH1" + intervals [206]: + xmin = 19.51 + xmax = 19.94 + text = "V" + intervals [207]: + xmin = 19.94 + xmax = 20.07 + text = "" + intervals [208]: + xmin = 20.07 + xmax = 20.18 + text = "IH1" + intervals [209]: + xmin = 20.18 + xmax = 20.24 + text = "T" + intervals [210]: + xmin = 20.24 + xmax = 20.27 + text = "S" + intervals [211]: + xmin = 20.27 + xmax = 20.36 + text = "AH0" + intervals [212]: + xmin = 20.36 + xmax = 20.59 + text = "K" + intervals [213]: + xmin = 20.59 + xmax = 20.74 + text = "AY1" + intervals [214]: + xmin = 20.74 + xmax = 20.79 + text = "N" + intervals [215]: + xmin = 20.79 + xmax = 20.83 + text = "D" + intervals [216]: + xmin = 20.83 + xmax = 20.87 + text = "AH0" + intervals [217]: + xmin = 20.87 + xmax = 20.98 + text = "V" + intervals [218]: + xmin = 20.98 + xmax = 21.04 + text = "TH" + intervals [219]: + xmin = 21.04 + xmax = 21.13 + text = "IH1" + intervals [220]: + xmin = 21.13 + xmax = 21.25 + text = "NG" + intervals [221]: + xmin = 21.25 + xmax = 21.31 + text = "DH" + intervals [222]: + xmin = 21.31 + xmax = 21.37 + text = "AE1" + intervals [223]: + xmin = 21.37 + xmax = 21.43 + text = "T" + intervals [224]: + xmin = 21.43 + xmax = 21.54 + text = "M" + intervals [225]: + xmin = 21.54 + xmax = 21.64 + text = "EY1" + intervals [226]: + xmin = 21.64 + xmax = 21.68 + text = "K" + intervals [227]: + xmin = 21.68 + xmax = 21.8 + text = "S" + intervals [228]: + xmin = 21.8 + xmax = 21.87 + text = "Y" + intervals [229]: + xmin = 21.87 + xmax = 22.28 + text = "UW1" + intervals [230]: + xmin = 22.28 + xmax = 22.31 + text = "" + intervals [231]: + xmin = 22.31 + xmax = 22.63 + text = "M" + intervals [232]: + xmin = 22.63 + xmax = 22.7 + text = "EY1" + intervals [233]: + xmin = 22.7 + xmax = 22.75 + text = "K" + intervals [234]: + xmin = 22.75 + xmax = 22.8 + text = "S" + intervals [235]: + xmin = 22.8 + xmax = 22.84 + text = "DH" + intervals [236]: + xmin = 22.84 + xmax = 22.91 + text = "AH0" + intervals [237]: + xmin = 22.91 + xmax = 23.0 + text = "W" + intervals [238]: + xmin = 23.0 + xmax = 23.08 + text = "ER1" + intervals [239]: + xmin = 23.08 + xmax = 23.18 + text = "L" + intervals [240]: + xmin = 23.18 + xmax = 23.21 + text = "D" + intervals [241]: + xmin = 23.21 + xmax = 23.27 + text = "G" + intervals [242]: + xmin = 23.27 + xmax = 23.38 + text = "OW1" + intervals [243]: + xmin = 23.38 + xmax = 23.48 + text = "R" + intervals [244]: + xmin = 23.48 + xmax = 23.68 + text = "AW1" + intervals [245]: + xmin = 23.68 + xmax = 23.75 + text = "N" + intervals [246]: + xmin = 23.75 + xmax = 23.87 + text = "D" + intervals [247]: + xmin = 23.87 + xmax = 24.08 + text = "" + intervals [248]: + xmin = 24.08 + xmax = 24.27 + text = "W" + intervals [249]: + xmin = 24.27 + xmax = 24.36 + text = "AA1" + intervals [250]: + xmin = 24.36 + xmax = 24.46 + text = "CH" + intervals [251]: + xmin = 24.46 + xmax = 24.54 + text = "IH0" + intervals [252]: + xmin = 24.54 + xmax = 24.6 + text = "NG" + intervals [253]: + xmin = 24.6 + xmax = 24.65 + text = "DH" + intervals [254]: + xmin = 24.65 + xmax = 24.72 + text = "IY1" + intervals [255]: + xmin = 24.72 + xmax = 24.8 + text = "Z" + intervals [256]: + xmin = 24.8 + xmax = 24.9 + text = "K" + intervals [257]: + xmin = 24.9 + xmax = 25.05 + text = "AY1" + intervals [258]: + xmin = 25.05 + xmax = 25.12 + text = "N" + intervals [259]: + xmin = 25.12 + xmax = 25.18 + text = "Z" + intervals [260]: + xmin = 25.18 + xmax = 25.21 + text = "AH0" + intervals [261]: + xmin = 25.21 + xmax = 25.29 + text = "V" + intervals [262]: + xmin = 25.29 + xmax = 25.36 + text = "R" + intervals [263]: + xmin = 25.36 + xmax = 25.39 + text = "OW0" + intervals [264]: + xmin = 25.39 + xmax = 25.5 + text = "M" + intervals [265]: + xmin = 25.5 + xmax = 25.56 + text = "AE1" + intervals [266]: + xmin = 25.56 + xmax = 25.6 + text = "N" + intervals [267]: + xmin = 25.6 + xmax = 25.65 + text = "T" + intervals [268]: + xmin = 25.65 + xmax = 25.75 + text = "IH0" + intervals [269]: + xmin = 25.75 + xmax = 25.83 + text = "K" + intervals [270]: + xmin = 25.83 + xmax = 25.9 + text = "M" + intervals [271]: + xmin = 25.9 + xmax = 25.99 + text = "UW1" + intervals [272]: + xmin = 25.99 + xmax = 26.06 + text = "V" + intervals [273]: + xmin = 26.06 + xmax = 26.14 + text = "IY0" + intervals [274]: + xmin = 26.14 + xmax = 26.23 + text = "Z" + intervals [275]: + xmin = 26.23 + xmax = 26.34 + text = "IH1" + intervals [276]: + xmin = 26.34 + xmax = 26.43 + text = "Z" + intervals [277]: + xmin = 26.43 + xmax = 26.64 + text = "JH" + intervals [278]: + xmin = 26.64 + xmax = 26.73 + text = "IH0" + intervals [279]: + xmin = 26.73 + xmax = 26.82 + text = "S" + intervals [280]: + xmin = 26.82 + xmax = 26.86 + text = "T" + intervals [281]: + xmin = 26.86 + xmax = 26.92 + text = "L" + intervals [282]: + xmin = 26.92 + xmax = 27.04 + text = "AY1" + intervals [283]: + xmin = 27.04 + xmax = 27.07 + text = "K" + intervals [284]: + xmin = 27.07 + xmax = 27.26 + text = "R" + intervals [285]: + xmin = 27.26 + xmax = 27.33 + text = "IY1" + intervals [286]: + xmin = 27.33 + xmax = 27.36 + text = "D" + intervals [287]: + xmin = 27.36 + xmax = 27.43 + text = "IH0" + intervals [288]: + xmin = 27.43 + xmax = 27.49 + text = "NG" + intervals [289]: + xmin = 27.49 + xmax = 27.56 + text = "AH0" + intervals [290]: + xmin = 27.56 + xmax = 27.63 + text = "B" + intervals [291]: + xmin = 27.63 + xmax = 27.77 + text = "UH1" + intervals [292]: + xmin = 27.77 + xmax = 27.98 + text = "K" + intervals [293]: + xmin = 27.98 + xmax = 28.11 + text = "" + intervals [294]: + xmin = 28.11 + xmax = 28.21 + text = "DH" + intervals [295]: + xmin = 28.21 + xmax = 28.25 + text = "AH0" + intervals [296]: + xmin = 28.25 + xmax = 28.29 + text = "T" + intervals [297]: + xmin = 28.29 + xmax = 28.38 + text = "T" + intervals [298]: + xmin = 28.38 + xmax = 28.44 + text = "IY1" + intervals [299]: + xmin = 28.44 + xmax = 28.52 + text = "CH" + intervals [300]: + xmin = 28.52 + xmax = 28.57 + text = "IH0" + intervals [301]: + xmin = 28.57 + xmax = 28.65 + text = "Z" + intervals [302]: + xmin = 28.65 + xmax = 28.69 + text = "M" + intervals [303]: + xmin = 28.69 + xmax = 28.78 + text = "IY1" + intervals [304]: + xmin = 28.78 + xmax = 28.91 + text = "HH" + intervals [305]: + xmin = 28.91 + xmax = 28.99 + text = "AW1" + intervals [306]: + xmin = 28.99 + xmax = 29.08 + text = "T" + intervals [307]: + xmin = 29.08 + xmax = 29.19 + text = "AH0" + intervals [308]: + xmin = 29.19 + xmax = 29.27 + text = "L" + intervals [309]: + xmin = 29.27 + xmax = 29.52 + text = "AH1" + intervals [310]: + xmin = 29.52 + xmax = 29.61 + text = "V" + intervals [311]: + xmin = 29.61 + xmax = 29.78 + text = "AE1" + intervals [312]: + xmin = 29.78 + xmax = 29.86 + text = "N" + intervals [313]: + xmin = 29.86 + xmax = 29.93 + text = "D" + intervals [314]: + xmin = 29.93 + xmax = 29.97 + text = "B" + intervals [315]: + xmin = 29.97 + xmax = 30.09 + text = "IY1" + intervals [316]: + xmin = 30.09 + xmax = 30.21 + text = "L" + intervals [317]: + xmin = 30.21 + xmax = 30.33 + text = "AH1" + intervals [318]: + xmin = 30.33 + xmax = 30.43 + text = "V" + intervals [319]: + xmin = 30.43 + xmax = 30.53 + text = "D" + intervals [320]: + xmin = 30.53 + xmax = 30.96 + text = "" + intervals [321]: + xmin = 30.96 + xmax = 31.09 + text = "M" + intervals [322]: + xmin = 31.09 + xmax = 31.14 + text = "AO0" + intervals [323]: + xmin = 31.14 + xmax = 31.22 + text = "R" + intervals [324]: + xmin = 31.22 + xmax = 31.5 + text = "OW1" + intervals [325]: + xmin = 31.5 + xmax = 31.53 + text = "V" + intervals [326]: + xmin = 31.53 + xmax = 31.68 + text = "ER0" + intervals [327]: + xmin = 31.68 + xmax = 31.74 + text = "W" + intervals [328]: + xmin = 31.74 + xmax = 31.81 + text = "IY1" + intervals [329]: + xmin = 31.81 + xmax = 32.01 + text = "" + intervals [330]: + xmin = 32.01 + xmax = 32.13 + text = "K" + intervals [331]: + xmin = 32.13 + xmax = 32.39 + text = "AE1" + intervals [332]: + xmin = 32.39 + xmax = 32.51 + text = "N" + intervals [333]: + xmin = 32.51 + xmax = 32.56 + text = "" + intervals [334]: + xmin = 32.56 + xmax = 32.65 + text = "L" + intervals [335]: + xmin = 32.65 + xmax = 32.68 + text = "ER1" + intervals [336]: + xmin = 32.68 + xmax = 32.72 + text = "N" + intervals [337]: + xmin = 32.72 + xmax = 33.02 + text = "W" + intervals [338]: + xmin = 33.02 + xmax = 33.09 + text = "IY1" + intervals [339]: + xmin = 33.09 + xmax = 33.16 + text = "K" + intervals [340]: + xmin = 33.16 + xmax = 33.21 + text = "AH0" + intervals [341]: + xmin = 33.21 + xmax = 33.25 + text = "N" + intervals [342]: + xmin = 33.25 + xmax = 33.4 + text = "L" + intervals [343]: + xmin = 33.4 + xmax = 33.58 + text = "ER1" + intervals [344]: + xmin = 33.58 + xmax = 34.05 + text = "N" + intervals [345]: + xmin = 34.05 + xmax = 34.2 + text = "" + intervals [346]: + xmin = 34.2 + xmax = 34.91 + text = "M" + intervals [347]: + xmin = 34.91 + xmax = 35.05 + text = "AO1" + intervals [348]: + xmin = 35.05 + xmax = 35.12 + text = "R" + intervals [349]: + xmin = 35.12 + xmax = 35.23 + text = "F" + intervals [350]: + xmin = 35.23 + xmax = 35.32 + text = "R" + intervals [351]: + xmin = 35.32 + xmax = 35.36 + text = "AH1" + intervals [352]: + xmin = 35.36 + xmax = 35.44 + text = "M" + intervals [353]: + xmin = 35.44 + xmax = 35.56 + text = "IH0" + intervals [354]: + xmin = 35.56 + xmax = 35.66 + text = "T" + intervals [355]: + xmin = 35.66 + xmax = 35.76 + text = "S" + intervals [356]: + xmin = 35.76 + xmax = 35.84 + text = "AH1" + intervals [357]: + xmin = 35.84 + xmax = 35.98 + text = "CH" + intervals [358]: + xmin = 35.98 + xmax = 36.06 + text = "TH" + intervals [359]: + xmin = 36.06 + xmax = 36.16 + text = "IH1" + intervals [360]: + xmin = 36.16 + xmax = 36.24 + text = "NG" + intervals [361]: + xmin = 36.24 + xmax = 36.35 + text = "Z" + intervals [362]: + xmin = 36.35 + xmax = 36.56 + text = "AE1" + intervals [363]: + xmin = 36.56 + xmax = 36.69 + text = "Z" + intervals [364]: + xmin = 36.69 + xmax = 36.89 + text = "" + intervals [365]: + xmin = 36.89 + xmax = 37.03 + text = "L" + intervals [366]: + xmin = 37.03 + xmax = 37.16 + text = "OY1" + intervals [367]: + xmin = 37.16 + xmax = 37.23 + text = "AH0" + intervals [368]: + xmin = 37.23 + xmax = 37.32 + text = "L" + intervals [369]: + xmin = 37.32 + xmax = 37.4 + text = "T" + intervals [370]: + xmin = 37.4 + xmax = 37.59 + text = "IY0" + intervals [371]: + xmin = 37.59 + xmax = 37.68 + text = "AE1" + intervals [372]: + xmin = 37.68 + xmax = 37.73 + text = "N" + intervals [373]: + xmin = 37.73 + xmax = 37.76 + text = "D" + intervals [374]: + xmin = 37.76 + xmax = 37.8 + text = "W" + intervals [375]: + xmin = 37.8 + xmax = 37.83 + text = "AH1" + intervals [376]: + xmin = 37.83 + xmax = 37.88 + text = "T" + intervals [377]: + xmin = 37.88 + xmax = 37.94 + text = "W" + intervals [378]: + xmin = 37.94 + xmax = 37.99 + text = "IY1" + intervals [379]: + xmin = 37.99 + xmax = 38.15 + text = "T" + intervals [380]: + xmin = 38.15 + xmax = 38.21 + text = "R" + intervals [381]: + xmin = 38.21 + xmax = 38.26 + text = "EH1" + intervals [382]: + xmin = 38.26 + xmax = 38.38 + text = "ZH" + intervals [383]: + xmin = 38.38 + xmax = 38.47 + text = "ER0" + intervals [384]: + xmin = 38.47 + xmax = 38.53 + text = "IH0" + intervals [385]: + xmin = 38.53 + xmax = 38.58 + text = "N" + intervals [386]: + xmin = 38.58 + xmax = 38.64 + text = "AA1" + intervals [387]: + xmin = 38.64 + xmax = 38.71 + text = "R" + intervals [388]: + xmin = 38.71 + xmax = 38.77 + text = "L" + intervals [389]: + xmin = 38.77 + xmax = 38.96 + text = "AY1" + intervals [390]: + xmin = 38.96 + xmax = 39.02 + text = "V" + intervals [391]: + xmin = 39.02 + xmax = 39.11 + text = "Z" + intervals [392]: + xmin = 39.11 + xmax = 39.4 + text = "" + intervals [393]: + xmin = 39.4 + xmax = 39.57 + text = "AH0" + intervals [394]: + xmin = 39.57 + xmax = 39.63 + text = "N" + intervals [395]: + xmin = 39.63 + xmax = 39.69 + text = "AH1" + intervals [396]: + xmin = 39.69 + xmax = 39.73 + text = "DH" + intervals [397]: + xmin = 39.73 + xmax = 39.8 + text = "ER0" + intervals [398]: + xmin = 39.8 + xmax = 39.85 + text = "M" + intervals [399]: + xmin = 39.85 + xmax = 39.98 + text = "UW1" + intervals [400]: + xmin = 39.98 + xmax = 40.06 + text = "V" + intervals [401]: + xmin = 40.06 + xmax = 40.13 + text = "IY0" + intervals [402]: + xmin = 40.13 + xmax = 40.25 + text = "AH0" + intervals [403]: + xmin = 40.25 + xmax = 40.31 + text = "B" + intervals [404]: + xmin = 40.31 + xmax = 40.46 + text = "AW1" + intervals [405]: + xmin = 40.46 + xmax = 40.51 + text = "T" + intervals [406]: + xmin = 40.51 + xmax = 40.62 + text = "L" + intervals [407]: + xmin = 40.62 + xmax = 40.76 + text = "AH1" + intervals [408]: + xmin = 40.76 + xmax = 40.83 + text = "V" + intervals [409]: + xmin = 40.83 + xmax = 41.01 + text = "IH1" + intervals [410]: + xmin = 41.01 + xmax = 41.08 + text = "Z" + intervals [411]: + xmin = 41.08 + xmax = 41.11 + text = "DH" + intervals [412]: + xmin = 41.11 + xmax = 41.24 + text = "AH0" + intervals [413]: + xmin = 41.24 + xmax = 41.3 + text = "" + intervals [414]: + xmin = 41.3 + xmax = 41.47 + text = "S" + intervals [415]: + xmin = 41.47 + xmax = 41.53 + text = "IY1" + intervals [416]: + xmin = 41.53 + xmax = 41.61 + text = "K" + intervals [417]: + xmin = 41.61 + xmax = 41.64 + text = "R" + intervals [418]: + xmin = 41.64 + xmax = 41.75 + text = "IH0" + intervals [419]: + xmin = 41.75 + xmax = 41.9 + text = "T" + intervals [420]: + xmin = 41.9 + xmax = 42.13 + text = "" + intervals [421]: + xmin = 42.13 + xmax = 42.35 + text = "DH" + intervals [422]: + xmin = 42.35 + xmax = 42.47 + text = "AH1" + intervals [423]: + xmin = 42.47 + xmax = 42.56 + text = "M" + intervals [424]: + xmin = 42.56 + xmax = 42.62 + text = "UW1" + intervals [425]: + xmin = 42.62 + xmax = 42.71 + text = "V" + intervals [426]: + xmin = 42.71 + xmax = 43.01 + text = "IY0" + intervals [427]: + xmin = 43.01 + xmax = 43.19 + text = "S" + intervals [428]: + xmin = 43.19 + xmax = 43.26 + text = "IY1" + intervals [429]: + xmin = 43.26 + xmax = 43.31 + text = "K" + intervals [430]: + xmin = 43.31 + xmax = 43.37 + text = "R" + intervals [431]: + xmin = 43.37 + xmax = 43.53 + text = "IH0" + intervals [432]: + xmin = 43.53 + xmax = 43.58 + text = "T" + intervals [433]: + xmin = 43.58 + xmax = 43.65 + text = "IH1" + intervals [434]: + xmin = 43.65 + xmax = 43.71 + text = "Z" + intervals [435]: + xmin = 43.71 + xmax = 43.77 + text = "AH0" + intervals [436]: + xmin = 43.77 + xmax = 43.84 + text = "B" + intervals [437]: + xmin = 43.84 + xmax = 44.03 + text = "AW1" + intervals [438]: + xmin = 44.03 + xmax = 44.1 + text = "T" + intervals [439]: + xmin = 44.1 + xmax = 44.15 + text = "AH0" + intervals [440]: + xmin = 44.15 + xmax = 44.33 + text = "S" + intervals [441]: + xmin = 44.33 + xmax = 44.4 + text = "T" + intervals [442]: + xmin = 44.4 + xmax = 44.51 + text = "AO1" + intervals [443]: + xmin = 44.51 + xmax = 44.65 + text = "R" + intervals [444]: + xmin = 44.65 + xmax = 44.88 + text = "IY0" + intervals [445]: + xmin = 44.88 + xmax = 44.95 + text = "" + intervals [446]: + xmin = 44.95 + xmax = 45.09 + text = "AH0" + intervals [447]: + xmin = 45.09 + xmax = 45.14 + text = "V" + intervals [448]: + xmin = 45.14 + xmax = 45.21 + text = "AH0" + intervals [449]: + xmin = 45.21 + xmax = 45.27 + text = "M" + intervals [450]: + xmin = 45.27 + xmax = 45.31 + text = "Y" + intervals [451]: + xmin = 45.31 + xmax = 45.34 + text = "UW1" + intervals [452]: + xmin = 45.34 + xmax = 45.41 + text = "Z" + intervals [453]: + xmin = 45.41 + xmax = 45.45 + text = "IH0" + intervals [454]: + xmin = 45.45 + xmax = 45.48 + text = "K" + intervals [455]: + xmin = 45.48 + xmax = 45.51 + text = "AH0" + intervals [456]: + xmin = 45.51 + xmax = 45.56 + text = "L" + intervals [457]: + xmin = 45.56 + xmax = 45.62 + text = "P" + intervals [458]: + xmin = 45.62 + xmax = 45.69 + text = "R" + intervals [459]: + xmin = 45.69 + xmax = 45.81 + text = "AA1" + intervals [460]: + xmin = 45.81 + xmax = 45.84 + text = "D" + intervals [461]: + xmin = 45.84 + xmax = 45.87 + text = "AH0" + intervals [462]: + xmin = 45.87 + xmax = 45.91 + text = "JH" + intervals [463]: + xmin = 45.91 + xmax = 45.96 + text = "IY0" + intervals [464]: + xmin = 45.96 + xmax = 45.99 + text = "D" + intervals [465]: + xmin = 45.99 + xmax = 46.04 + text = "UW1" + intervals [466]: + xmin = 46.04 + xmax = 46.09 + text = "DH" + intervals [467]: + xmin = 46.09 + xmax = 46.12 + text = "AH0" + intervals [468]: + xmin = 46.12 + xmax = 46.17 + text = "T" + intervals [469]: + xmin = 46.17 + xmax = 46.24 + text = "F" + intervals [470]: + xmin = 46.24 + xmax = 46.29 + text = "AO1" + intervals [471]: + xmin = 46.29 + xmax = 46.36 + text = "L" + intervals [472]: + xmin = 46.36 + xmax = 46.42 + text = "Z" + intervals [473]: + xmin = 46.42 + xmax = 46.46 + text = "IH0" + intervals [474]: + xmin = 46.46 + xmax = 46.49 + text = "N" + intervals [475]: + xmin = 46.49 + xmax = 46.53 + text = "L" + intervals [476]: + xmin = 46.53 + xmax = 46.59 + text = "AH1" + intervals [477]: + xmin = 46.59 + xmax = 46.62 + text = "V" + intervals [478]: + xmin = 46.62 + xmax = 46.65 + text = "W" + intervals [479]: + xmin = 46.65 + xmax = 46.68 + text = "IH0" + intervals [480]: + xmin = 46.68 + xmax = 46.72 + text = "DH" + intervals [481]: + xmin = 46.72 + xmax = 46.85 + text = "AH0" + intervals [482]: + xmin = 46.85 + xmax = 46.91 + text = "" + intervals [483]: + xmin = 46.91 + xmax = 47.04 + text = "G" + intervals [484]: + xmin = 47.04 + xmax = 47.13 + text = "ER1" + intervals [485]: + xmin = 47.13 + xmax = 47.21 + text = "L" + intervals [486]: + xmin = 47.21 + xmax = 47.27 + text = "HH" + intervals [487]: + xmin = 47.27 + xmax = 47.38 + text = "UW1" + intervals [488]: + xmin = 47.38 + xmax = 47.46 + text = "Z" + intervals [489]: + xmin = 47.46 + xmax = 47.55 + text = "D" + intervals [490]: + xmin = 47.55 + xmax = 47.73 + text = "AY1" + intervals [491]: + xmin = 47.73 + xmax = 47.84 + text = "IH0" + intervals [492]: + xmin = 47.84 + xmax = 48.01 + text = "NG" + intervals [493]: + xmin = 48.01 + xmax = 49.08 + text = "" + intervals [494]: + xmin = 49.08 + xmax = 49.15 + text = "DH" + intervals [495]: + xmin = 49.15 + xmax = 49.18 + text = "EH1" + intervals [496]: + xmin = 49.18 + xmax = 49.33 + text = "R" + intervals [497]: + xmin = 49.33 + xmax = 49.36 + text = "AA1" + intervals [498]: + xmin = 49.36 + xmax = 49.39 + text = "R" + intervals [499]: + xmin = 49.39 + xmax = 49.46 + text = "AH0" + intervals [500]: + xmin = 49.46 + xmax = 49.52 + text = "L" + intervals [501]: + xmin = 49.52 + xmax = 49.69 + text = "AA1" + intervals [502]: + xmin = 49.69 + xmax = 49.82 + text = "T" + intervals [503]: + xmin = 49.82 + xmax = 49.98 + text = "AH1" + intervals [504]: + xmin = 49.98 + xmax = 50.2 + text = "V" + intervals [505]: + xmin = 50.2 + xmax = 50.29 + text = "" + intervals [506]: + xmin = 50.29 + xmax = 50.43 + text = "EH1" + intervals [507]: + xmin = 50.43 + xmax = 50.52 + text = "N" + intervals [508]: + xmin = 50.52 + xmax = 50.56 + text = "V" + intervals [509]: + xmin = 50.56 + xmax = 50.65 + text = "IY0" + intervals [510]: + xmin = 50.65 + xmax = 50.7 + text = "AH0" + intervals [511]: + xmin = 50.7 + xmax = 50.75 + text = "B" + intervals [512]: + xmin = 50.75 + xmax = 50.78 + text = "AH0" + intervals [513]: + xmin = 50.78 + xmax = 50.88 + text = "L" + intervals [514]: + xmin = 50.88 + xmax = 50.94 + text = "M" + intervals [515]: + xmin = 50.94 + xmax = 51.05 + text = "OW1" + intervals [516]: + xmin = 51.05 + xmax = 51.12 + text = "M" + intervals [517]: + xmin = 51.12 + xmax = 51.16 + text = "AH0" + intervals [518]: + xmin = 51.16 + xmax = 51.19 + text = "N" + intervals [519]: + xmin = 51.19 + xmax = 51.23 + text = "T" + intervals [520]: + xmin = 51.23 + xmax = 51.3 + text = "S" + intervals [521]: + xmin = 51.3 + xmax = 51.34 + text = "IH0" + intervals [522]: + xmin = 51.34 + xmax = 51.37 + text = "N" + intervals [523]: + xmin = 51.37 + xmax = 51.4 + text = "DH" + intervals [524]: + xmin = 51.4 + xmax = 51.46 + text = "IH0" + intervals [525]: + xmin = 51.46 + xmax = 51.53 + text = "S" + intervals [526]: + xmin = 51.53 + xmax = 51.58 + text = "F" + intervals [527]: + xmin = 51.58 + xmax = 51.62 + text = "IH1" + intervals [528]: + xmin = 51.62 + xmax = 51.72 + text = "L" + intervals [529]: + xmin = 51.72 + xmax = 51.77 + text = "M" + intervals [530]: + xmin = 51.77 + xmax = 51.86 + text = "S" + intervals [531]: + xmin = 51.86 + xmax = 51.91 + text = "AH1" + intervals [532]: + xmin = 51.91 + xmax = 52.01 + text = "CH" + intervals [533]: + xmin = 52.01 + xmax = 52.09 + text = "EH1" + intervals [534]: + xmin = 52.09 + xmax = 52.2 + text = "Z" + intervals [535]: + xmin = 52.2 + xmax = 52.23 + text = "DH" + intervals [536]: + xmin = 52.23 + xmax = 52.3 + text = "AH0" + intervals [537]: + xmin = 52.3 + xmax = 52.39 + text = "S" + intervals [538]: + xmin = 52.39 + xmax = 52.43 + text = "IH1" + intervals [539]: + xmin = 52.43 + xmax = 52.46 + text = "M" + intervals [540]: + xmin = 52.46 + xmax = 52.5 + text = "P" + intervals [541]: + xmin = 52.5 + xmax = 52.53 + text = "AH0" + intervals [542]: + xmin = 52.53 + xmax = 52.57 + text = "L" + intervals [543]: + xmin = 52.57 + xmax = 52.66 + text = "L" + intervals [544]: + xmin = 52.66 + xmax = 52.71 + text = "AH1" + intervals [545]: + xmin = 52.71 + xmax = 52.74 + text = "V" + intervals [546]: + xmin = 52.74 + xmax = 52.79 + text = "B" + intervals [547]: + xmin = 52.79 + xmax = 52.83 + text = "IH0" + intervals [548]: + xmin = 52.83 + xmax = 52.91 + text = "T" + intervals [549]: + xmin = 52.91 + xmax = 52.99 + text = "W" + intervals [550]: + xmin = 52.99 + xmax = 53.03 + text = "IY1" + intervals [551]: + xmin = 53.03 + xmax = 53.06 + text = "N" + intervals [552]: + xmin = 53.06 + xmax = 53.18 + text = "HH" + intervals [553]: + xmin = 53.18 + xmax = 53.26 + text = "AY1" + intervals [554]: + xmin = 53.26 + xmax = 53.34 + text = "S" + intervals [555]: + xmin = 53.34 + xmax = 53.42 + text = "K" + intervals [556]: + xmin = 53.42 + xmax = 53.45 + text = "UW1" + intervals [557]: + xmin = 53.45 + xmax = 53.52 + text = "L" + intervals [558]: + xmin = 53.52 + xmax = 53.63 + text = "S" + intervals [559]: + xmin = 53.63 + xmax = 53.68 + text = "T" + intervals [560]: + xmin = 53.68 + xmax = 53.77 + text = "UW1" + intervals [561]: + xmin = 53.77 + xmax = 53.8 + text = "D" + intervals [562]: + xmin = 53.8 + xmax = 53.84 + text = "AH0" + intervals [563]: + xmin = 53.84 + xmax = 53.89 + text = "N" + intervals [564]: + xmin = 53.89 + xmax = 53.95 + text = "T" + intervals [565]: + xmin = 53.95 + xmax = 54.09 + text = "S" + intervals [566]: + xmin = 54.09 + xmax = 54.32 + text = "" + intervals [567]: + xmin = 54.32 + xmax = 54.42 + text = "EH1" + intervals [568]: + xmin = 54.42 + xmax = 54.45 + text = "V" + intervals [569]: + xmin = 54.45 + xmax = 54.51 + text = "R" + intervals [570]: + xmin = 54.51 + xmax = 54.56 + text = "IY0" + intervals [571]: + xmin = 54.56 + xmax = 54.65 + text = "T" + intervals [572]: + xmin = 54.65 + xmax = 54.83 + text = "AY1" + intervals [573]: + xmin = 54.83 + xmax = 54.93 + text = "M" + intervals [574]: + xmin = 54.93 + xmax = 54.96 + text = "" + intervals [575]: + xmin = 54.96 + xmax = 55.08 + text = "AY1" + intervals [576]: + xmin = 55.08 + xmax = 55.28 + text = "W" + intervals [577]: + xmin = 55.28 + xmax = 55.36 + text = "AA1" + intervals [578]: + xmin = 55.36 + xmax = 55.45 + text = "CH" + intervals [579]: + xmin = 55.45 + xmax = 55.53 + text = "DH" + intervals [580]: + xmin = 55.53 + xmax = 55.59 + text = "IH0" + intervals [581]: + xmin = 55.59 + xmax = 55.68 + text = "S" + intervals [582]: + xmin = 55.68 + xmax = 55.73 + text = "M" + intervals [583]: + xmin = 55.73 + xmax = 55.76 + text = "UW1" + intervals [584]: + xmin = 55.76 + xmax = 55.84 + text = "V" + intervals [585]: + xmin = 55.84 + xmax = 55.89 + text = "IY0" + intervals [586]: + xmin = 55.89 + xmax = 55.92 + text = "IH1" + intervals [587]: + xmin = 55.92 + xmax = 55.98 + text = "T" + intervals [588]: + xmin = 55.98 + xmax = 56.05 + text = "R" + intervals [589]: + xmin = 56.05 + xmax = 56.12 + text = "IY0" + intervals [590]: + xmin = 56.12 + xmax = 56.2 + text = "M" + intervals [591]: + xmin = 56.2 + xmax = 56.34 + text = "AY1" + intervals [592]: + xmin = 56.34 + xmax = 56.38 + text = "N" + intervals [593]: + xmin = 56.38 + xmax = 56.41 + text = "D" + intervals [594]: + xmin = 56.41 + xmax = 56.44 + text = "Z" + intervals [595]: + xmin = 56.44 + xmax = 56.49 + text = "M" + intervals [596]: + xmin = 56.49 + xmax = 56.55 + text = "IY1" + intervals [597]: + xmin = 56.55 + xmax = 56.6 + text = "AH0" + intervals [598]: + xmin = 56.6 + xmax = 56.63 + text = "V" + intervals [599]: + xmin = 56.63 + xmax = 56.7 + text = "AH0" + intervals [600]: + xmin = 56.7 + xmax = 56.82 + text = "T" + intervals [601]: + xmin = 56.82 + xmax = 56.94 + text = "AY1" + intervals [602]: + xmin = 56.94 + xmax = 56.99 + text = "M" + intervals [603]: + xmin = 56.99 + xmax = 57.02 + text = "DH" + intervals [604]: + xmin = 57.02 + xmax = 57.05 + text = "AH0" + intervals [605]: + xmin = 57.05 + xmax = 57.08 + text = "T" + intervals [606]: + xmin = 57.08 + xmax = 57.13 + text = "AY1" + intervals [607]: + xmin = 57.13 + xmax = 57.17 + text = "W" + intervals [608]: + xmin = 57.17 + xmax = 57.2 + text = "AH0" + intervals [609]: + xmin = 57.2 + xmax = 57.23 + text = "Z" + intervals [610]: + xmin = 57.23 + xmax = 57.26 + text = "IH0" + intervals [611]: + xmin = 57.26 + xmax = 57.29 + text = "N" + intervals [612]: + xmin = 57.29 + xmax = 57.41 + text = "HH" + intervals [613]: + xmin = 57.41 + xmax = 57.53 + text = "AY1" + intervals [614]: + xmin = 57.53 + xmax = 57.66 + text = "S" + intervals [615]: + xmin = 57.66 + xmax = 57.69 + text = "K" + intervals [616]: + xmin = 57.69 + xmax = 57.84 + text = "UW1" + intervals [617]: + xmin = 57.84 + xmax = 57.88 + text = "L" + intervals [618]: + xmin = 57.88 + xmax = 57.91 + text = "AH0" + intervals [619]: + xmin = 57.91 + xmax = 57.94 + text = "N" + intervals [620]: + xmin = 57.94 + xmax = 58.03 + text = "D" + intervals [621]: + xmin = 58.03 + xmax = 58.25 + text = "" + intervals [622]: + xmin = 58.25 + xmax = 58.36 + text = "Y" + intervals [623]: + xmin = 58.36 + xmax = 58.4 + text = "UW1" + intervals [624]: + xmin = 58.4 + xmax = 58.45 + text = "M" + intervals [625]: + xmin = 58.45 + xmax = 58.51 + text = "AY1" + intervals [626]: + xmin = 58.51 + xmax = 58.55 + text = "T" + intervals [627]: + xmin = 58.55 + xmax = 58.59 + text = "R" + intervals [628]: + xmin = 58.59 + xmax = 58.66 + text = "IH0" + intervals [629]: + xmin = 58.66 + xmax = 58.7 + text = "M" + intervals [630]: + xmin = 58.7 + xmax = 58.76 + text = "EH1" + intervals [631]: + xmin = 58.76 + xmax = 58.79 + text = "M" + intervals [632]: + xmin = 58.79 + xmax = 58.84 + text = "B" + intervals [633]: + xmin = 58.84 + xmax = 59.1 + text = "ER0" + intervals [634]: + xmin = 59.1 + xmax = 59.29 + text = "" + intervals [635]: + xmin = 59.29 + xmax = 59.38 + text = "DH" + intervals [636]: + xmin = 59.38 + xmax = 59.44 + text = "AH0" + intervals [637]: + xmin = 59.44 + xmax = 59.56 + text = "K" + intervals [638]: + xmin = 59.56 + xmax = 59.63 + text = "R" + intervals [639]: + xmin = 59.63 + xmax = 59.69 + text = "AH1" + intervals [640]: + xmin = 59.69 + xmax = 59.83 + text = "SH" + intervals [641]: + xmin = 59.83 + xmax = 59.89 + text = "Y" + intervals [642]: + xmin = 59.89 + xmax = 59.94 + text = "UW1" + intervals [643]: + xmin = 59.94 + xmax = 60.03 + text = "HH" + intervals [644]: + xmin = 60.03 + xmax = 60.12 + text = "AE1" + intervals [645]: + xmin = 60.12 + xmax = 60.16 + text = "D" + intervals [646]: + xmin = 60.16 + xmax = 60.2 + text = "IH0" + intervals [647]: + xmin = 60.2 + xmax = 60.27 + text = "N" + intervals [648]: + xmin = 60.27 + xmax = 60.37 + text = "S" + intervals [649]: + xmin = 60.37 + xmax = 60.43 + text = "K" + intervals [650]: + xmin = 60.43 + xmax = 60.54 + text = "UW1" + intervals [651]: + xmin = 60.54 + xmax = 60.74 + text = "L" + intervals [652]: + xmin = 60.74 + xmax = 60.9 + text = "" + intervals [653]: + xmin = 60.9 + xmax = 61.06 + text = "AE1" + intervals [654]: + xmin = 61.06 + xmax = 61.09 + text = "N" + intervals [655]: + xmin = 61.09 + xmax = 61.12 + text = "D" + intervals [656]: + xmin = 61.12 + xmax = 61.2 + text = "HH" + intervals [657]: + xmin = 61.2 + xmax = 61.24 + text = "AW1" + intervals [658]: + xmin = 61.24 + xmax = 61.29 + text = "Y" + intervals [659]: + xmin = 61.29 + xmax = 61.36 + text = "UW1" + intervals [660]: + xmin = 61.36 + xmax = 61.4 + text = "W" + intervals [661]: + xmin = 61.4 + xmax = 61.43 + text = "UH1" + intervals [662]: + xmin = 61.43 + xmax = 61.48 + text = "D" + intervals [663]: + xmin = 61.48 + xmax = 61.57 + text = "L" + intervals [664]: + xmin = 61.57 + xmax = 61.63 + text = "UH1" + intervals [665]: + xmin = 61.63 + xmax = 61.7 + text = "K" + intervals [666]: + xmin = 61.7 + xmax = 61.73 + text = "AE1" + intervals [667]: + xmin = 61.73 + xmax = 61.77 + text = "T" + intervals [668]: + xmin = 61.77 + xmax = 61.84 + text = "HH" + intervals [669]: + xmin = 61.84 + xmax = 61.98 + text = "IH1" + intervals [670]: + xmin = 61.98 + xmax = 62.16 + text = "M" + intervals [671]: + xmin = 62.16 + xmax = 62.37 + text = "" + intervals [672]: + xmin = 62.37 + xmax = 62.42 + text = "HH" + intervals [673]: + xmin = 62.42 + xmax = 62.46 + text = "W" + intervals [674]: + xmin = 62.46 + xmax = 62.49 + text = "AY1" + intervals [675]: + xmin = 62.49 + xmax = 62.54 + text = "L" + intervals [676]: + xmin = 62.54 + xmax = 62.59 + text = "HH" + intervals [677]: + xmin = 62.59 + xmax = 62.64 + text = "IY1" + intervals [678]: + xmin = 62.64 + xmax = 62.74 + text = "Z" + intervals [679]: + xmin = 62.74 + xmax = 62.93 + text = "AE1" + intervals [680]: + xmin = 62.93 + xmax = 63.02 + text = "T" + intervals [681]: + xmin = 63.02 + xmax = 63.5 + text = "IH1" + intervals [682]: + xmin = 63.5 + xmax = 63.61 + text = "N" + intervals [683]: + xmin = 63.61 + xmax = 63.7 + text = "K" + intervals [684]: + xmin = 63.7 + xmax = 63.77 + text = "L" + intervals [685]: + xmin = 63.77 + xmax = 63.93 + text = "AE1" + intervals [686]: + xmin = 63.93 + xmax = 64.04 + text = "S" + intervals [687]: + xmin = 64.04 + xmax = 64.13 + text = "W" + intervals [688]: + xmin = 64.13 + xmax = 64.16 + text = "IH0" + intervals [689]: + xmin = 64.16 + xmax = 64.2 + text = "DH" + intervals [690]: + xmin = 64.2 + xmax = 64.3 + text = "AW1" + intervals [691]: + xmin = 64.3 + xmax = 64.38 + text = "T" + intervals [692]: + xmin = 64.38 + xmax = 64.45 + text = "TH" + intervals [693]: + xmin = 64.45 + xmax = 64.52 + text = "IH1" + intervals [694]: + xmin = 64.52 + xmax = 64.58 + text = "NG" + intervals [695]: + xmin = 64.58 + xmax = 64.64 + text = "K" + intervals [696]: + xmin = 64.64 + xmax = 64.76 + text = "IH0" + intervals [697]: + xmin = 64.76 + xmax = 64.83 + text = "NG" + intervals [698]: + xmin = 64.83 + xmax = 64.88 + text = "AO1" + intervals [699]: + xmin = 64.88 + xmax = 64.95 + text = "R" + intervals [700]: + xmin = 64.95 + xmax = 64.98 + text = "" + intervals [701]: + xmin = 64.98 + xmax = 65.13 + text = "W" + intervals [702]: + xmin = 65.13 + xmax = 65.17 + text = "AA1" + intervals [703]: + xmin = 65.17 + xmax = 65.21 + text = "N" + intervals [704]: + xmin = 65.21 + xmax = 65.24 + text = "IH0" + intervals [705]: + xmin = 65.24 + xmax = 65.27 + text = "NG" + intervals [706]: + xmin = 65.27 + xmax = 65.31 + text = "T" + intervals [707]: + xmin = 65.31 + xmax = 65.36 + text = "AH0" + intervals [708]: + xmin = 65.36 + xmax = 65.42 + text = "G" + intervals [709]: + xmin = 65.42 + xmax = 65.54 + text = "OW1" + intervals [710]: + xmin = 65.54 + xmax = 65.63 + text = "P" + intervals [711]: + xmin = 65.63 + xmax = 65.7 + text = "L" + intervals [712]: + xmin = 65.7 + xmax = 65.76 + text = "EY1" + intervals [713]: + xmin = 65.76 + xmax = 65.83 + text = "S" + intervals [714]: + xmin = 65.83 + xmax = 65.88 + text = "IH0" + intervals [715]: + xmin = 65.88 + xmax = 65.95 + text = "Z" + intervals [716]: + xmin = 65.95 + xmax = 66.0 + text = "W" + intervals [717]: + xmin = 66.0 + xmax = 66.03 + text = "IH1" + intervals [718]: + xmin = 66.03 + xmax = 66.12 + text = "DH" + intervals [719]: + xmin = 66.12 + xmax = 66.2 + text = "IH0" + intervals [720]: + xmin = 66.2 + xmax = 66.38 + text = "M" + intervals [721]: + xmin = 66.38 + xmax = 67 + text = "" diff --git a/EMAGE/test_sequences/wave16k/2_scott_0_1_1.wav b/EMAGE/test_sequences/wave16k/2_scott_0_1_1.wav new file mode 100644 index 0000000000000000000000000000000000000000..ce126b220378ad565fb546a29aa3551325e16b8e Binary files /dev/null and b/EMAGE/test_sequences/wave16k/2_scott_0_1_1.wav differ diff --git a/EMAGE/test_sequences/wave16k/2_scott_0_2_2.wav b/EMAGE/test_sequences/wave16k/2_scott_0_2_2.wav new file mode 100644 index 0000000000000000000000000000000000000000..237bef4d0fc331b2aad7aedefe8df8ccacaa3692 Binary files /dev/null and b/EMAGE/test_sequences/wave16k/2_scott_0_2_2.wav differ diff --git a/EMAGE/test_sequences/wave16k/2_scott_0_3_3.wav b/EMAGE/test_sequences/wave16k/2_scott_0_3_3.wav new file mode 100644 index 0000000000000000000000000000000000000000..136de1d7f9d24b0e35758c701379b578d5dd960f Binary files /dev/null and b/EMAGE/test_sequences/wave16k/2_scott_0_3_3.wav differ diff --git a/EMAGE/test_sequences/wave16k/2_scott_0_4_4.wav b/EMAGE/test_sequences/wave16k/2_scott_0_4_4.wav new file mode 100644 index 0000000000000000000000000000000000000000..4d76914f5a9c269442eccc129e351de8ad14d7aa Binary files /dev/null and b/EMAGE/test_sequences/wave16k/2_scott_0_4_4.wav differ diff --git a/EMAGE/test_sequences/weights/AESKConv_240_100.bin b/EMAGE/test_sequences/weights/AESKConv_240_100.bin new file mode 100644 index 0000000000000000000000000000000000000000..1d1ea36ecd9582802176c499eba43969144ad9fe --- /dev/null +++ b/EMAGE/test_sequences/weights/AESKConv_240_100.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cd9566b24264f34d44003b3de62cdfd50aa85b7cdde2d369214599023c40f55 +size 17558653 diff --git a/EMAGE/test_sequences/weights/mean_vel_smplxflame_30.npy b/EMAGE/test_sequences/weights/mean_vel_smplxflame_30.npy new file mode 100644 index 0000000000000000000000000000000000000000..0789238537103f051a6c51a3c4725e8afe85e140 --- /dev/null +++ b/EMAGE/test_sequences/weights/mean_vel_smplxflame_30.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53b5e48f2a7bf78c41a6de6395d6bb4f29018465ca5d0ee2820a2be3eebb7137 +size 348 diff --git a/EMAGE/test_sequences/weights/vocab.pkl b/EMAGE/test_sequences/weights/vocab.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3daf14aa7f32d7823b19fd765ff5739bb9a1bd32 --- /dev/null +++ b/EMAGE/test_sequences/weights/vocab.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54fbcea7b19e0ee9b5c5836c85087a682d3a9513041091ce3e95d83eed0b2acd +size 13821361 diff --git a/README.md b/README.md index bc4835a5ecc6342624354a1bd051209326282e81..290eb050c2de0c8a7305fe604ca61d5cda6a61a9 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ --- title: EMAGE -emoji: 💻 -colorFrom: green +emoji: ⚡ +colorFrom: yellow colorTo: green sdk: gradio sdk_version: 4.24.0 app_file: app.py pinned: false +license: apache-2.0 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/ae_trainer.py b/ae_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc27c34aa8c3544671cc3db0a529d73dab3134f --- /dev/null +++ b/ae_trainer.py @@ -0,0 +1,375 @@ +import train +import os +import time +import csv +import sys +import warnings +import random +import numpy as np +import time +import pprint +import pickle + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP +from loguru import logger +import smplx + +from utils import config, logger_tools, other_tools, metric +from utils import rotation_conversions as rc +from dataloaders import data_tools +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from scipy.spatial.transform import Rotation + + +class CustomTrainer(train.BaseTrainer): + """ + motion representation learning + """ + def __init__(self, args): + super().__init__(args) + self.joints = self.train_data.joints + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + self.tracker = other_tools.EpochTracker(["rec", "vel", "ver", "com", "kl", "acc"], [False, False, False, False, False, False]) + if not self.args.rot6d: #"rot6d" not in args.pose_rep: + logger.error(f"this script is for rot6d, your pose rep. is {args.pose_rep}") + self.rec_loss = get_loss_func("GeodesicLoss") + self.vel_loss = torch.nn.L1Loss(reduction='mean') + self.vectices_loss = torch.nn.MSELoss(reduction='mean') + + def inverse_selection(self, filtered_t, selection_array, n): + # 创建一个全为零的数组,形状为 n*165 + original_shape_t = np.zeros((n, selection_array.size)) + + # 找到选择数组中为1的索引位置 + selected_indices = np.where(selection_array == 1)[0] + + # 将 filtered_t 的值填充到 original_shape_t 中相应的位置 + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + # 创建一个全为零的数组,形状为 n*165 + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + + # 找到选择数组中为1的索引位置 + selected_indices = torch.where(selection_array == 1)[0] + + # 将 filtered_t 的值填充到 original_shape_t 中相应的位置 + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + + return original_shape_t + + def train(self, epoch): + self.model.train() + t_start = time.time() + self.tracker.reset() + for its, dict_data in enumerate(self.train_loader): + tar_pose = dict_data["pose"] + tar_beta = dict_data["beta"].cuda() + tar_trans = dict_data["trans"].cuda() + tar_pose = tar_pose.cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_exps = torch.zeros((bs, n, 100)).cuda() + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + t_data = time.time() - t_start + + self.opt.zero_grad() + g_loss_final = 0 + net_out = self.model(tar_pose) + rec_pose = net_out["rec_pose"] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("rec", "train", loss_rec.item()) + g_loss_final += loss_rec + + velocity_loss = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight + acceleration_loss = self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight + self.tracker.update_meter("vel", "train", velocity_loss.item()) + self.tracker.update_meter("acc", "train", acceleration_loss.item()) + g_loss_final += velocity_loss + g_loss_final += acceleration_loss + # vertices loss + if self.args.rec_ver_weight > 0: + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0]) + tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0]) + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:,:3], + body_pose=rec_pose[:,3:21*3+3], + left_hand_pose=rec_pose[:,25*3:40*3], + right_hand_pose=rec_pose[:,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + vertices_tar = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=tar_pose[:, 66:69], + global_orient=tar_pose[:,:3], + body_pose=tar_pose[:,3:21*3+3], + left_hand_pose=tar_pose[:,25*3:40*3], + right_hand_pose=tar_pose[:,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + vectices_loss = self.vectices_loss(vertices_rec['vertices'], vertices_tar['vertices']) + self.tracker.update_meter("ver", "train", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight + + vertices_vel_loss = self.vel_loss(vertices_rec['vertices'][:, 1:] - vertices_rec['vertices'][:, :-1], vertices_tar['vertices'][:, 1:] - vertices_tar['vertices'][:, :-1]) * self.args.rec_weight + vertices_acc_loss = self.vel_loss(vertices_rec['vertices'][:, 2:] + vertices_rec['vertices'][:, :-2] - 2 * vertices_rec['vertices'][:, 1:-1], vertices_tar['vertices'][:, 2:] + vertices_tar['vertices'][:, :-2] - 2 * vertices_tar['vertices'][:, 1:-1]) * self.args.rec_weight + g_loss_final += vertices_vel_loss * self.args.rec_weight * self.args.rec_ver_weight + g_loss_final += vertices_acc_loss * self.args.rec_weight * self.args.rec_ver_weight + + # if self.args.vel_weight > 0: + # pos_rec_vel = other_tools.estimate_linear_velocity(vertices_rec['joints'], 1/self.pose_fps) + # pos_tar_vel = other_tools.estimate_linear_velocity(vertices_tar['joints'], 1/self.pose_fps) + # vel_rec_loss = self.vel_loss(pos_rec_vel, pos_tar_vel) + # tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + # rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3)) + # rot_rec_vel = other_tools.estimate_angular_velocity(rec_pose, 1/self.pose_fps) + # rot_tar_vel = other_tools.estimate_angular_velocity(tar_pose, 1/self.pose_fps) + # vel_rec_loss += self.vel_loss(pos_rec_vel, pos_tar_vel) + # self.tracker.update_meter("vel", "train", vel_rec_loss.item()*self.args.vel_weight) + # loss += (vel_rec_loss*self.args.vel_weight) + + # ---------------------- vae -------------------------- # + if "VQVAE" in self.args.g_name: + loss_embedding = net_out["embedding_loss"] + g_loss_final += loss_embedding + self.tracker.update_meter("com", "train", loss_embedding.item()) + # elif "VAE" in self.args.g_name: + # pose_mu, pose_logvar = net_out["pose_mu"], net_out["pose_logvar"] + # KLD = -0.5 * torch.sum(1 + pose_logvar - pose_mu.pow(2) - pose_logvar.exp()) + # if epoch < 0: + # KLD_weight = 0 + # else: + # KLD_weight = min(1.0, (epoch - 0) * 0.05) * 0.01 + # loss += KLD_weight * KLD + # self.tracker.update_meter("kl", "train", KLD_weight * KLD.item()) + g_loss_final.backward() + if self.args.grad_norm != 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm) + self.opt.step() + t_train = time.time() - t_start - t_data + t_start = time.time() + mem_cost = torch.cuda.memory_cached() / 1E9 + lr_g = self.opt.param_groups[0]['lr'] + if its % self.args.log_period == 0: + self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g) + if self.args.debug: + if its == 1: break + self.opt_s.step(epoch) + + def val(self, epoch): + self.model.eval() + t_start = time.time() + with torch.no_grad(): + for its, dict_data in enumerate(self.val_loader): + tar_pose = dict_data["pose"] + tar_beta = dict_data["beta"].cuda() + tar_trans = dict_data["trans"].cuda() + tar_pose = tar_pose.cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_exps = torch.zeros((bs, n, 100)).cuda() + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + t_data = time.time() - t_start + + #self.opt.zero_grad() + #g_loss_final = 0 + net_out = self.model(tar_pose) + rec_pose = net_out["rec_pose"] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("rec", "val", loss_rec.item()) + #g_loss_final += loss_rec + + # vertices loss + if self.args.rec_ver_weight > 0: + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0]) + tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0]) + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:,:3], + body_pose=rec_pose[:,3:21*3+3], + left_hand_pose=rec_pose[:,25*3:40*3], + right_hand_pose=rec_pose[:,40*3:55*3], + return_verts=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + vertices_tar = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=tar_pose[:, 66:69], + global_orient=tar_pose[:,:3], + body_pose=tar_pose[:,3:21*3+3], + left_hand_pose=tar_pose[:,25*3:40*3], + right_hand_pose=tar_pose[:,40*3:55*3], + return_verts=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + vectices_loss = self.vectices_loss(vertices_rec['vertices'], vertices_tar['vertices']) + self.tracker.update_meter("ver", "val", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + if "VQVAE" in self.args.g_name: + loss_embedding = net_out["embedding_loss"] + self.tracker.update_meter("com", "val", loss_embedding.item()) + #g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight + self.val_recording(epoch) + + def test(self, epoch): + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + return 0 + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + self.model.eval() + with torch.no_grad(): + for its, dict_data in enumerate(self.test_loader): + tar_pose = dict_data["pose"] + tar_pose = tar_pose.cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + remain = n%self.args.pose_length + tar_pose = tar_pose[:, :n-remain, :] + #print(tar_pose.shape) + if True: + net_out = self.model(tar_pose) + rec_pose = net_out["rec_pose"] + n = rec_pose.shape[1] + tar_pose = tar_pose[:, :n, :] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = rec_pose.cpu().numpy() + else: + pass +# for i in range(tar_pose.shape[1]//(self.args.vae_test_len)): +# tar_pose_new = tar_pose[:,i*(self.args.vae_test_len):i*(self.args.vae_test_len)+self.args.vae_test_len,:] +# net_out = self.model(**dict(inputs=tar_pose_new)) +# rec_pose = net_out["rec_pose"] +# rec_pose = (rec_pose.reshape(rec_pose.shape[0], rec_pose.shape[1], -1, 6) * self.joint_level_mask_cuda).reshape(rec_pose.shape[0], rec_pose.shape[1], -1) +# if "rot6d" in self.args.pose_rep: +# rec_pose = data_transfer.rotation_6d_to_matrix(rec_pose.reshape(tar_pose.shape[0], self.args.vae_test_len, -1, 6)) +# rec_pose = data_transfer.matrix_to_euler_angles(rec_pose, "XYZ").reshape(rec_pose.shape[0], rec_pose.shape[1], -1) +# if "smplx" not in self.args.pose_rep: +# rec_pose = torch.rad2deg(rec_pose) +# rec_pose = rec_pose * self.joint_mask_cuda + +# out_sub = rec_pose.cpu().numpy().reshape(-1, rec_pose.shape[2]) +# if i != 0: +# out_final = np.concatenate((out_final,out_sub), 0) +# else: +# out_final = out_sub + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + tar_pose = tar_pose.cpu().numpy() + + total_length += n + # --- save --- # + if 'smplx' in self.args.pose_rep: + gt_npz = np.load(self.args.data_path+self.args.pose_rep+"/"+test_seq_list.iloc[its]['id']+'.npz', allow_pickle=True) + stride = int(30 / self.args.pose_fps) + tar_pose = self.inverse_selection(tar_pose, self.test_data.joint_mask, tar_pose.shape[0]) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose[:n], + expressions=gt_npz["expressions"]-gt_npz["expressions"], + trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n], + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + rec_pose = self.inverse_selection(rec_pose, self.test_data.joint_mask, rec_pose.shape[0]) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose, + expressions=gt_npz["expressions"]-gt_npz["expressions"], + trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n], + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + else: + rec_pose = rc.axis_angle_to_matrix(torch.from_numpy(rec_pose.reshape(bs*n, j, 3))) + rec_pose = np.rad2deg(rc.matrix_to_euler_angles(rec_pose, "XYZ")).reshape(bs*n, j*3).numpy() + tar_pose = rc.axis_angle_to_matrix(torch.from_numpy(tar_pose.reshape(bs*n, j, 3))) + tar_pose = np.rad2deg(rc.matrix_to_euler_angles(tar_pose, "XYZ")).reshape(bs*n, j*3).numpy() + #trans="0.000000 0.000000 0.000000" + + with open(f"{self.args.data_path}{self.args.pose_rep}/{test_seq_list.iloc[its]['id']}.bvh", "r") as f_demo: + with open(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_gt: + with open(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_real: + for i, line_data in enumerate(f_demo.readlines()): + if i < 431: + f_real.write(line_data) + f_gt.write(line_data) + else: break + for line_id in range(n): #,args.pre_frames, args.pose_length + line_data = np.array2string(rec_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + f_real.write(line_data[1:-2]+'\n') + for line_id in range(n): #,args.pre_frames, args.pose_length + line_data = np.array2string(tar_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + f_gt.write(line_data[1:-2]+'\n') + # with open(results_save_path+"gt_"+test_seq_list[its]+'.pkl', 'wb') as fw: + # pickle.dump(new_dict, fw) + # #new_dict2["fullpose"] = out_final + # with open(results_save_path+"res_"+test_seq_list[its]+'.pkl', 'wb') as fw1: + # pickle.dump(new_dict2, fw1) + + # other_tools.render_one_sequence( + # results_save_path+"res_"+test_seq_list[its]+'.pkl', + # results_save_path+"gt_"+test_seq_list[its]+'.pkl', + # results_save_path, + # self.args.data_path + self.args.test_data_path + 'wave16k/' + test_seq_list[its]+'.npy', + # ) + + #if its == 1:break + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") \ No newline at end of file diff --git a/aeface_trainer.py b/aeface_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c174bc5c94d1240b4b1494ee57922a3c903ae73 --- /dev/null +++ b/aeface_trainer.py @@ -0,0 +1,388 @@ +import train +import os +import time +import csv +import sys +import warnings +import random +import numpy as np +import time +import pprint +import pickle + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP +from loguru import logger +import smplx + +from utils import config, logger_tools, other_tools, metric +from utils import rotation_conversions as rc +from dataloaders import data_tools +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from scipy.spatial.transform import Rotation + + +class CustomTrainer(train.BaseTrainer): + """ + motion representation learning + """ + def __init__(self, args): + super().__init__(args) + self.joints = self.train_data.joints + self.tracker = other_tools.EpochTracker(["rec", "vel", "acc", "com", "face", "face_vel", "face_acc", "ver", "ver_vel", "ver_acc"], [False, False, False, False, False, False, False, False, False, False]) + self.rec_loss = get_loss_func("GeodesicLoss") + self.mse_loss = torch.nn.MSELoss(reduction='mean') + self.vel_loss = torch.nn.MSELoss(reduction='mean') #torch.nn.L1Loss(reduction='mean') + self.vectices_loss = torch.nn.MSELoss(reduction='mean') + + def inverse_selection(self, filtered_t, selection_array, n): + # 创建一个全为零的数组,形状为 n*165 + original_shape_t = np.zeros((n, selection_array.size)) + + # 找到选择数组中为1的索引位置 + selected_indices = np.where(selection_array == 1)[0] + + # 将 filtered_t 的值填充到 original_shape_t 中相应的位置 + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + + return original_shape_t + + def train(self, epoch): + self.model.train() + t_start = time.time() + self.tracker.reset() + for its, dict_data in enumerate(self.train_loader): + tar_pose = dict_data["pose"] + tar_beta = dict_data["beta"].cuda() + tar_trans = dict_data["trans"].cuda() + tar_pose = tar_pose.cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_exps = dict_data["facial"].to(self.rank) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + in_tar_pose = torch.cat([tar_pose, tar_exps], -1) # 103 + t_data = time.time() - t_start + + self.opt.zero_grad() + g_loss_final = 0 + net_out = self.model(in_tar_pose) + # jaw open 6d loss + rec_pose = net_out["rec_pose"][:, :, :j*6] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("rec", "train", loss_rec.item()) + g_loss_final += loss_rec + # jaw open 6d vel and acc loss + velocity_loss = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight + acceleration_loss = self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight + self.tracker.update_meter("vel", "train", velocity_loss.item()) + self.tracker.update_meter("acc", "train", acceleration_loss.item()) + g_loss_final += velocity_loss + g_loss_final += acceleration_loss + # face parameter l1 loss + rec_exps = net_out["rec_pose"][:, :, j*6:] + loss_face = self.mse_loss(rec_exps, tar_exps) * self.args.rec_weight + self.tracker.update_meter("face", "train", loss_face.item()) + g_loss_final += loss_face + # face parameter l1 vel and acc loss + face_velocity_loss = self.vel_loss(rec_exps[:, 1:] - rec_exps[:, :-1], tar_exps[:, 1:] - tar_exps[:, :-1]) * self.args.rec_weight + face_acceleration_loss = self.vel_loss(rec_exps[:, 2:] + rec_exps[:, :-2] - 2 * rec_exps[:, 1:-1], tar_exps[:, 2:] + tar_exps[:, :-2] - 2 * tar_exps[:, 1:-1]) * self.args.rec_weight + self.tracker.update_meter("face_vel", "train", face_velocity_loss.item()) + self.tracker.update_meter("face_acc", "train", face_acceleration_loss.item()) + g_loss_final += face_velocity_loss + g_loss_final += face_acceleration_loss + + # vertices loss + if self.args.rec_ver_weight > 0: + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose, + global_orient=torch.zeros(bs*n, 3).cuda(), + body_pose=torch.zeros(bs*n, 21*3).cuda(), + left_hand_pose=torch.zeros(bs*n, 15*3).cuda(), + right_hand_pose=torch.zeros(bs*n, 15*3).cuda(), + return_verts=True, + # return_joints=True, + leye_pose=torch.zeros(bs*n, 3).cuda(), + reye_pose=torch.zeros(bs*n, 3).cuda(), + ) + vertices_tar = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), + expression=rec_exps.reshape(bs*n, 100), + jaw_pose=tar_pose, + global_orient=torch.zeros(bs*n, 3).cuda(), + body_pose=torch.zeros(bs*n, 21*3).cuda(), + left_hand_pose=torch.zeros(bs*n, 15*3).cuda(), + right_hand_pose=torch.zeros(bs*n, 15*3).cuda(), + return_verts=True, + # return_joints=True, + leye_pose=torch.zeros(bs*n, 3).cuda(), + reye_pose=torch.zeros(bs*n, 3).cuda(), + ) + vectices_loss = self.mse_loss(vertices_rec['vertices'], vertices_tar['vertices']) + self.tracker.update_meter("ver", "train", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight + # vertices vel and acc loss + vert_velocity_loss = self.vel_loss(vertices_rec['vertices'][:, 1:] - vertices_rec['vertices'][:, :-1], vertices_tar['vertices'][:, 1:] - vertices_tar['vertices'][:, :-1]) * self.args.rec_weight * self.args.rec_ver_weight + vert_acceleration_loss = self.vel_loss(vertices_rec['vertices'][:, 2:] + vertices_rec['vertices'][:, :-2] - 2 * vertices_rec['vertices'][:, 1:-1], vertices_tar['vertices'][:, 2:] + vertices_tar['vertices'][:, :-2] - 2 * vertices_tar['vertices'][:, 1:-1]) * self.args.rec_weight * self.args.rec_ver_weight + self.tracker.update_meter("ver_vel", "train", vert_velocity_loss.item()) + self.tracker.update_meter("ver_acc", "train", vert_acceleration_loss.item()) + g_loss_final += vert_velocity_loss + g_loss_final += vert_acceleration_loss + + # ---------------------- vae -------------------------- # + if "VQVAE" in self.args.g_name: + loss_embedding = net_out["embedding_loss"] + g_loss_final += loss_embedding + self.tracker.update_meter("com", "train", loss_embedding.item()) + # elif "VAE" in self.args.g_name: + # pose_mu, pose_logvar = net_out["pose_mu"], net_out["pose_logvar"] + # KLD = -0.5 * torch.sum(1 + pose_logvar - pose_mu.pow(2) - pose_logvar.exp()) + # if epoch < 0: + # KLD_weight = 0 + # else: + # KLD_weight = min(1.0, (epoch - 0) * 0.05) * 0.01 + # loss += KLD_weight * KLD + # self.tracker.update_meter("kl", "train", KLD_weight * KLD.item()) + g_loss_final.backward() + if self.args.grad_norm != 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm) + self.opt.step() + t_train = time.time() - t_start - t_data + t_start = time.time() + mem_cost = torch.cuda.memory_cached() / 1E9 + lr_g = self.opt.param_groups[0]['lr'] + if its % self.args.log_period == 0: + self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g) + if self.args.debug: + if its == 1: break + self.opt_s.step(epoch) + + def val(self, epoch): + self.model.eval() + t_start = time.time() + with torch.no_grad(): + for its, dict_data in enumerate(self.val_loader): + tar_pose = dict_data["pose"] + tar_beta = dict_data["beta"].cuda() + tar_trans = dict_data["trans"].cuda() + tar_pose = tar_pose.cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_exps = dict_data["facial"].to(self.rank) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + in_tar_pose = torch.cat([tar_pose, tar_exps], -1) # 103 + # print(tar_pose.shape, in_tar_pose.shape, tar_exps.shape) + t_data = time.time() - t_start + + #self.opt.zero_grad() + #g_loss_final = 0 + net_out = self.model(in_tar_pose) + # jaw open 6d loss + rec_pose = net_out["rec_pose"][:, :, :j*6] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("rec", "val", loss_rec.item()) + # g_loss_final += loss_rec + # jaw open 6d vel and acc loss + velocity_loss = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight + acceleration_loss = self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight + self.tracker.update_meter("vel", "val", velocity_loss.item()) + self.tracker.update_meter("acc", "val", acceleration_loss.item()) + # g_loss_final += velocity_loss + # g_loss_final += acceleration_loss + # face parameter l1 loss + rec_exps = net_out["rec_pose"][:, :, j*6:] + loss_face = self.vel_loss(rec_exps, tar_exps) * self.args.rec_weight + self.tracker.update_meter("face", "val", loss_face.item()) + # g_loss_final += loss_face + # face parameter l1 vel and acc loss + face_velocity_loss = self.vel_loss(rec_exps[:, 1:] - rec_exps[:, :-1], tar_exps[:, 1:] - tar_exps[:, :-1]) * self.args.rec_weight + face_acceleration_loss = self.vel_loss(rec_exps[:, 2:] + rec_exps[:, :-2] - 2 * rec_exps[:, 1:-1], tar_exps[:, 2:] + tar_exps[:, :-2] - 2 * tar_exps[:, 1:-1]) * self.args.rec_weight + self.tracker.update_meter("face_vel", "val", face_velocity_loss.item()) + self.tracker.update_meter("face_acc", "val", face_acceleration_loss.item()) + # g_loss_final += face_velocity_loss + # g_loss_final += face_acceleration_loss + + # vertices loss + if self.args.rec_ver_weight > 0: + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose, + global_orient=torch.zeros(bs*n, 3).cuda(), + body_pose=torch.zeros(bs*n, 21*3).cuda(), + left_hand_pose=torch.zeros(bs*n, 15*3).cuda(), + right_hand_pose=torch.zeros(bs*n, 15*3).cuda(), + return_verts=True, + # return_joints=True, + leye_pose=torch.zeros(bs*n, 3).cuda(), + reye_pose=torch.zeros(bs*n, 3).cuda(), + ) + vertices_tar = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), + expression=rec_exps.reshape(bs*n, 100), + jaw_pose=tar_pose, + global_orient=torch.zeros(bs*n, 3).cuda(), + body_pose=torch.zeros(bs*n, 21*3).cuda(), + left_hand_pose=torch.zeros(bs*n, 15*3).cuda(), + right_hand_pose=torch.zeros(bs*n, 15*3).cuda(), + return_verts=True, + # return_joints=True, + leye_pose=torch.zeros(bs*n, 3).cuda(), + reye_pose=torch.zeros(bs*n, 3).cuda(), + ) + vectices_loss = self.mse_loss(vertices_rec['vertices'], vertices_tar['vertices']) + self.tracker.update_meter("ver", "val", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + # g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight + # vertices vel and acc loss + vert_velocity_loss = self.vel_loss(vertices_rec['vertices'][:, 1:] - vertices_rec['vertices'][:, :-1], vertices_tar['vertices'][:, 1:] - vertices_tar['vertices'][:, :-1]) * self.args.rec_weight * self.args.rec_ver_weight + vert_acceleration_loss = self.vel_loss(vertices_rec['vertices'][:, 2:] + vertices_rec['vertices'][:, :-2] - 2 * vertices_rec['vertices'][:, 1:-1], vertices_tar['vertices'][:, 2:] + vertices_tar['vertices'][:, :-2] - 2 * vertices_tar['vertices'][:, 1:-1]) * self.args.rec_weight * self.args.rec_ver_weight + self.tracker.update_meter("ver_vel", "val", vert_velocity_loss.item()) + self.tracker.update_meter("ver_acc", "val", vert_acceleration_loss.item()) + # g_loss_final += vert_velocity_loss + # g_loss_final += vert_acceleration_loss + if "VQVAE" in self.args.g_name: + loss_embedding = net_out["embedding_loss"] + self.tracker.update_meter("com", "val", loss_embedding.item()) + #g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight + self.val_recording(epoch) + + def test(self, epoch): + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + return 0 + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + self.model.eval() + with torch.no_grad(): + for its, dict_data in enumerate(self.test_loader): + tar_pose = dict_data["pose"] + tar_pose = tar_pose.cuda() + tar_exps = dict_data["facial"].to(self.rank) + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + remain = n%self.args.pose_length + tar_pose = tar_pose[:, :n-remain, :] + # print(tar_exps.shape) + in_tar_pose = torch.cat([tar_pose, tar_exps[:, :n-remain, :]], -1) # 103 + #print(tar_pose.shape) + if True: + net_out = self.model(in_tar_pose) + rec_pose = net_out["rec_pose"][:, :, :j*6] + n = rec_pose.shape[1] + tar_pose = tar_pose[:, :n, :] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = rec_pose.cpu().numpy() + rec_exps = net_out["rec_pose"][:, :, j*6:] + rec_exps = rec_exps.cpu().numpy().reshape(bs*n, 100) + else: + pass +# for i in range(tar_pose.shape[1]//(self.args.vae_test_len)): +# tar_pose_new = tar_pose[:,i*(self.args.vae_test_len):i*(self.args.vae_test_len)+self.args.vae_test_len,:] +# net_out = self.model(**dict(inputs=tar_pose_new)) +# rec_pose = net_out["rec_pose"] +# rec_pose = (rec_pose.reshape(rec_pose.shape[0], rec_pose.shape[1], -1, 6) * self.joint_level_mask_cuda).reshape(rec_pose.shape[0], rec_pose.shape[1], -1) +# if "rot6d" in self.args.pose_rep: +# rec_pose = data_transfer.rotation_6d_to_matrix(rec_pose.reshape(tar_pose.shape[0], self.args.vae_test_len, -1, 6)) +# rec_pose = data_transfer.matrix_to_euler_angles(rec_pose, "XYZ").reshape(rec_pose.shape[0], rec_pose.shape[1], -1) +# if "smplx" not in self.args.pose_rep: +# rec_pose = torch.rad2deg(rec_pose) +# rec_pose = rec_pose * self.joint_mask_cuda + +# out_sub = rec_pose.cpu().numpy().reshape(-1, rec_pose.shape[2]) +# if i != 0: +# out_final = np.concatenate((out_final,out_sub), 0) +# else: +# out_final = out_sub + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + tar_pose = tar_pose.cpu().numpy() + + total_length += n + # --- save --- # + if 'smplx' in self.args.pose_rep: + gt_npz = np.load(self.args.data_path+self.args.pose_rep+"/"+test_seq_list.iloc[its]['id']+'.npz', allow_pickle=True) + stride = int(30 / self.args.pose_fps) + tar_pose = self.inverse_selection(tar_pose, self.test_data.joint_mask, tar_pose.shape[0]) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose[:n], + expressions=gt_npz["expressions"], + trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n], + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + rec_pose = self.inverse_selection(rec_pose, self.test_data.joint_mask, rec_pose.shape[0]) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose, + expressions=rec_exps, + trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n], + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + else: + rec_pose = rc.axis_angle_to_matrix(torch.from_numpy(rec_pose.reshape(bs*n, j, 3))) + rec_pose = np.rad2deg(rc.matrix_to_euler_angles(rec_pose, "XYZ")).reshape(bs*n, j*3).numpy() + tar_pose = rc.axis_angle_to_matrix(torch.from_numpy(tar_pose.reshape(bs*n, j, 3))) + tar_pose = np.rad2deg(rc.matrix_to_euler_angles(tar_pose, "XYZ")).reshape(bs*n, j*3).numpy() + #trans="0.000000 0.000000 0.000000" + + with open(f"{self.args.data_path}{self.args.pose_rep}/{test_seq_list.iloc[its]['id']}.bvh", "r") as f_demo: + with open(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_gt: + with open(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_real: + for i, line_data in enumerate(f_demo.readlines()): + if i < 431: + f_real.write(line_data) + f_gt.write(line_data) + else: break + for line_id in range(n): #,args.pre_frames, args.pose_length + line_data = np.array2string(rec_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + f_real.write(line_data[1:-2]+'\n') + for line_id in range(n): #,args.pre_frames, args.pose_length + line_data = np.array2string(tar_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + f_gt.write(line_data[1:-2]+'\n') + # with open(results_save_path+"gt_"+test_seq_list[its]+'.pkl', 'wb') as fw: + # pickle.dump(new_dict, fw) + # #new_dict2["fullpose"] = out_final + # with open(results_save_path+"res_"+test_seq_list[its]+'.pkl', 'wb') as fw1: + # pickle.dump(new_dict2, fw1) + + # other_tools.render_one_sequence( + # results_save_path+"res_"+test_seq_list[its]+'.pkl', + # results_save_path+"gt_"+test_seq_list[its]+'.pkl', + # results_save_path, + # self.args.data_path + self.args.test_data_path + 'wave16k/' + test_seq_list[its]+'.npy', + # ) + + #if its == 1:break + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") \ No newline at end of file diff --git a/aelower_trainer.py b/aelower_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3af3d647bcf63cd15e859cff3e619ac441e087e5 --- /dev/null +++ b/aelower_trainer.py @@ -0,0 +1,494 @@ +import train +import os +import time +import csv +import sys +import warnings +import random +import numpy as np +import time +import pprint +import pickle + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP +from loguru import logger +import smplx + +from utils import config, logger_tools, other_tools, metric +from utils import rotation_conversions as rc +from dataloaders import data_tools +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from scipy.spatial.transform import Rotation + + +class CustomTrainer(train.BaseTrainer): + """ + motion representation learning + """ + def __init__(self, args): + super().__init__(args) + self.joints = self.train_data.joints + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + self.tracker = other_tools.EpochTracker(["rec", "contact", "vel", "foot", "ver", "com", "kl", "acc", "trans", "transv"], [False,False, False, False, False, False, False, False, False, False]) + if not self.args.rot6d: #"rot6d" not in args.pose_rep: + logger.error(f"this script is for rot6d, your pose rep. is {args.pose_rep}") + self.rec_loss = get_loss_func("GeodesicLoss") + self.vel_loss = torch.nn.L1Loss(reduction='mean') + self.vectices_loss = torch.nn.MSELoss(reduction='mean') + + def inverse_selection(self, filtered_t, selection_array, n): + # 创建一个全为零的数组,形状为 n*165 + original_shape_t = np.zeros((n, selection_array.size)) + + # 找到选择数组中为1的索引位置 + selected_indices = np.where(selection_array == 1)[0] + + # 将 filtered_t 的值填充到 original_shape_t 中相应的位置 + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + # 创建一个全为零的数组,形状为 n*165 + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + + # 找到选择数组中为1的索引位置 + selected_indices = torch.where(selection_array == 1)[0] + + # 将 filtered_t 的值填充到 original_shape_t 中相应的位置 + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + + return original_shape_t + + + def train(self, epoch): + self.model.train() + t_start = time.time() + self.tracker.reset() + for its, dict_data in enumerate(self.train_loader): + tar_pose_raw = dict_data["pose"] + tar_beta = dict_data["beta"].cuda() + tar_trans = dict_data["trans"].cuda() + tar_trans_vel_x = other_tools.estimate_linear_velocity(tar_trans[:, :, 0:1], dt=1/self.args.pose_fps) + tar_trans_vel_z = other_tools.estimate_linear_velocity(tar_trans[:, :, 2:3], dt=1/self.args.pose_fps) + tar_pose = tar_pose_raw[:, :, :27].cuda() + tar_contact = tar_pose_raw[:, :, 27:31].cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_exps = torch.zeros((bs, n, 100)).cuda() + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + tar_trans_copy = tar_trans-tar_trans + tar_contact_copy = tar_contact-tar_contact + in_tar_pose = torch.cat((tar_pose, tar_trans_copy, tar_contact_copy), dim=-1) + + t_data = time.time() - t_start + + self.opt.zero_grad() + g_loss_final = 0 + net_out = self.model(in_tar_pose) + rec_pose = tar_pose#net_out["rec_pose"][:, :, :j*6] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + # loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight + # self.tracker.update_meter("rec", "train", loss_rec.item()) + # g_loss_final += loss_rec + + rec_contact = net_out["rec_pose"][:, :, j*6+3:j*6+7] + loss_contact = self.vectices_loss(rec_contact, tar_contact) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("contact", "train", loss_contact.item()) + g_loss_final += loss_contact + + # velocity_loss = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight + # acceleration_loss = self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight + # self.tracker.update_meter("vel", "train", velocity_loss.item()) + # self.tracker.update_meter("acc", "train", acceleration_loss.item()) + # g_loss_final += velocity_loss + # g_loss_final += acceleration_loss + + rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] + rec_x_trans = other_tools.velocity2position(rec_trans[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans[:,:,1:2] + rec_xyz_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + loss_trans_vel = self.vel_loss(rec_trans[:, :, 0:1], tar_trans_vel_x) * self.args.rec_weight \ + + self.vel_loss(rec_trans[:, :, 2:3], tar_trans_vel_z) * self.args.rec_weight + v3 = self.vel_loss(rec_trans[:, :, 0:1][:, 1:] - rec_trans[:, :, 0:1][:, :-1], tar_trans_vel_x[:, 1:] - tar_trans_vel_x[:, :-1]) * self.args.rec_weight \ + + self.vel_loss(rec_trans[:, :, 2:3][:, 1:] - rec_trans[:, :, 2:3][:, :-1], tar_trans_vel_z[:, 1:] - tar_trans_vel_z[:, :-1]) * self.args.rec_weight + a3 = self.vel_loss(rec_trans[:, :, 0:1][:, 2:] + rec_trans[:, :, 0:1][:, :-2] - 2 * rec_trans[:, :, 0:1][:, 1:-1], tar_trans_vel_x[:, 2:] + tar_trans_vel_x[:, :-2] - 2 * tar_trans_vel_x[:, 1:-1]) * self.args.rec_weight \ + + self.vel_loss(rec_trans[:, :, 2:3][:, 2:] + rec_trans[:, :, 2:3][:, :-2] - 2 * rec_trans[:, :, 2:3][:, 1:-1], tar_trans_vel_z[:, 2:] + tar_trans_vel_z[:, :-2] - 2 * tar_trans_vel_z[:, 1:-1]) * self.args.rec_weight + g_loss_final += 5*v3 + g_loss_final += 5*a3 + v2 = self.vel_loss(rec_xyz_trans[:, 1:] - rec_xyz_trans[:, :-1], tar_trans[:, 1:] - tar_trans[:, :-1]) * self.args.rec_weight + a2 = self.vel_loss(rec_xyz_trans[:, 2:] + rec_xyz_trans[:, :-2] - 2 * rec_xyz_trans[:, 1:-1], tar_trans[:, 2:] + tar_trans[:, :-2] - 2 * tar_trans[:, 1:-1]) * self.args.rec_weight + g_loss_final += 5*v2 + g_loss_final += 5*a2 + self.tracker.update_meter("transv", "train", loss_trans_vel.item()) + g_loss_final += loss_trans_vel + loss_trans = self.vel_loss(rec_xyz_trans, tar_trans) * self.args.rec_weight + self.tracker.update_meter("trans", "train", loss_trans.item()) + g_loss_final += loss_trans + + # vertices loss + if self.args.rec_ver_weight > 0: + # print(tar_pose.shape, j) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0]) + tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0]) + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=rec_xyz_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:,:3], + body_pose=rec_pose[:,3:21*3+3], + left_hand_pose=rec_pose[:,25*3:40*3], + right_hand_pose=rec_pose[:,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + vertices_tar = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=tar_pose[:, 66:69], + global_orient=tar_pose[:,:3], + body_pose=tar_pose[:,3:21*3+3], + left_hand_pose=tar_pose[:,25*3:40*3], + right_hand_pose=tar_pose[:,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + joints_rec = vertices_rec['joints'] + # print(joints_rec.shape) + joints_rec = joints_rec.reshape(bs, n, -1, 3) + vectices_loss = self.vectices_loss(vertices_rec['vertices'], vertices_tar['vertices']) + vertices_vel_loss = self.vectices_loss( + vertices_rec['vertices'][:, 1:] - vertices_rec['vertices'][:, :-1], + vertices_tar['vertices'][:, 1:] - vertices_tar['vertices'][:, :-1]) + vertices_acc_loss = self.vectices_loss( + vertices_rec['vertices'][:, 2:] + vertices_rec['vertices'][:, :-2] - 2 * vertices_rec['vertices'][:, 1:-1], + vertices_tar['vertices'][:, 2:] + vertices_tar['vertices'][:, :-2] - 2 * vertices_tar['vertices'][:, 1:-1]) + foot_idx = [7, 8, 10, 11] + model_contact = net_out["rec_pose"][:, :, j*6+3:j*6+7] + # find static indices consistent with model's own predictions + static_idx = model_contact > 0.95 # N x S x 4 + # print(model_contact,static_idx) + model_feet = joints_rec[:, :, foot_idx] # foot positions (N, S, 4, 3) + model_foot_v = torch.zeros_like(model_feet) + model_foot_v[:, :-1] = ( + model_feet[:, 1:, :, :] - model_feet[:, :-1, :, :] + ) # (N, S-1, 4, 3) + model_foot_v[~static_idx] = 0 + foot_loss = self.vel_loss( + model_foot_v, torch.zeros_like(model_foot_v) + ) + self.tracker.update_meter("foot", "train", foot_loss.item()*self.args.rec_weight * self.args.rec_ver_weight*1000) + self.tracker.update_meter("ver", "train", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + g_loss_final += (vectices_loss+5*vertices_vel_loss+5*vertices_acc_loss)*self.args.rec_weight*self.args.rec_ver_weight + g_loss_final += foot_loss*self.args.rec_weight*self.args.rec_ver_weight*20 + + # ---------------------- vae -------------------------- # + if "VQVAE" in self.args.g_name: + loss_embedding = net_out["embedding_loss"] + g_loss_final += loss_embedding + self.tracker.update_meter("com", "train", loss_embedding.item()) + # elif "VAE" in self.args.g_name: + # pose_mu, pose_logvar = net_out["pose_mu"], net_out["pose_logvar"] + # KLD = -0.5 * torch.sum(1 + pose_logvar - pose_mu.pow(2) - pose_logvar.exp()) + # if epoch < 0: + # KLD_weight = 0 + # else: + # KLD_weight = min(1.0, (epoch - 0) * 0.05) * 0.01 + # loss += KLD_weight * KLD + # self.tracker.update_meter("kl", "train", KLD_weight * KLD.item()) + g_loss_final.backward() + if self.args.grad_norm != 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm) + self.opt.step() + t_train = time.time() - t_start - t_data + t_start = time.time() + mem_cost = torch.cuda.memory_cached() / 1E9 + lr_g = self.opt.param_groups[0]['lr'] + if its % self.args.log_period == 0: + self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g) + if self.args.debug: + if its == 1: break + self.opt_s.step(epoch) + + def val(self, epoch): + self.model.eval() + t_start = time.time() + with torch.no_grad(): + for its, dict_data in enumerate(self.val_loader): + tar_pose_raw = dict_data["pose"] + tar_beta = dict_data["beta"].cuda() + tar_trans = dict_data["trans"].cuda() + tar_trans_vel_x = other_tools.estimate_linear_velocity(tar_trans[:, :, 0:1], dt=1/self.args.pose_fps) + tar_trans_vel_z = other_tools.estimate_linear_velocity(tar_trans[:, :, 2:3], dt=1/self.args.pose_fps) + #print(tar_pose.shape) + tar_pose = tar_pose_raw[:, :, :27].cuda() + + tar_contact = tar_pose_raw[:, :, 27:31].cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_exps = torch.zeros((bs, n, 100)).cuda() + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + tar_trans_copy = tar_trans-tar_trans + tar_contact_copy = tar_contact-tar_contact + in_tar_pose = torch.cat((tar_pose, tar_trans_copy, tar_contact_copy), dim=-1) + t_data = time.time() - t_start + + #self.opt.zero_grad() + #g_loss_final = 0 + net_out = self.model(in_tar_pose) + rec_pose = tar_pose + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + # loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight + # self.tracker.update_meter("rec", "val", loss_rec.item()) + rec_contact = net_out["rec_pose"][:, :, j*6+3:j*6+7] + # print(rec_contact.shape, tar_contact.shape) + loss_contact = self.vel_loss(rec_contact, tar_contact) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("contact", "val", loss_contact.item()) + #g_loss_final += loss_rec + # rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] + # rec_x_trans = other_tools.velocity2position(rec_trans[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + # rec_z_trans = other_tools.velocity2position(rec_trans[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + # rec_y_trans = rec_trans[:,:,1:2] + # rec_xyz_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + + rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] + rec_x_trans = other_tools.velocity2position(rec_trans[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans[:,:,1:2] + rec_xyz_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + loss_trans_vel = self.vel_loss(rec_trans[:, :, 0:1], tar_trans_vel_x) * self.args.rec_weight \ + + self.vel_loss(rec_trans[:, :, 2:3], tar_trans_vel_z) * self.args.rec_weight + # v3 = self.vel_loss(rec_trans[:, :, 0:1][:, 1:] - rec_trans[:, :, 0:1][:, :-1], tar_trans_vel_x[:, 1:] - tar_trans_vel_x[:, :-1]) * self.args.rec_weight \ + # + self.vel_loss(rec_trans[:, :, 2:3][:, 1:] - rec_trans[:, :, 2:3][:, :-1], tar_trans_vel_z[:, 1:] - tar_trans_vel_z[:, :-1]) * self.args.rec_weight + # a3 = self.vel_loss(rec_trans[:, :, 0:1][:, 2:] + rec_trans[:, :, 0:1][:, :-2] - 2 * rec_trans[:, :, 0:1][:, 1:-1], tar_trans_vel_x[:, 2:] + tar_trans_vel_x[:, :-2] - 2 * tar_trans_vel_x[:, 1:-1]) * self.args.rec_weight \ + # + self.vel_loss(rec_trans[:, :, 2:3][:, 2:] + rec_trans[:, :, 2:3][:, :-2] - 2 * rec_trans[:, :, 2:3][:, 1:-1], tar_trans_vel_z[:, 2:] + tar_trans_vel_z[:, :-2] - 2 * tar_trans_vel_z[:, 1:-1]) * self.args.rec_weight + # #g_loss_final += 5*v3 + # #g_loss_final += 5*a3 + # v2 = self.vel_loss(rec_xyz_trans[:, 1:] - rec_xyz_trans[:, :-1], tar_trans[:, 1:] - tar_trans[:, :-1]) * self.args.rec_weight + # a2 = self.vel_loss(rec_xyz_trans[:, 2:] + rec_xyz_trans[:, :-2] - 2 * rec_xyz_trans[:, 1:-1], tar_trans[:, 2:] + tar_trans[:, :-2] - 2 * tar_trans[:, 1:-1]) * self.args.rec_weight + #g_loss_final += 5*v2 + #g_loss_final += 5*a2 + self.tracker.update_meter("transv", "val", loss_trans_vel.item()) + #g_loss_final += loss_trans_vel + loss_trans = self.vel_loss(rec_xyz_trans, tar_trans) * self.args.rec_weight + self.tracker.update_meter("trans", "val", loss_trans.item()) + #g_loss_final += loss_trans + + # vertices loss + if self.args.rec_ver_weight > 0: + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0]) + tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0]) + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=rec_xyz_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:,:3], + body_pose=rec_pose[:,3:21*3+3], + left_hand_pose=rec_pose[:,25*3:40*3], + right_hand_pose=rec_pose[:,40*3:55*3], + return_verts=False, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + vertices_tar = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=tar_pose[:, 66:69], + global_orient=tar_pose[:,:3], + body_pose=tar_pose[:,3:21*3+3], + left_hand_pose=tar_pose[:,25*3:40*3], + right_hand_pose=tar_pose[:,40*3:55*3], + return_verts=False, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + joints_rec = vertices_rec['joints'] + joints_rec = joints_rec.reshape(bs, n, -1, 3) + vectices_loss = self.vectices_loss(vertices_rec['joints'], vertices_tar['joints']) + foot_idx = [7, 8, 10, 11] + model_contact = net_out["rec_pose"][:, :, j*6+3:j*6+7] + # find static indices consistent with model's own predictions + static_idx = model_contact > 0.95 # N x S x 4 + # print(model_contact) + model_feet = joints_rec[:, :, foot_idx] # foot positions (N, S, 4, 3) + model_foot_v = torch.zeros_like(model_feet) + model_foot_v[:, :-1] = ( + model_feet[:, 1:, :, :] - model_feet[:, :-1, :, :] + ) # (N, S-1, 4, 3) + model_foot_v[~static_idx] = 0 + foot_loss = self.vectices_loss( + model_foot_v, torch.zeros_like(model_foot_v) + ) + self.tracker.update_meter("foot", "val", foot_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + self.tracker.update_meter("ver", "val", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + if "VQVAE" in self.args.g_name: + loss_embedding = net_out["embedding_loss"] + self.tracker.update_meter("com", "val", loss_embedding.item()) + #g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight + self.val_recording(epoch) + + def test(self, epoch): + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + return 0 + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + self.model.eval() + with torch.no_grad(): + for its, dict_data in enumerate(self.test_loader): + tar_pose_raw = dict_data["pose"] + tar_trans = dict_data["trans"].to(self.rank) + tar_pose = tar_pose_raw[:, :, :27].cuda() + tar_contact = tar_pose_raw[:, :, 27:31].cuda() + # tar_pose = tar_pose.cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + remain = n%self.args.pose_length + tar_pose = tar_pose[:, :n-remain, :] + tar_contact = tar_contact[:, :n-remain, :] + tar_trans_copy = tar_trans[:, :n-remain, :]-tar_trans[:, :n-remain, :] + tar_contact_copy = tar_contact-tar_contact + in_tar_pose = torch.cat([tar_pose, tar_trans_copy, tar_contact_copy], dim=-1) + #print(tar_pose.shape) + if True: + net_out = self.model(in_tar_pose) + rec_pose = tar_pose #net_out["rec_pose"][:, :, :j*6] + rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] + # print(rec_trans.shape) + rec_x_trans = other_tools.velocity2position(rec_trans[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + n = rec_pose.shape[1] + rec_trans = rec_trans.cpu().numpy().reshape(bs*n, 3) + tar_pose = tar_pose[:, :n, :] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = rec_pose.cpu().numpy() + else: + pass +# for i in range(tar_pose.shape[1]//(self.args.vae_test_len)): +# tar_pose_new = tar_pose[:,i*(self.args.vae_test_len):i*(self.args.vae_test_len)+self.args.vae_test_len,:] +# net_out = self.model(**dict(inputs=tar_pose_new)) +# rec_pose = net_out["rec_pose"] +# rec_pose = (rec_pose.reshape(rec_pose.shape[0], rec_pose.shape[1], -1, 6) * self.joint_level_mask_cuda).reshape(rec_pose.shape[0], rec_pose.shape[1], -1) +# if "rot6d" in self.args.pose_rep: +# rec_pose = data_transfer.rotation_6d_to_matrix(rec_pose.reshape(tar_pose.shape[0], self.args.vae_test_len, -1, 6)) +# rec_pose = data_transfer.matrix_to_euler_angles(rec_pose, "XYZ").reshape(rec_pose.shape[0], rec_pose.shape[1], -1) +# if "smplx" not in self.args.pose_rep: +# rec_pose = torch.rad2deg(rec_pose) +# rec_pose = rec_pose * self.joint_mask_cuda + +# out_sub = rec_pose.cpu().numpy().reshape(-1, rec_pose.shape[2]) +# if i != 0: +# out_final = np.concatenate((out_final,out_sub), 0) +# else: +# out_final = out_sub + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + tar_pose = tar_pose.cpu().numpy() + + total_length += n + # --- save --- # + if 'smplx' in self.args.pose_rep: + gt_npz = np.load(self.args.data_path+self.args.pose_rep+"/"+test_seq_list.iloc[its]['id']+'.npz', allow_pickle=True) + stride = int(30 / self.args.pose_fps) + tar_pose = self.inverse_selection(tar_pose, self.test_data.joint_mask, tar_pose.shape[0]) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose[:n], + expressions=gt_npz["expressions"]-gt_npz["expressions"], + trans=gt_npz["trans"][::stride][:n], + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + rec_pose = self.inverse_selection(rec_pose, self.test_data.joint_mask, rec_pose.shape[0]) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose, + expressions=gt_npz["expressions"]-gt_npz["expressions"], + trans=rec_trans, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + else: + rec_pose = rc.axis_angle_to_matrix(torch.from_numpy(rec_pose.reshape(bs*n, j, 3))) + rec_pose = np.rad2deg(rc.matrix_to_euler_angles(rec_pose, "XYZ")).reshape(bs*n, j*3).numpy() + tar_pose = rc.axis_angle_to_matrix(torch.from_numpy(tar_pose.reshape(bs*n, j, 3))) + tar_pose = np.rad2deg(rc.matrix_to_euler_angles(tar_pose, "XYZ")).reshape(bs*n, j*3).numpy() + #trans="0.000000 0.000000 0.000000" + + with open(f"{self.args.data_path}{self.args.pose_rep}/{test_seq_list.iloc[its]['id']}.bvh", "r") as f_demo: + with open(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_gt: + with open(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_real: + for i, line_data in enumerate(f_demo.readlines()): + if i < 431: + f_real.write(line_data) + f_gt.write(line_data) + else: break + for line_id in range(n): #,args.pre_frames, args.pose_length + line_data = np.array2string(rec_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + f_real.write(line_data[1:-2]+'\n') + for line_id in range(n): #,args.pre_frames, args.pose_length + line_data = np.array2string(tar_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + f_gt.write(line_data[1:-2]+'\n') + # with open(results_save_path+"gt_"+test_seq_list[its]+'.pkl', 'wb') as fw: + # pickle.dump(new_dict, fw) + # #new_dict2["fullpose"] = out_final + # with open(results_save_path+"res_"+test_seq_list[its]+'.pkl', 'wb') as fw1: + # pickle.dump(new_dict2, fw1) + + # other_tools.render_one_sequence( + # results_save_path+"res_"+test_seq_list[its]+'.pkl', + # results_save_path+"gt_"+test_seq_list[its]+'.pkl', + # results_save_path, + # self.args.data_path + self.args.test_data_path + 'wave16k/' + test_seq_list[its]+'.npy', + # ) + + #if its == 1:break + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") \ No newline at end of file diff --git a/aelowerfoot_trainer.py b/aelowerfoot_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..38c5f11e2dcebea01aae5a049eaa96b5ad327bb8 --- /dev/null +++ b/aelowerfoot_trainer.py @@ -0,0 +1,491 @@ +import train +import os +import time +import csv +import sys +import warnings +import random +import numpy as np +import time +import pprint +import pickle + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP +from loguru import logger +import smplx + +from utils import config, logger_tools, other_tools, metric +from utils import rotation_conversions as rc +from dataloaders import data_tools +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from scipy.spatial.transform import Rotation + + +class CustomTrainer(train.BaseTrainer): + """ + motion representation learning + """ + def __init__(self, args): + super().__init__(args) + self.joints = self.train_data.joints + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + self.tracker = other_tools.EpochTracker(["rec", "contact", "vel", "foot", "ver", "com", "kl", "acc", "trans", "transv"], [False,False, False, False, False, False, False, False, False, False]) + if not self.args.rot6d: #"rot6d" not in args.pose_rep: + logger.error(f"this script is for rot6d, your pose rep. is {args.pose_rep}") + self.rec_loss = get_loss_func("GeodesicLoss") + self.vel_loss = torch.nn.L1Loss(reduction='mean') + self.vectices_loss = torch.nn.MSELoss(reduction='mean') + + def inverse_selection(self, filtered_t, selection_array, n): + # 创建一个全为零的数组,形状为 n*165 + original_shape_t = np.zeros((n, selection_array.size)) + + # 找到选择数组中为1的索引位置 + selected_indices = np.where(selection_array == 1)[0] + + # 将 filtered_t 的值填充到 original_shape_t 中相应的位置 + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + # 创建一个全为零的数组,形状为 n*165 + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + + # 找到选择数组中为1的索引位置 + selected_indices = torch.where(selection_array == 1)[0] + + # 将 filtered_t 的值填充到 original_shape_t 中相应的位置 + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + + return original_shape_t + + + def train(self, epoch): + self.model.train() + t_start = time.time() + self.tracker.reset() + for its, dict_data in enumerate(self.train_loader): + tar_pose_raw = dict_data["pose"] + tar_beta = dict_data["beta"].cuda() + tar_trans = dict_data["trans"].cuda() + # tar_trans_vel_x = other_tools.estimate_linear_velocity(tar_trans[:, :, 0:1], dt=1/self.args.pose_fps) + # tar_trans_vel_z = other_tools.estimate_linear_velocity(tar_trans[:, :, 2:3], dt=1/self.args.pose_fps) + tar_pose = tar_pose_raw[:, :, :27].cuda() + tar_contact = tar_pose_raw[:, :, 27:31].cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_exps = torch.zeros((bs, n, 100)).cuda() + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + tar_trans_copy = tar_trans + tar_contact_copy = tar_contact + in_tar_pose = torch.cat((tar_pose, tar_trans_copy, tar_contact_copy), dim=-1) + + t_data = time.time() - t_start + + self.opt.zero_grad() + g_loss_final = 0 + net_out = self.model(in_tar_pose) + rec_pose = net_out["rec_pose"][:, :, :j*6] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("rec", "train", loss_rec.item()) + g_loss_final += loss_rec + + rec_contact = net_out["rec_pose"][:, :, j*6+3:j*6+7] + loss_contact = self.vectices_loss(rec_contact, tar_contact) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("contact", "train", loss_contact.item()) + g_loss_final += loss_contact + + velocity_loss = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight + acceleration_loss = self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight + self.tracker.update_meter("vel", "train", velocity_loss.item()) + self.tracker.update_meter("acc", "train", acceleration_loss.item()) + g_loss_final += velocity_loss + g_loss_final += acceleration_loss + + # rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] + # rec_x_trans = other_tools.velocity2position(rec_trans[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + # rec_z_trans = other_tools.velocity2position(rec_trans[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + # rec_y_trans = rec_trans[:,:,1:2] + # rec_xyz_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + # loss_trans_vel = self.vel_loss(rec_trans[:, :, 0:1], tar_trans_vel_x) * self.args.rec_weight \ + # + self.vel_loss(rec_trans[:, :, 2:3], tar_trans_vel_z) * self.args.rec_weight + # v3 = self.vel_loss(rec_trans[:, :, 0:1][:, 1:] - rec_trans[:, :, 0:1][:, :-1], tar_trans_vel_x[:, 1:] - tar_trans_vel_x[:, :-1]) * self.args.rec_weight \ + # + self.vel_loss(rec_trans[:, :, 2:3][:, 1:] - rec_trans[:, :, 2:3][:, :-1], tar_trans_vel_z[:, 1:] - tar_trans_vel_z[:, :-1]) * self.args.rec_weight + # a3 = self.vel_loss(rec_trans[:, :, 0:1][:, 2:] + rec_trans[:, :, 0:1][:, :-2] - 2 * rec_trans[:, :, 0:1][:, 1:-1], tar_trans_vel_x[:, 2:] + tar_trans_vel_x[:, :-2] - 2 * tar_trans_vel_x[:, 1:-1]) * self.args.rec_weight \ + # + self.vel_loss(rec_trans[:, :, 2:3][:, 2:] + rec_trans[:, :, 2:3][:, :-2] - 2 * rec_trans[:, :, 2:3][:, 1:-1], tar_trans_vel_z[:, 2:] + tar_trans_vel_z[:, :-2] - 2 * tar_trans_vel_z[:, 1:-1]) * self.args.rec_weight + # g_loss_final += 5*v3 + # g_loss_final += 5*a3 + # v2 = self.vel_loss(rec_xyz_trans[:, 1:] - rec_xyz_trans[:, :-1], tar_trans[:, 1:] - tar_trans[:, :-1]) * self.args.rec_weight + # a2 = self.vel_loss(rec_xyz_trans[:, 2:] + rec_xyz_trans[:, :-2] - 2 * rec_xyz_trans[:, 1:-1], tar_trans[:, 2:] + tar_trans[:, :-2] - 2 * tar_trans[:, 1:-1]) * self.args.rec_weight + # g_loss_final += 5*v2 + # g_loss_final += 5*a2 + # self.tracker.update_meter("transv", "train", loss_trans_vel.item()) + # g_loss_final += loss_trans_vel + # loss_trans = self.vel_loss(rec_xyz_trans, tar_trans) * self.args.rec_weight + # self.tracker.update_meter("trans", "train", loss_trans.item()) + # g_loss_final += loss_trans + + # vertices loss + if self.args.rec_ver_weight > 0: + # print(tar_pose.shape, bs, n, j) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0]) + tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0]) + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:,:3], + body_pose=rec_pose[:,3:21*3+3], + left_hand_pose=rec_pose[:,25*3:40*3], + right_hand_pose=rec_pose[:,40*3:55*3], + return_verts=False, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + vertices_tar = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=tar_pose[:, 66:69], + global_orient=tar_pose[:,:3], + body_pose=tar_pose[:,3:21*3+3], + left_hand_pose=tar_pose[:,25*3:40*3], + right_hand_pose=tar_pose[:,40*3:55*3], + return_verts=False, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + joints_rec = vertices_rec['joints'] + # print(joints_rec.shape) + joints_rec = joints_rec.reshape(bs, n, -1, 3) + vectices_loss = self.vectices_loss(vertices_rec['joints'], vertices_tar['joints']) + foot_idx = [7, 8, 10, 11] + model_contact = net_out["rec_pose"][:, :, j*6+3:j*6+7] + # find static indices consistent with model's own predictions + static_idx = model_contact > 0.95 # N x S x 4 + # print(model_contact,static_idx) + model_feet = joints_rec[:, :, foot_idx] # foot positions (N, S, 4, 3) + model_foot_v = torch.zeros_like(model_feet) + model_foot_v[:, :-1] = ( + model_feet[:, 1:, :, :] - model_feet[:, :-1, :, :] + ) # (N, S-1, 4, 3) + model_foot_v[~static_idx] = 0 + foot_loss = self.vel_loss( + model_foot_v, torch.zeros_like(model_foot_v) + ) + self.tracker.update_meter("foot", "train", foot_loss.item()*self.args.rec_weight * self.args.rec_ver_weight*20) + self.tracker.update_meter("ver", "train", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + g_loss_final += (vectices_loss)*self.args.rec_weight*self.args.rec_ver_weight + g_loss_final += foot_loss*self.args.rec_weight*self.args.rec_ver_weight*20 + + # ---------------------- vae -------------------------- # + if "VQVAE" in self.args.g_name: + loss_embedding = net_out["embedding_loss"] + g_loss_final += loss_embedding + self.tracker.update_meter("com", "train", loss_embedding.item()) + # elif "VAE" in self.args.g_name: + # pose_mu, pose_logvar = net_out["pose_mu"], net_out["pose_logvar"] + # KLD = -0.5 * torch.sum(1 + pose_logvar - pose_mu.pow(2) - pose_logvar.exp()) + # if epoch < 0: + # KLD_weight = 0 + # else: + # KLD_weight = min(1.0, (epoch - 0) * 0.05) * 0.01 + # loss += KLD_weight * KLD + # self.tracker.update_meter("kl", "train", KLD_weight * KLD.item()) + g_loss_final.backward() + if self.args.grad_norm != 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm) + self.opt.step() + t_train = time.time() - t_start - t_data + t_start = time.time() + mem_cost = torch.cuda.memory_cached() / 1E9 + lr_g = self.opt.param_groups[0]['lr'] + if its % self.args.log_period == 0: + self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g) + if self.args.debug: + if its == 1: break + self.opt_s.step(epoch) + + def val(self, epoch): + self.model.eval() + t_start = time.time() + with torch.no_grad(): + for its, dict_data in enumerate(self.val_loader): + tar_pose_raw = dict_data["pose"] + tar_beta = dict_data["beta"].cuda() + tar_trans = dict_data["trans"].cuda() + tar_trans_vel_x = other_tools.estimate_linear_velocity(tar_trans[:, :, 0:1], dt=1/self.args.pose_fps) + tar_trans_vel_z = other_tools.estimate_linear_velocity(tar_trans[:, :, 2:3], dt=1/self.args.pose_fps) + #print(tar_pose.shape) + tar_pose = tar_pose_raw[:, :, :27].cuda() + + tar_contact = tar_pose_raw[:, :, 27:31].cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_exps = torch.zeros((bs, n, 100)).cuda() + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + tar_trans_copy = tar_trans + tar_contact_copy = tar_contact + in_tar_pose = torch.cat((tar_pose, tar_trans_copy, tar_contact_copy), dim=-1) + t_data = time.time() - t_start + + #self.opt.zero_grad() + #g_loss_final = 0 + net_out = self.model(in_tar_pose) + rec_pose = net_out["rec_pose"][:, :, :j*6] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("rec", "val", loss_rec.item()) + rec_contact = net_out["rec_pose"][:, :, j*6+3:j*6+7] + # print(rec_contact.shape, tar_contact.shape) + loss_contact = self.vel_loss(rec_contact, tar_contact) * self.args.rec_weight * self.args.rec_pos_weight + self.tracker.update_meter("contact", "val", loss_contact.item()) + #g_loss_final += loss_rec + rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] + rec_x_trans = other_tools.velocity2position(rec_trans[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans[:,:,1:2] + rec_xyz_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + + # rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] + # rec_x_trans = other_tools.velocity2position(rec_trans[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + # rec_z_trans = other_tools.velocity2position(rec_trans[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + # rec_y_trans = rec_trans[:,:,1:2] + # rec_xyz_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + # loss_trans_vel = self.vel_loss(rec_trans[:, :, 0:1], tar_trans_vel_x) * self.args.rec_weight \ + # + self.vel_loss(rec_trans[:, :, 2:3], tar_trans_vel_z) * self.args.rec_weight + # v3 = self.vel_loss(rec_trans[:, :, 0:1][:, 1:] - rec_trans[:, :, 0:1][:, :-1], tar_trans_vel_x[:, 1:] - tar_trans_vel_x[:, :-1]) * self.args.rec_weight \ + # + self.vel_loss(rec_trans[:, :, 2:3][:, 1:] - rec_trans[:, :, 2:3][:, :-1], tar_trans_vel_z[:, 1:] - tar_trans_vel_z[:, :-1]) * self.args.rec_weight + # a3 = self.vel_loss(rec_trans[:, :, 0:1][:, 2:] + rec_trans[:, :, 0:1][:, :-2] - 2 * rec_trans[:, :, 0:1][:, 1:-1], tar_trans_vel_x[:, 2:] + tar_trans_vel_x[:, :-2] - 2 * tar_trans_vel_x[:, 1:-1]) * self.args.rec_weight \ + # + self.vel_loss(rec_trans[:, :, 2:3][:, 2:] + rec_trans[:, :, 2:3][:, :-2] - 2 * rec_trans[:, :, 2:3][:, 1:-1], tar_trans_vel_z[:, 2:] + tar_trans_vel_z[:, :-2] - 2 * tar_trans_vel_z[:, 1:-1]) * self.args.rec_weight + # #g_loss_final += 5*v3 + # #g_loss_final += 5*a3 + # v2 = self.vel_loss(rec_xyz_trans[:, 1:] - rec_xyz_trans[:, :-1], tar_trans[:, 1:] - tar_trans[:, :-1]) * self.args.rec_weight + # a2 = self.vel_loss(rec_xyz_trans[:, 2:] + rec_xyz_trans[:, :-2] - 2 * rec_xyz_trans[:, 1:-1], tar_trans[:, 2:] + tar_trans[:, :-2] - 2 * tar_trans[:, 1:-1]) * self.args.rec_weight + #g_loss_final += 5*v2 + #g_loss_final += 5*a2 + # self.tracker.update_meter("transv", "val", loss_trans_vel.item()) + # #g_loss_final += loss_trans_vel + # loss_trans = self.vel_loss(rec_xyz_trans, tar_trans) * self.args.rec_weight + # self.tracker.update_meter("trans", "val", loss_trans.item()) + #g_loss_final += loss_trans + + # vertices loss + if self.args.rec_ver_weight > 0: + # print(tar_pose.shape, bs, n, j) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0]) + tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0]) + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:,:3], + body_pose=rec_pose[:,3:21*3+3], + left_hand_pose=rec_pose[:,25*3:40*3], + right_hand_pose=rec_pose[:,40*3:55*3], + return_verts=False, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + vertices_tar = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=tar_pose[:, 66:69], + global_orient=tar_pose[:,:3], + body_pose=tar_pose[:,3:21*3+3], + left_hand_pose=tar_pose[:,25*3:40*3], + right_hand_pose=tar_pose[:,40*3:55*3], + return_verts=False, + return_joints=True, + leye_pose=tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75], + ) + joints_rec = vertices_rec['joints'] + joints_rec = joints_rec.reshape(bs, n, -1, 3) + vectices_loss = self.vectices_loss(vertices_rec['joints'], vertices_tar['joints']) + foot_idx = [7, 8, 10, 11] + model_contact = net_out["rec_pose"][:, :, j*6+3:j*6+7] + # find static indices consistent with model's own predictions + static_idx = model_contact > 0.95 # N x S x 4 + # print(model_contact) + model_feet = joints_rec[:, :, foot_idx] # foot positions (N, S, 4, 3) + model_foot_v = torch.zeros_like(model_feet) + model_foot_v[:, :-1] = ( + model_feet[:, 1:, :, :] - model_feet[:, :-1, :, :] + ) # (N, S-1, 4, 3) + model_foot_v[~static_idx] = 0 + foot_loss = self.vectices_loss( + model_foot_v, torch.zeros_like(model_foot_v) + ) + self.tracker.update_meter("foot", "val", foot_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + self.tracker.update_meter("ver", "val", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight) + if "VQVAE" in self.args.g_name: + loss_embedding = net_out["embedding_loss"] + self.tracker.update_meter("com", "val", loss_embedding.item()) + #g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight + if self.args.debug: + if its == 1: break + self.val_recording(epoch) + + def test(self, epoch): + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + return 0 + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + self.model.eval() + with torch.no_grad(): + for its, dict_data in enumerate(self.test_loader): + tar_pose_raw = dict_data["pose"] + tar_trans = dict_data["trans"].to(self.rank) + tar_pose = tar_pose_raw[:, :, :27].cuda() + tar_contact = tar_pose_raw[:, :, 27:31].cuda() + # tar_pose = tar_pose.cuda() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + remain = n%self.args.pose_length + tar_pose = tar_pose[:, :n-remain, :] + tar_contact = tar_contact[:, :n-remain, :] + tar_trans_copy = tar_trans[:, :n-remain, :] + tar_contact_copy = tar_contact + in_tar_pose = torch.cat([tar_pose, tar_trans_copy, tar_contact_copy], dim=-1) + #print(tar_pose.shape) + if True: + net_out = self.model(in_tar_pose) + rec_pose = net_out["rec_pose"][:, :, :j*6] + rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] - net_out["rec_pose"][:, :, j*6:j*6+3] + # print(rec_trans.shape) + rec_x_trans = other_tools.velocity2position(rec_trans[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + n = rec_pose.shape[1] + rec_trans = rec_trans.cpu().numpy().reshape(bs*n, 3) + tar_pose = tar_pose[:, :n, :] + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose)# + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + rec_pose = rec_pose.cpu().numpy() + else: + pass +# for i in range(tar_pose.shape[1]//(self.args.vae_test_len)): +# tar_pose_new = tar_pose[:,i*(self.args.vae_test_len):i*(self.args.vae_test_len)+self.args.vae_test_len,:] +# net_out = self.model(**dict(inputs=tar_pose_new)) +# rec_pose = net_out["rec_pose"] +# rec_pose = (rec_pose.reshape(rec_pose.shape[0], rec_pose.shape[1], -1, 6) * self.joint_level_mask_cuda).reshape(rec_pose.shape[0], rec_pose.shape[1], -1) +# if "rot6d" in self.args.pose_rep: +# rec_pose = data_transfer.rotation_6d_to_matrix(rec_pose.reshape(tar_pose.shape[0], self.args.vae_test_len, -1, 6)) +# rec_pose = data_transfer.matrix_to_euler_angles(rec_pose, "XYZ").reshape(rec_pose.shape[0], rec_pose.shape[1], -1) +# if "smplx" not in self.args.pose_rep: +# rec_pose = torch.rad2deg(rec_pose) +# rec_pose = rec_pose * self.joint_mask_cuda + +# out_sub = rec_pose.cpu().numpy().reshape(-1, rec_pose.shape[2]) +# if i != 0: +# out_final = np.concatenate((out_final,out_sub), 0) +# else: +# out_final = out_sub + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + tar_pose = tar_pose.cpu().numpy() + + total_length += n + # --- save --- # + if 'smplx' in self.args.pose_rep: + gt_npz = np.load(self.args.data_path+self.args.pose_rep+"/"+test_seq_list.iloc[its]['id']+'.npz', allow_pickle=True) + stride = int(30 / self.args.pose_fps) + tar_pose = self.inverse_selection(tar_pose, self.test_data.joint_mask, tar_pose.shape[0]) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose[:n], + expressions=gt_npz["expressions"]-gt_npz["expressions"], + trans=rec_trans-rec_trans, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + rec_pose = self.inverse_selection(rec_pose, self.test_data.joint_mask, rec_pose.shape[0]) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose, + expressions=gt_npz["expressions"]-gt_npz["expressions"], + trans=rec_trans-rec_trans, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + else: + rec_pose = rc.axis_angle_to_matrix(torch.from_numpy(rec_pose.reshape(bs*n, j, 3))) + rec_pose = np.rad2deg(rc.matrix_to_euler_angles(rec_pose, "XYZ")).reshape(bs*n, j*3).numpy() + tar_pose = rc.axis_angle_to_matrix(torch.from_numpy(tar_pose.reshape(bs*n, j, 3))) + tar_pose = np.rad2deg(rc.matrix_to_euler_angles(tar_pose, "XYZ")).reshape(bs*n, j*3).numpy() + #trans="0.000000 0.000000 0.000000" + + with open(f"{self.args.data_path}{self.args.pose_rep}/{test_seq_list.iloc[its]['id']}.bvh", "r") as f_demo: + with open(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_gt: + with open(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_real: + for i, line_data in enumerate(f_demo.readlines()): + if i < 431: + f_real.write(line_data) + f_gt.write(line_data) + else: break + for line_id in range(n): #,args.pre_frames, args.pose_length + line_data = np.array2string(rec_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + f_real.write(line_data[1:-2]+'\n') + for line_id in range(n): #,args.pre_frames, args.pose_length + line_data = np.array2string(tar_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + f_gt.write(line_data[1:-2]+'\n') + # with open(results_save_path+"gt_"+test_seq_list[its]+'.pkl', 'wb') as fw: + # pickle.dump(new_dict, fw) + # #new_dict2["fullpose"] = out_final + # with open(results_save_path+"res_"+test_seq_list[its]+'.pkl', 'wb') as fw1: + # pickle.dump(new_dict2, fw1) + + # other_tools.render_one_sequence( + # results_save_path+"res_"+test_seq_list[its]+'.pkl', + # results_save_path+"gt_"+test_seq_list[its]+'.pkl', + # results_save_path, + # self.args.data_path + self.args.test_data_path + 'wave16k/' + test_seq_list[its]+'.npy', + # ) + + #if its == 1:break + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..11a141959a6f3b4638139e16772a0fa4431e69e4 --- /dev/null +++ b/app.py @@ -0,0 +1,664 @@ +import spaces +import os +# os.system("Xvfb :99 -ac &") +# os.environ["DISPLAY"] = ":99" +import OpenGL.GL as gl +os.environ["PYOPENGL_PLATFORM"] = "egl" +os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1" +import signal +import time +import csv +import sys +import warnings +import random +import gradio as gr +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp +import numpy as np +import time +import pprint +from loguru import logger +import smplx +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from utils import config, logger_tools, other_tools_hf, metric, data_transfer +from dataloaders import data_tools +from dataloaders.build_vocab import Vocab +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from dataloaders.data_tools import joints_list +from utils import rotation_conversions as rc +import soundfile as sf +import librosa + +def inverse_selection_tensor(filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + selected_indices = torch.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + +@spaces.GPU(duration=120) +def test_demo_gpu( + model, vq_model_face, vq_model_upper, vq_model_hands, vq_model_lower, global_motion, smplx_model, + dict_data, + args, + joints, joint_mask_upper, joint_mask_lower, joint_mask_hands, + log_softmax, +): + rank = 0 + other_tools_hf.load_checkpoints(vq_model_face, args.data_path_1 + "pretrained_vq/last_790_face_v2.bin", args.e_name) + other_tools_hf.load_checkpoints(vq_model_upper, args.data_path_1 + "pretrained_vq/upper_vertex_1layer_710.bin", args.e_name) + other_tools_hf.load_checkpoints(vq_model_hands, args.data_path_1 + "pretrained_vq/hands_vertex_1layer_710.bin", args.e_name) + other_tools_hf.load_checkpoints(vq_model_lower, args.data_path_1 + "pretrained_vq/lower_foot_600.bin", args.e_name) + other_tools_hf.load_checkpoints(global_motion, args.data_path_1 + "pretrained_vq/last_1700_foot.bin", args.e_name) + other_tools_hf.load_checkpoints(model, args.test_ckpt, args.g_name) + model.to(rank).eval() + smplx_model.to(rank).eval() + vq_model_face.to(rank).eval() + vq_model_upper.to(rank).eval() + vq_model_hands.to(rank).eval() + vq_model_lower.to(rank).eval() + global_motion.to(rank).eval() + + with torch.no_grad(): + tar_pose_raw = dict_data["pose"] + tar_pose = tar_pose_raw[:, :, :165].to(rank) + tar_contact = tar_pose_raw[:, :, 165:169].to(rank) + tar_trans = dict_data["trans"].to(rank) + tar_exps = dict_data["facial"].to(rank) + in_audio = dict_data["audio"].to(rank) + in_word = None# dict_data["word"].to(rank) + tar_beta = dict_data["beta"].to(rank) + tar_id = dict_data["id"].to(rank).long() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], joints + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + # tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2) + + tar_index_value_face_top = vq_model_face.map2index(tar_pose_face) # bs*n/4 + tar_index_value_upper_top = vq_model_upper.map2index(tar_pose_upper) # bs*n/4 + tar_index_value_hands_top = vq_model_hands.map2index(tar_pose_hands) # bs*n/4 + tar_index_value_lower_top = vq_model_lower.map2index(tar_pose_lower) # bs*n/4 + + latent_face_top = vq_model_face.map2latent(tar_pose_face) # bs*n/4 + latent_upper_top = vq_model_upper.map2latent(tar_pose_upper) # bs*n/4 + latent_hands_top = vq_model_hands.map2latent(tar_pose_hands) # bs*n/4 + latent_lower_top = vq_model_lower.map2latent(tar_pose_lower) # bs*n/4 + + latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) + + index_in = torch.stack([tar_index_value_upper_top, tar_index_value_hands_top, tar_index_value_lower_top], dim=-1).long() + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + + loaded_data = { + "tar_pose_jaw": tar_pose_jaw, + "tar_pose_face": tar_pose_face, + "tar_pose_upper": tar_pose_upper, + "tar_pose_lower": tar_pose_lower, + "tar_pose_hands": tar_pose_hands, + 'tar_pose_leg': tar_pose_leg, + "in_audio": in_audio, + "in_word": in_word, + "tar_trans": tar_trans, + "tar_exps": tar_exps, + "tar_beta": tar_beta, + "tar_pose": tar_pose, + "tar4dis": tar4dis, + "tar_index_value_face_top": tar_index_value_face_top, + "tar_index_value_upper_top": tar_index_value_upper_top, + "tar_index_value_hands_top": tar_index_value_hands_top, + "tar_index_value_lower_top": tar_index_value_lower_top, + "latent_face_top": latent_face_top, + "latent_upper_top": latent_upper_top, + "latent_hands_top": latent_hands_top, + "latent_lower_top": latent_lower_top, + "latent_in": latent_in, + "index_in": index_in, + "tar_id": tar_id, + "latent_all": latent_all, + "tar_pose_6d": tar_pose_6d, + "tar_contact": tar_contact, + } + + mode = 'test' + bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], joints + tar_pose = loaded_data["tar_pose"] + tar_beta = loaded_data["tar_beta"] + in_word =None# loaded_data["in_word"] + tar_exps = loaded_data["tar_exps"] + tar_contact = loaded_data["tar_contact"] + in_audio = loaded_data["in_audio"] + tar_trans = loaded_data["tar_trans"] + + remain = n%8 + if remain != 0: + tar_pose = tar_pose[:, :-remain, :] + tar_beta = tar_beta[:, :-remain, :] + tar_trans = tar_trans[:, :-remain, :] + # in_word = in_word[:, :-remain] + tar_exps = tar_exps[:, :-remain, :] + tar_contact = tar_contact[:, :-remain, :] + n = n - remain + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + + rec_index_all_face = [] + rec_index_all_upper = [] + rec_index_all_lower = [] + rec_index_all_hands = [] + + roundt = (n - args.pre_frames) // (args.pose_length - args.pre_frames) + remain = (n - args.pre_frames) % (args.pose_length - args.pre_frames) + round_l = args.pose_length - args.pre_frames + + for i in range(0, roundt): + # in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+args.pre_frames] + # audio fps is 16000 and pose fps is 30 + in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*args.pre_frames] + in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+args.pre_frames] + mask_val = torch.ones(bs, args.pose_length, args.pose_dims+3+4).float().cuda() + mask_val[:, :args.pre_frames, :] = 0.0 + if i == 0: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+args.pre_frames, :] + else: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+args.pre_frames, :] + # print(latent_all_tmp.shape, latent_last.shape) + latent_all_tmp[:, :args.pre_frames, :] = latent_last[:, -args.pre_frames:, :] + + net_out_val = model( + in_audio = in_audio_tmp, + in_word=None, #in_word_tmp, + mask=mask_val, + in_motion = latent_all_tmp, + in_id = in_id_tmp, + use_attentions=True,) + + if args.cu != 0: + rec_index_upper = log_softmax(net_out_val["cls_upper"]).reshape(-1, args.vae_codebook_size) + _, rec_index_upper = torch.max(rec_index_upper.reshape(-1, args.pose_length, args.vae_codebook_size), dim=2) + #rec_upper = vq_model_upper.decode(rec_index_upper) + else: + _, rec_index_upper, _, _ = vq_model_upper.quantizer(net_out_val["rec_upper"]) + #rec_upper = vq_model_upper.decoder(rec_index_upper) + if args.cl != 0: + rec_index_lower = log_softmax(net_out_val["cls_lower"]).reshape(-1, args.vae_codebook_size) + _, rec_index_lower = torch.max(rec_index_lower.reshape(-1, args.pose_length, args.vae_codebook_size), dim=2) + #rec_lower = vq_model_lower.decode(rec_index_lower) + else: + _, rec_index_lower, _, _ = vq_model_lower.quantizer(net_out_val["rec_lower"]) + #rec_lower = vq_model_lower.decoder(rec_index_lower) + if args.ch != 0: + rec_index_hands = log_softmax(net_out_val["cls_hands"]).reshape(-1, args.vae_codebook_size) + _, rec_index_hands = torch.max(rec_index_hands.reshape(-1, args.pose_length, args.vae_codebook_size), dim=2) + #rec_hands = vq_model_hands.decode(rec_index_hands) + else: + _, rec_index_hands, _, _ = vq_model_hands.quantizer(net_out_val["rec_hands"]) + #rec_hands = vq_model_hands.decoder(rec_index_hands) + if args.cf != 0: + rec_index_face = log_softmax(net_out_val["cls_face"]).reshape(-1, args.vae_codebook_size) + _, rec_index_face = torch.max(rec_index_face.reshape(-1, args.pose_length, args.vae_codebook_size), dim=2) + #rec_face = vq_model_face.decoder(rec_index_face) + else: + _, rec_index_face, _, _ = vq_model_face.quantizer(net_out_val["rec_face"]) + #rec_face = vq_model_face.decoder(rec_index_face) + + if i == 0: + rec_index_all_face.append(rec_index_face) + rec_index_all_upper.append(rec_index_upper) + rec_index_all_lower.append(rec_index_lower) + rec_index_all_hands.append(rec_index_hands) + else: + rec_index_all_face.append(rec_index_face[:, args.pre_frames:]) + rec_index_all_upper.append(rec_index_upper[:, args.pre_frames:]) + rec_index_all_lower.append(rec_index_lower[:, args.pre_frames:]) + rec_index_all_hands.append(rec_index_hands[:, args.pre_frames:]) + + if args.cu != 0: + rec_upper_last = vq_model_upper.decode(rec_index_upper) + else: + rec_upper_last = vq_model_upper.decoder(rec_index_upper) + if args.cl != 0: + rec_lower_last = vq_model_lower.decode(rec_index_lower) + else: + rec_lower_last = vq_model_lower.decoder(rec_index_lower) + if args.ch != 0: + rec_hands_last = vq_model_hands.decode(rec_index_hands) + else: + rec_hands_last = vq_model_hands.decoder(rec_index_hands) + # if args.cf != 0: + # rec_face_last = vq_model_face.decode(rec_index_face) + # else: + # rec_face_last = vq_model_face.decoder(rec_index_face) + + rec_pose_legs = rec_lower_last[:, :, :54] + bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] + rec_pose_upper = rec_upper_last.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = inverse_selection_tensor(rec_pose_upper, joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = inverse_selection_tensor(rec_pose_lower, joint_mask_lower, bs*n) + rec_pose_hands = rec_hands_last.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = inverse_selection_tensor(rec_pose_hands, joint_mask_hands, bs*n) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + rec_trans_v_s = rec_lower_last[:, :, 54:57] + rec_x_trans = other_tools_hf.velocity2position(rec_trans_v_s[:, :, 0:1], 1/args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools_hf.velocity2position(rec_trans_v_s[:, :, 2:3], 1/args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + latent_last = torch.cat([rec_pose, rec_trans, rec_lower_last[:, :, 57:61]], dim=-1) + + rec_index_face = torch.cat(rec_index_all_face, dim=1) + rec_index_upper = torch.cat(rec_index_all_upper, dim=1) + rec_index_lower = torch.cat(rec_index_all_lower, dim=1) + rec_index_hands = torch.cat(rec_index_all_hands, dim=1) + if args.cu != 0: + rec_upper = vq_model_upper.decode(rec_index_upper) + else: + rec_upper = vq_model_upper.decoder(rec_index_upper) + if args.cl != 0: + rec_lower = vq_model_lower.decode(rec_index_lower) + else: + rec_lower = vq_model_lower.decoder(rec_index_lower) + if args.ch != 0: + rec_hands = vq_model_hands.decode(rec_index_hands) + else: + rec_hands = vq_model_hands.decoder(rec_index_hands) + if args.cf != 0: + rec_face = vq_model_face.decode(rec_index_face) + else: + rec_face = vq_model_face.decoder(rec_index_face) + + rec_exps = rec_face[:, :, 6:] + rec_pose_jaw = rec_face[:, :, :6] + rec_pose_legs = rec_lower[:, :, :54] + bs, n = rec_pose_jaw.shape[0], rec_pose_jaw.shape[1] + rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = inverse_selection_tensor(rec_pose_upper, joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = inverse_selection_tensor(rec_pose_lower, joint_mask_lower, bs*n) + rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = inverse_selection_tensor(rec_pose_hands, joint_mask_hands, bs*n) + rec_pose_jaw = rec_pose_jaw.reshape(bs*n, 6) + rec_pose_jaw = rc.rotation_6d_to_matrix(rec_pose_jaw) + rec_pose_jaw = rc.matrix_to_axis_angle(rec_pose_jaw).reshape(bs*n, 1*3) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose[:, 66:69] = rec_pose_jaw + + to_global = rec_lower + to_global[:, :, 54:57] = 0.0 + to_global[:, :, :54] = rec_lower2global + rec_global = global_motion(to_global) + + rec_trans_v_s = rec_global["rec_pose"][:, :, 54:57] + rec_x_trans = other_tools_hf.velocity2position(rec_trans_v_s[:, :, 0:1], 1/args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools_hf.velocity2position(rec_trans_v_s[:, :, 2:3], 1/args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + tar_pose = tar_pose[:, :n, :] + tar_exps = tar_exps[:, :n, :] + tar_trans = tar_trans[:, :n, :] + tar_beta = tar_beta[:, :n, :] + + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + + net_out = { + 'rec_pose': rec_pose, + 'rec_trans': rec_trans, + 'tar_pose': tar_pose, + 'tar_exps': tar_exps, + 'tar_beta': tar_beta, + 'tar_trans': tar_trans, + 'rec_exps': rec_exps, + } + + + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + tar_exps = net_out['tar_exps'] + tar_beta = net_out['tar_beta'] + rec_trans = net_out['rec_trans'] + tar_trans = net_out['tar_trans'] + rec_exps = net_out['rec_exps'] + # print(rec_pose.shape, tar_pose.shape) + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], joints + # interpolate to 30fps + if (30/args.pose_fps) != 1: + assert 30%args.pose_fps == 0 + n *= int(30/args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/args.pose_fps, mode='linear').permute(0,2,1) + + # print(rec_pose.shape, tar_pose.shape) + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + + return tar_pose, rec_pose, tar_exps, tar_beta, rec_trans, tar_trans, rec_exps, bs, n, j + + +class BaseTrainer(object): + def __init__(self, args, sp, ap, tp): + hf_dir = "hf" + if not os.path.exists(args.out_path + "custom/" + hf_dir + "/"): + os.makedirs(args.out_path + "custom/" + hf_dir + "/") + sf.write(args.out_path + "custom/" + hf_dir + "/tmp.wav", ap[1][:ap[0]*8], ap[0]) + self.audio_path = args.out_path + "custom/" + hf_dir + "/tmp.wav" + audio, ssr = librosa.load(self.audio_path) + ap = (ssr, audio) + self.args = args + self.rank = 0 # dist.get_rank() + + #self.checkpoint_path = args.out_path + "custom/" + args.name + args.notes + "/" #wandb.run.dir #args.cache_path+args.out_path+"/"+args.name + self.checkpoint_path = args.out_path + "custom/" + hf_dir + "/" + if self.rank == 0: + self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test", smplx_path=sp, audio_path=ap, text_path=tp) + self.test_loader = torch.utils.data.DataLoader( + self.test_data, + batch_size=1, + shuffle=False, + num_workers=args.loader_workers, + drop_last=False, + ) + logger.info(f"Init test dataloader success") + model_module = __import__(f"models.{args.model}", fromlist=["something"]) + + if args.ddp: + self.model = getattr(model_module, args.g_name)(args).to(self.rank) + process_group = torch.distributed.new_group() + self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model, process_group) + self.model = DDP(self.model, device_ids=[self.rank], output_device=self.rank, + broadcast_buffers=False, find_unused_parameters=False) + else: + self.model = torch.nn.DataParallel(getattr(model_module, args.g_name)(args), args.gpus).cpu() + + if self.rank == 0: + logger.info(self.model) + logger.info(f"init {args.g_name} success") + + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ) + + self.args = args + self.joints = self.test_data.joints + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list_face = joints_list["beat_smplx_face"] + self.tar_joint_list_upper = joints_list["beat_smplx_upper"] + self.tar_joint_list_hands = joints_list["beat_smplx_hands"] + self.tar_joint_list_lower = joints_list["beat_smplx_lower"] + + self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = 55 + for joint_name in self.tar_joint_list_face: + self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_upper: + self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_hands: + self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_lower: + self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + + self.tracker = other_tools_hf.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False,False,False]) + + vq_model_module = __import__(f"models.motion_representation", fromlist=["something"]) + self.args.vae_layer = 2 + self.args.vae_length = 256 + self.args.vae_test_dim = 106 + self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).cpu() + # print(self.vq_model_face) + # other_tools_hf.load_checkpoints(self.vq_model_face, self.args.data_path_1 + "pretrained_vq/last_790_face_v2.bin", args.e_name) + self.args.vae_test_dim = 78 + self.vq_model_upper = getattr(vq_model_module, "VQVAEConvZero")(self.args).cpu() + # other_tools_hf.load_checkpoints(self.vq_model_upper, self.args.data_path_1 + "pretrained_vq/upper_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 180 + self.vq_model_hands = getattr(vq_model_module, "VQVAEConvZero")(self.args).cpu() + # other_tools_hf.load_checkpoints(self.vq_model_hands, self.args.data_path_1 + "pretrained_vq/hands_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).cpu() + # other_tools_hf.load_checkpoints(self.vq_model_lower, self.args.data_path_1 + "pretrained_vq/lower_foot_600.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.global_motion = getattr(vq_model_module, "VAEConvZero")(self.args).cpu() + # other_tools_hf.load_checkpoints(self.global_motion, self.args.data_path_1 + "pretrained_vq/last_1700_foot.bin", args.e_name) + self.args.vae_test_dim = 330 + self.args.vae_layer = 4 + self.args.vae_length = 240 + + # self.cls_loss = nn.NLLLoss().to(self.rank) + # self.reclatent_loss = nn.MSELoss().to(self.rank) + # self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank) + # self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank) + self.log_softmax = nn.LogSoftmax(dim=2) + + + def inverse_selection(self, filtered_t, selection_array, n): + original_shape_t = np.zeros((n, selection_array.size)) + selected_indices = np.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + selected_indices = torch.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + + def test_demo(self, epoch): + ''' + input audio and text, output motion + do not calculate loss and metric + save video + ''' + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + import shutil + shutil.rmtree(results_save_path) + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + align = 0 + latent_out = [] + latent_ori = [] + l2_all = 0 + lvel = 0 + for its, batch_data in enumerate(self.test_loader): + tar_pose, rec_pose, tar_exps, tar_beta, rec_trans, tar_trans, rec_exps, bs, n, j = test_demo_gpu( + self.model, self.vq_model_face, self.vq_model_upper, self.vq_model_hands, self.vq_model_lower, self.global_motion, self.smplx, + batch_data, + self.args, + self.joints, self.joint_mask_upper, self.joint_mask_lower, self.joint_mask_hands, + self.log_softmax, + ) + + tar_pose_np = tar_pose.detach().cpu().numpy() + rec_pose_np = rec_pose.detach().cpu().numpy() + rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) + rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) + #''' + # its = 0 + gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose_np, + expressions=tar_exp_np, + trans=tar_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + + total_length += n + render_vid_path = other_tools_hf.render_one_sequence_no_gt( + results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + # results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + results_save_path, + self.audio_path, + self.args.data_path_1+"smplx_models/", + use_matplotlib = False, + args = self.args, + ) + result = gr.Video(value=render_vid_path, visible=True) + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") + return result + + +@logger.catch +def emage(audio_path): + smplx_path = None + text_path = None + rank = 0 + world_size = 1 + args = config.parse_args() + #os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/" + if not sys.warnoptions: + warnings.simplefilter("ignore") + # dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + + #logger_tools.set_args_and_logger(args, rank) + other_tools_hf.set_random_seed(args) + other_tools_hf.print_exp_info(args) + + # return one intance of trainer + trainer = BaseTrainer(args, sp = smplx_path, ap = audio_path, tp = text_path) + result = trainer.test_demo(999) + return result + +examples = [ + ["./EMAGE/test_sequences/wave16k/2_scott_0_1_1.wav"], + ["./EMAGE/test_sequences/wave16k/2_scott_0_2_2.wav"], + ["./EMAGE/test_sequences/wave16k/2_scott_0_3_3.wav"], +] + +demo = gr.Interface( + emage, # function + inputs=[ + # gr.File(label="Please upload SMPL-X file with npz format here.", file_types=["npz", "NPZ"]), + gr.Audio(), + # gr.File(label="Please upload textgrid format file here.", file_types=["TextGrid", "Textgrid", "textgrid"]) + ], # input type + outputs=gr.Video(format="mp4", visible=True), + title='\ +
\ + EMAGE: Towards Unified Holistic Co-Speech Gesture Generation via Expressive Masked Audio Gesture Modeling
\ + CVPR 2024
\ +
', + description='\ +
\ + Haiyang Liu1*, Zihao Zhu2*, Giorgio Becherini3, Yichen Peng4, Mingyang Su5,
\ + You Zhou, Xuefei Zhe, Naoya Iwamoto, Bo Zheng, Michael J. Black3
\ + (*Equal Contribution)
\ + 1The University of Tokyo, 2Keio University, 4Japan Advanced Institute of Science and Technology,
\ + 3Max Planck Institute for Intelligent Systems, 5Tsinghua University
\ +
\ + ', + article="\ + Due to the limited resources in this space, we process the first 8s of your uploaded audio.
\ + Try to develop this space locally for longer motion generation, e.g., 60s.
\ + Relevant links: [Project Page (https://pantomatrix.github.io/EMAGE/)\ + ", + examples=examples, +) + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"]='127.0.0.1' + os.environ["MASTER_PORT"]='8675' + #os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + demo.launch(share=True) \ No newline at end of file diff --git a/camn_trainer.py b/camn_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c62fc207d2e58d27669a476ec90cadc10958baad --- /dev/null +++ b/camn_trainer.py @@ -0,0 +1,361 @@ +import train +import os +import time +import csv +import sys +import warnings +import random +import numpy as np +import time +import pprint +import pickle + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP +from loguru import logger +import smplx +import librosa + +from utils import config, logger_tools, other_tools, metric +from utils import rotation_conversions as rc +from dataloaders import data_tools +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from scipy.spatial.transform import Rotation + + +class CustomTrainer(train.BaseTrainer): + def __init__(self, args): + super().__init__(args) + self.joints = self.train_data.joints + self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'div_reg', "kl"], [False,True,True, False, False, False, False, False, False, False, False, False, False]) + if not self.args.rot6d: #"rot6d" not in args.pose_rep: + logger.error(f"this script is for rot6d, your pose rep. is {args.pose_rep}") + self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank) + self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank) + + def _load_data(self, dict_data): + tar_pose = dict_data["pose"].to(self.rank) + tar_trans = dict_data["trans"].to(self.rank) + tar_exps = dict_data["facial"].to(self.rank) + tar_beta = dict_data["beta"].to(self.rank) + tar_id = dict_data["id"].to(self.rank).long() + tar_word = dict_data["word"].to(self.rank) + in_audio = dict_data["audio"].to(self.rank) + in_emo = dict_data["emo"].to(self.rank) + #in_sem = dict_data["sem"].to(self.rank) + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + in_pre_pose_cat = torch.cat([tar_pose[:, 0:self.args.pre_frames], tar_trans[:, :self.args.pre_frames]], dim=2).to(self.rank) + + in_pre_pose = tar_pose.new_zeros((bs, n, j*6+1+3)).to(self.rank) + in_pre_pose[:, 0:self.args.pre_frames, :-1] = in_pre_pose_cat[:, 0:self.args.pre_frames] + in_pre_pose[:, 0:self.args.pre_frames, -1] = 1 + return { + "tar_pose": tar_pose, + "in_audio": in_audio, + "in_motion": in_pre_pose, + "tar_trans": tar_trans, + "tar_exps": tar_exps, + "tar_beta": tar_beta, + "tar_word": tar_word, + 'tar_id': tar_id, + 'in_emo': in_emo, + #'in_sem': in_sem, + } + + def _d_training(self, loaded_data): + bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints + net_out = self.model(in_audio = loaded_data['in_audio'], pre_seq = loaded_data["in_motion"], in_text=loaded_data["tar_word"], in_id=loaded_data["tar_id"], in_emo=loaded_data["in_emo"], in_facial = loaded_data["tar_exps"]) + rec_pose = net_out["rec_pose"][:, :, :j*6] + # rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] + + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.rotation_6d_to_matrix(loaded_data["tar_pose"].reshape(bs, n, j, 6)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + out_d_fake = self.d_model(rec_pose) + out_d_real = self.d_model(tar_pose) + + d_loss_adv = torch.sum(-torch.mean(torch.log(out_d_real + 1e-8) + torch.log(1 - out_d_fake + 1e-8))) + self.tracker.update_meter("dis", "train", d_loss_adv.item()) + return d_loss_adv + + def _g_training(self, loaded_data, use_adv, mode="train"): + bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints + net_out = self.model(in_audio = loaded_data['in_audio'], pre_seq = loaded_data["in_motion"], in_text=loaded_data["tar_word"], in_id=loaded_data["tar_id"], in_emo=loaded_data["in_emo"], in_facial = loaded_data["tar_exps"]) + rec_pose = net_out["rec_pose"][:, :, :j*6] + rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] + # print(rec_pose.shape, bs, n, j, loaded_data['in_audio'].shape, loaded_data["in_motion"].shape) + rec_pose = rec_pose.reshape(bs, n, j, 6) + rec_pose = rc.rotation_6d_to_matrix(rec_pose) + tar_pose = rc.rotation_6d_to_matrix(loaded_data["tar_pose"].reshape(bs, n, j, 6)) + + rec_loss = self.rec_loss(tar_pose, rec_pose) + rec_loss *= self.args.rec_weight + self.tracker.update_meter("rec", mode, rec_loss.item()) + # rec_loss_vel = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) + # self.tracker.update_meter("vel", mode, rec_loss_vel.item()) + # rec_loss_acc = self.vel_loss(rec_pose[:, 2:] - 2*rec_pose[:, 1:-1] + rec_pose[:, :-2], tar_pose[:, 2:] - 2*tar_pose[:, 1:-1] + tar_pose[:, :-2]) + # self.tracker.update_meter("acc", mode, rec_loss_acc.item()) + + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + if self.args.pose_dims < 330 and mode != "train": + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs, n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs, n, j*3) + rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0]) + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, 55, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, 55*6) + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs, n, j*3) + tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0]) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, 55*6) + if use_adv and mode == 'train': + out_d_fake = self.d_model(rec_pose) + d_loss_adv = -torch.mean(torch.log(out_d_fake + 1e-8)) + self.tracker.update_meter("gen", mode, d_loss_adv.item()) + else: + d_loss_adv = 0 + + if self.args.train_trans: + trans_loss = self.vel_loss(rec_trans, loaded_data["tar_trans"]) + trans_loss *= self.args.rec_weight + self.tracker.update_meter("trans", mode, trans_loss.item()) + else: + trans_loss = 0 + # trans_loss_vel = self.vel_loss(rec_trans[:, 1:] - rec_trans[:, :-1], loaded_data["tar_trans"][:, 1:] - loaded_data["tar_trans"][:, :-1]) + # self.tracker.update_meter("transv", mode, trans_loss_vel.item()) + # trans_loss_acc = self.vel_loss(rec_trans[:, 2:] - 2*rec_trans[:, 1:-1] + rec_trans[:, :-2], loaded_data["tar_trans"][:, 2:] - 2*loaded_data["tar_trans"][:, 1:-1] + loaded_data["tar_trans"][:, :-2]) + # self.tracker.update_meter("transa", mode, trans_loss_acc.item()) + + if mode == 'train': + return d_loss_adv + rec_loss + trans_loss # + rec_loss_vel + rec_loss_acc + trans_loss_vel + trans_loss_acc + elif mode == 'val': + return { + 'rec_pose': rec_pose, + 'rec_trans': rec_trans, + 'tar_pose': tar_pose, + } + else: + return { + 'rec_pose': rec_pose, + 'rec_trans': rec_trans, + 'tar_pose': tar_pose, + 'tar_exps': loaded_data["tar_exps"], + 'tar_beta': loaded_data["tar_beta"], + 'tar_trans': loaded_data["tar_trans"], + } + + def train(self, epoch): + use_adv = bool(epoch>=self.args.no_adv_epoch) + self.model.train() + self.d_model.train() + self.tracker.reset() + t_start = time.time() + for its, batch_data in enumerate(self.train_loader): + loaded_data = self._load_data(batch_data) + t_data = time.time() - t_start + + if use_adv: + d_loss_final = 0 + self.opt_d.zero_grad() + d_loss_adv = self._d_training(loaded_data) + d_loss_final += d_loss_adv + d_loss_final.backward() + self.opt_d.step() + + self.opt.zero_grad() + g_loss_final = 0 + g_loss_final += self._g_training(loaded_data, use_adv, 'train') + g_loss_final.backward() + self.opt.step() + + mem_cost = torch.cuda.memory_cached() / 1E9 + lr_g = self.opt.param_groups[0]['lr'] + lr_d = self.opt_d.param_groups[0]['lr'] + t_train = time.time() - t_start - t_data + t_start = time.time() + if its % self.args.log_period == 0: + self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g, lr_d=lr_d) + if self.args.debug: + if its == 1: break + self.opt_s.step(epoch) + self.opt_d_s.step(epoch) + + + def val(self, epoch): + self.model.eval() + self.d_model.eval() + with torch.no_grad(): + for its, batch_data in enumerate(self.train_loader): + loaded_data = self._load_data(batch_data) + net_out = self._g_training(loaded_data, False, 'val') + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + n = tar_pose.shape[1] + if (30/self.args.pose_fps) != 1: + assert 30%self.args.pose_fps == 0 + n *= int(30/self.args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + n = tar_pose.shape[1] + remain = n%self.args.vae_test_len + tar_pose = tar_pose[:, :n-remain, :] + rec_pose = rec_pose[:, :n-remain, :] + latent_out = self.eval_copy.map2latent(rec_pose).reshape(-1, self.args.vae_length).cpu().numpy() + latent_ori = self.eval_copy.map2latent(tar_pose).reshape(-1, self.args.vae_length).cpu().numpy() + if its == 0: + latent_out_motion_all = latent_out + latent_ori_all = latent_ori + else: + latent_out_motion_all = np.concatenate([latent_out_motion_all, latent_out], axis=0) + latent_ori_all = np.concatenate([latent_ori_all, latent_ori], axis=0) + if self.args.debug: + if its == 1: break + fid_motion = data_tools.FIDCalculator.frechet_distance(latent_out_motion_all, latent_ori_all) + self.tracker.update_meter("fid", "val", fid_motion) + self.val_recording(epoch) + + def test(self, epoch): + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + return 0 + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + align = 0 + latent_out = [] + latent_ori = [] + self.model.eval() + self.smplx.eval() + self.eval_copy.eval() + with torch.no_grad(): + for its, batch_data in enumerate(self.test_loader): + loaded_data = self._load_data(batch_data) + net_out = self._g_training(loaded_data, False, 'test') + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + tar_exps = net_out['tar_exps'] + tar_beta = net_out['tar_beta'] + rec_trans = net_out['rec_trans'] + tar_trans = net_out['tar_trans'] + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], 55 + if (30/self.args.pose_fps) != 1: + assert 30%self.args.pose_fps == 0 + n *= int(30/self.args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + tar_beta = torch.nn.functional.interpolate(tar_beta.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + tar_exps = torch.nn.functional.interpolate(tar_exps.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + tar_trans = torch.nn.functional.interpolate(tar_trans.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + rec_trans = torch.nn.functional.interpolate(rec_trans.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + + # print(rec_pose.shape, tar_pose.shape) + # rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + # rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + # tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + remain = n%self.args.vae_test_len + latent_out.append(self.eval_copy.map2latent(rec_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy()) # bs * n/8 * 240 + latent_ori.append(self.eval_copy.map2latent(tar_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy()) + + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:,:3], + body_pose=rec_pose[:,3:21*3+3], + left_hand_pose=rec_pose[:,25*3:40*3], + right_hand_pose=rec_pose[:,40*3:55*3], + return_joints=True, + leye_pose=rec_pose[:, 69:72], + reye_pose=rec_pose[:, 72:75], + ) + # vertices_tar = self.smplx( + # betas=tar_beta.reshape(bs*n, 300), + # transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3), + # expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100), + # jaw_pose=tar_pose[:, 66:69], + # global_orient=tar_pose[:,:3], + # body_pose=tar_pose[:,3:21*3+3], + # left_hand_pose=tar_pose[:,25*3:40*3], + # right_hand_pose=tar_pose[:,40*3:55*3], + # return_joints=True, + # leye_pose=tar_pose[:, 69:72], + # reye_pose=tar_pose[:, 72:75], + # ) + joints_rec = vertices_rec["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3] + # joints_tar = vertices_tar["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3] + _ = self.l1_calculator.run(joints_rec) + if self.alignmenter is not None: + in_audio_eval, sr = librosa.load(self.args.data_path+"wave16k/"+test_seq_list.iloc[its]['id']+".wav") + in_audio_eval = librosa.resample(in_audio_eval, orig_sr=sr, target_sr=self.args.audio_sr) + a_offset = int(self.align_mask * (self.args.audio_sr / self.args.pose_fps)) + onset_bt = self.alignmenter.load_audio(in_audio_eval[:int(self.args.audio_sr / self.args.pose_fps*n)], a_offset, len(in_audio_eval)-a_offset, True) + beat_vel = self.alignmenter.load_pose(joints_rec, self.align_mask, n-self.align_mask, 30, True) + # print(beat_vel) + align += (self.alignmenter.calculate_align(onset_bt, beat_vel, 30) * (n-2*self.align_mask)) + + tar_pose_axis_np = tar_pose.detach().cpu().numpy() + rec_pose_axis_np = rec_pose.detach().cpu().numpy() + rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) + rec_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) - tar_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) - tar_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) + gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True) + if not self.args.train_trans: + tar_trans_np = tar_trans_np - tar_trans_np + rec_trans_np = rec_trans_np - rec_trans_np + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose_axis_np, + expressions=tar_exp_np, + trans=tar_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose_axis_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + total_length += n + + latent_out_all = np.concatenate(latent_out, axis=0) + latent_ori_all = np.concatenate(latent_ori, axis=0) + fid = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) + logger.info(f"fid score: {fid}") + self.test_recording("fid", fid, epoch) + + align_avg = align/(total_length-2*len(self.test_loader)*self.align_mask) + logger.info(f"align score: {align_avg}") + self.test_recording("bc", align_avg, epoch) + + l1div = self.l1_calculator.avg() + logger.info(f"l1div score: {l1div}") + self.test_recording("l1div", l1div, epoch) + + # data_tools.result2target_vis(self.args.pose_version, results_save_path, results_save_path, self.test_demo, False) + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") \ No newline at end of file diff --git a/configs/.ipynb_checkpoints/emage_test_hf-checkpoint.yaml b/configs/.ipynb_checkpoints/emage_test_hf-checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..564aaf431479e2d25825507d2819b8e247e857a1 --- /dev/null +++ b/configs/.ipynb_checkpoints/emage_test_hf-checkpoint.yaml @@ -0,0 +1,101 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +data_path: ./EMAGE/test_sequences/ +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +test_ckpt: ./EMAGE/emage_audio_175.bin +data_path_1: ./EMAGE/ +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: ./datasets/beat_cache/beat_smplx_en_emage_test/ +dataset: beat_testonly_hf +new_cache: True + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 64 +stride: 20 +test_length: 64 +motion_f: 256 +m_pre_encoder: null +m_encoder: null +m_fix_pre: False + +# audio config +audio_rep: wave16k +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: none +# a_fix_pre: False + +# text config +# word_rep: textgrid +# word_index_num: 11195 +# word_dims: 300 +# freeze_wordembed: False +# word_f: 256 +# t_pre_encoder: fasttext +# t_encoder: null +# t_fix_pre: False + +# facial config +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 +f_pre_encoder: null +f_encoder: null +f_fix_pre: False + +# speaker config +id_rep: onehot +speaker_f: 0 + +# model config +batch_size: 64 +# warmup_epochs: 1 +# warmup_lr: 1e-6 +lr_base: 5e-4 +model: emage_audio +g_name: MAGE_Transformer +trainer: emage +hidden_size: 768 +n_layer: 1 + +rec_weight: 1 +grad_norm: 0.99 +epochs: 400 +test_period: 20 +ll: 3 +lf: 3 +lu: 3 +lh: 3 +cl: 1 +cf: 0 +cu: 1 +ch: 1 diff --git a/configs/camn.yaml b/configs/camn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..177f4864bc8a5d8c1bb50f89675f463646133342 --- /dev/null +++ b/configs/camn.yaml @@ -0,0 +1,101 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +data_path: ./BEAT2/beat_english_v2.0.0/ +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +test_ckpt: ./EMAGE/camn.bin +data_path_1: ./EMAGE/ +vae_test_len: 64 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_camn/ +dataset: beat_sep +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 15 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 32 +stride: 10 +test_length: 32 +motion_f: 256 +m_pre_encoder: null +m_encoder: null +m_fix_pre: False + +# audio config +audio_rep: wave16k +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 128 +# a_pre_encoder: tcn_camn +# a_encoder: none +# a_fix_pre: False + +# text config +word_rep: textgrid +word_index_num: 11195 +word_dims: 300 +freeze_wordembed: False +word_f: 128 +t_pre_encoder: fasttext +t_encoder: null +t_fix_pre: False + +# facial config +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 64 +f_pre_encoder: null +f_encoder: null +f_fix_pre: False + +# speaker config +id_rep: onehot +speaker_f: 16 +emo_rep: emo +emotion_f: 8 +# sem_rep: sem + + +# model config +batch_size: 128 +# warmup_epochs: 1 +# warmup_lr: 1e-6 +lr_base: 3e-4 +model: camn +g_name: CaMN +d_name: ConvDiscriminator +trainer: camn +hidden_size: 512 +n_layer: 4 +rec_weight: 500 +no_adv_epoch: 999 +# rec_pos_weight: 1 +# rec_ver_weight: 0 +# rec_fac_weight: 1 +# grad_norm: 1 +epochs: 100 +test_period: 20 \ No newline at end of file diff --git a/configs/cnn_vqvae_face_30.yaml b/configs/cnn_vqvae_face_30.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f05b872ac57106cb57184397a6da5ebb016dffd --- /dev/null +++ b/configs/cnn_vqvae_face_30.yaml @@ -0,0 +1,82 @@ +is_train: True +ddp: False +stat: ts +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +root_path: ./ +out_path: ./outputs/audio2pose/ +cache_path: datasets/beat_cache/beat_smplx_en_face/ +project: mage_smplx +data_path: ./BEAT2/beat_english_v2.0.0/ +e_path: weights/AESKConv_240_100.bin +test_ckpt: weights/multi.bin +data_path_1: ./EMAGE/ +#torch_hub_path: datasets/hub/ +additional_data: False +dataset: beat_sep +new_cache: False +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_face +pose_rep: smplxflame_30 +facial_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 + + +vae_test_len: 64 +vae_test_dim: 106 +vae_test_stride: 20 +vae_length: 256 +vae_codebook_size: 256 +vae_layer: 2 +vae_grow: [1,1,2,1] +variational: False + +pose_dims: 106 +pose_length: 64 +stride: 20 +facial_dims: 100 +word_index_num: 11195 +word_dims: 300 +batch_size: 64 +lr_base: 3e-4 +model: motion_representation +g_name: VQVAEConvZero +#eval_model: motion_autoencoder +#e_name: HalfEmbeddingNet +trainer: aeface +decay_epochs: 780 +# audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: lp +# a_fix_pre: False + +# freeze_wordembed: False +# word_f: 128 +# t_pre_encoder: fasttext +# t_encoder: lp +# t_fix_pre: False + +# motion_f: 256 +# m_pre_encoder: lp +# m_encoder: lp +# m_fix_pre: False + +# facial_f: 128 +# f_pre_encoder: lp +# f_encoder: lp +# f_fix_pre: False + +#m_decoder: lstm +#decode_fusion: cat +#n_layer: 2 +#hidden_size: 512 +rec_weight: 1 +rec_pos_weight: 1 +rec_ver_weight: 1 +# rec_fac_weight: 1 +#ita_weight: 0 +#iwa_weight: 0 +#fusion_mode: sum +# grad_norm: 1 +epochs: 800 +test_period: 100 \ No newline at end of file diff --git a/configs/cnn_vqvae_hands_30.yaml b/configs/cnn_vqvae_hands_30.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e909d39d618bb89cd19333bb5eb427cfa937fc26 --- /dev/null +++ b/configs/cnn_vqvae_hands_30.yaml @@ -0,0 +1,81 @@ +is_train: True +ddp: False +stat: ts +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +root_path: ./ +out_path: ./outputs/audio2pose/ +cache_path: datasets/beat_cache/beat_smplx_en_hands/ +project: mage_smplx +data_path: ./BEAT2/beat_english_v2.0.0/ +e_path: weights/AESKConv_240_100.bin +test_ckpt: weights/multi.bin +data_path_1: ./EMAGE/ +#torch_hub_path: datasets/hub/ +additional_data: False +dataset: beat_sep +new_cache: False +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_hands +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 + + +vae_test_len: 64 +vae_test_dim: 180 +vae_test_stride: 20 +vae_length: 256 +vae_codebook_size: 256 +vae_layer: 2 +vae_grow: [1,1,2,1] +variational: False + +pose_dims: 180 +pose_length: 64 +stride: 20 +facial_dims: 100 +word_index_num: 11195 +word_dims: 300 +batch_size: 64 +lr_base: 3e-4 +model: motion_representation +g_name: VQVAEConvZero +#eval_model: motion_autoencoder +#e_name: HalfEmbeddingNet +trainer: ae +decay_epochs: 780 +# audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: lp +# a_fix_pre: False + +# freeze_wordembed: False +# word_f: 128 +# t_pre_encoder: fasttext +# t_encoder: lp +# t_fix_pre: False + +# motion_f: 256 +# m_pre_encoder: lp +# m_encoder: lp +# m_fix_pre: False + +# facial_f: 128 +# f_pre_encoder: lp +# f_encoder: lp +# f_fix_pre: False + +#m_decoder: lstm +#decode_fusion: cat +#n_layer: 2 +#hidden_size: 512 +rec_weight: 1 +rec_pos_weight: 1 +rec_ver_weight: 1 +# rec_fac_weight: 1 +#ita_weight: 0 +#iwa_weight: 0 +#fusion_mode: sum +# grad_norm: 1 +epochs: 800 +test_period: 100 \ No newline at end of file diff --git a/configs/cnn_vqvae_lower_30.yaml b/configs/cnn_vqvae_lower_30.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7c61a8cdc9f0ca6e16d13259285b36c9bd83309 --- /dev/null +++ b/configs/cnn_vqvae_lower_30.yaml @@ -0,0 +1,81 @@ +is_train: True +ddp: False +stat: ts +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +root_path: ./ +out_path: ./outputs/audio2pose/ +cache_path: datasets/beat_cache/beat_smplx_en_lower/ +project: mage_smplx +data_path: ./BEAT2/beat_english_v2.0.0/ +e_path: weights/AESKConv_240_100.bin +test_ckpt: weights/multi.bin +data_path_1: ./EMAGE/ +#torch_hub_path: datasets/hub/ +additional_data: False +dataset: beat_sep_lower +new_cache: False +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_lower +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 + + +vae_test_len: 64 +vae_test_dim: 61 +vae_test_stride: 20 +vae_length: 256 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +pose_dims: 61 +pose_length: 64 +stride: 20 +facial_dims: 100 +word_index_num: 11195 +word_dims: 300 +batch_size: 64 +lr_base: 3e-4 +model: motion_representation +g_name: VAEConvZero +#eval_model: motion_autoencoder +#e_name: HalfEmbeddingNet +trainer: aelower +decay_epochs: 780 +# audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: lp +# a_fix_pre: False + +# freeze_wordembed: False +# word_f: 128 +# t_pre_encoder: fasttext +# t_encoder: lp +# t_fix_pre: False + +# motion_f: 256 +# m_pre_encoder: lp +# m_encoder: lp +# m_fix_pre: False + +# facial_f: 128 +# f_pre_encoder: lp +# f_encoder: lp +# f_fix_pre: False + +#m_decoder: lstm +#decode_fusion: cat +#n_layer: 2 +#hidden_size: 512 +rec_weight: 1 +rec_pos_weight: 1 +rec_ver_weight: 1 +# rec_fac_weight: 1 +#ita_weight: 0 +#iwa_weight: 0 +#fusion_mode: sum +# grad_norm: 1 +epochs: 800 +test_period: 100 \ No newline at end of file diff --git a/configs/cnn_vqvae_lower_foot_30.yaml b/configs/cnn_vqvae_lower_foot_30.yaml new file mode 100644 index 0000000000000000000000000000000000000000..161befcd0344ee38cff19e7e6a09fa26523b083f --- /dev/null +++ b/configs/cnn_vqvae_lower_foot_30.yaml @@ -0,0 +1,81 @@ +is_train: True +ddp: False +stat: ts +training_speakers: [2] +root_path: ./ +out_path: ./outputs/audio2pose/ +cache_path: datasets/beat_cache/beat_smplx_en_lower/ +project: mage_smplx +data_path: ./BEAT2/beat_english_v2.0.0/ +e_path: weights/AESKConv_240_100.bin +test_ckpt: weights/multi.bin +data_path_1: ./EMAGE/ +#torch_hub_path: datasets/hub/ +additional_data: False +dataset: beat_sep_lower +new_cache: False +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_lower +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 + + +vae_test_len: 64 +vae_test_dim: 61 +vae_test_stride: 20 +vae_length: 256 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +pose_dims: 61 +pose_length: 64 +stride: 20 +facial_dims: 100 +word_index_num: 11195 +word_dims: 300 +batch_size: 64 +lr_base: 3e-4 +model: motion_representation +g_name: VQVAEConvZero +#eval_model: motion_autoencoder +#e_name: HalfEmbeddingNet +trainer: aelowerfoot +decay_epochs: 780 +# audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: lp +# a_fix_pre: False + +# freeze_wordembed: False +# word_f: 128 +# t_pre_encoder: fasttext +# t_encoder: lp +# t_fix_pre: False + +# motion_f: 256 +# m_pre_encoder: lp +# m_encoder: lp +# m_fix_pre: False + +# facial_f: 128 +# f_pre_encoder: lp +# f_encoder: lp +# f_fix_pre: False + +#m_decoder: lstm +#decode_fusion: cat +#n_layer: 2 +#hidden_size: 512 +rec_weight: 1 +rec_pos_weight: 1 +rec_ver_weight: 1 +# rec_fac_weight: 1 +#ita_weight: 0 +#iwa_weight: 0 +#fusion_mode: sum +# grad_norm: 1 +epochs: 800 +test_period: 100 \ No newline at end of file diff --git a/configs/cnn_vqvae_upper_30.yaml b/configs/cnn_vqvae_upper_30.yaml new file mode 100644 index 0000000000000000000000000000000000000000..db52e20b72ec2f18c81be58bf5414b99f8315fe5 --- /dev/null +++ b/configs/cnn_vqvae_upper_30.yaml @@ -0,0 +1,82 @@ +is_train: True +ddp: False +stat: ts +training_speakers: [2] +root_path: ./ +out_path: ./outputs/audio2pose/ +cache_path: datasets/beat_cache/beat_smplx_en_upper/ +project: mage_smplx +data_path: ./BEAT2/beat_english_v2.0.0/ +e_path: weights/AESKConv_240_100.bin +test_ckpt: weights/multi.bin +data_path_1: ./EMAGE/ +#torch_hub_path: datasets/hub/ +additional_data: False +dataset: beat_sep +new_cache: False +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_upper +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 + + +vae_test_len: 64 +vae_test_dim: 78 +vae_test_stride: 20 +vae_length: 256 +vae_codebook_size: 256 +vae_layer: 2 +vae_grow: [1,1,2,1] +variational: False + +pose_dims: 78 +pose_length: 64 +stride: 20 +facial_dims: 100 +word_index_num: 11195 +word_dims: 300 +batch_size: 64 +lr_base: 3e-4 +decay_epochs: 9999 +model: motion_representation +g_name: VQVAEConvZero +#eval_model: motion_autoencoder +#e_name: HalfEmbeddingNet +trainer: ae + +# audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: lp +# a_fix_pre: False + +# freeze_wordembed: False +# word_f: 128 +# t_pre_encoder: fasttext +# t_encoder: lp +# t_fix_pre: False + +# motion_f: 256 +# m_pre_encoder: lp +# m_encoder: lp +# m_fix_pre: False + +# facial_f: 128 +# f_pre_encoder: lp +# f_encoder: lp +# f_fix_pre: False + +#m_decoder: lstm +#decode_fusion: cat +#n_layer: 2 +#hidden_size: 512 +rec_weight: 1 +rec_pos_weight: 1 +rec_ver_weight: 1 +# rec_fac_weight: 1 +#ita_weight: 0 +#iwa_weight: 0 +#fusion_mode: sum +# grad_norm: 1 +epochs: 500 +test_period: 100 \ No newline at end of file diff --git a/configs/emage.yaml b/configs/emage.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1160bde10ed93a591ab2fc8903255d08391d9ff7 --- /dev/null +++ b/configs/emage.yaml @@ -0,0 +1,101 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +data_path: ./BEAT2/beat_english_v2.0.0/ +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +test_ckpt: ./EMAGE/emage_240.bin +data_path_1: ./EMAGE/ +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage/ +dataset: beat_sep_lower +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 64 +stride: 20 +test_length: 64 +motion_f: 256 +m_pre_encoder: null +m_encoder: null +m_fix_pre: False + +# audio config +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: none +# a_fix_pre: False + +# text config +word_rep: textgrid +word_index_num: 11195 +word_dims: 300 +freeze_wordembed: False +word_f: 256 +t_pre_encoder: fasttext +t_encoder: null +t_fix_pre: False + +# facial config +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 +f_pre_encoder: null +f_encoder: null +f_fix_pre: False + +# speaker config +id_rep: onehot +speaker_f: 0 + +# model config +batch_size: 64 +# warmup_epochs: 1 +# warmup_lr: 1e-6 +lr_base: 5e-4 +model: emage +g_name: MAGE_Transformer +trainer: emage +hidden_size: 768 +n_layer: 1 + +rec_weight: 1 +grad_norm: 0.99 +epochs: 400 +test_period: 20 +ll: 3 +lf: 3 +lu: 3 +lh: 3 +cl: 1 +cf: 0 +cu: 1 +ch: 1 diff --git a/configs/emage_test.yaml b/configs/emage_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca29c88ba4cfe028f73a9f7ae3636b1e82be2a39 --- /dev/null +++ b/configs/emage_test.yaml @@ -0,0 +1,101 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +data_path: ./EMAGE/test_sequences/ +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +test_ckpt: ./EMAGE/emage_240.bin +data_path_1: ./EMAGE/ +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: ./datasets/beat_cache/beat_smplx_en_emage_test/ +dataset: beat_testonly +new_cache: True + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 64 +stride: 20 +test_length: 64 +motion_f: 256 +m_pre_encoder: null +m_encoder: null +m_fix_pre: False + +# audio config +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: none +# a_fix_pre: False + +# text config +word_rep: textgrid +word_index_num: 11195 +word_dims: 300 +freeze_wordembed: False +word_f: 256 +t_pre_encoder: fasttext +t_encoder: null +t_fix_pre: False + +# facial config +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 +f_pre_encoder: null +f_encoder: null +f_fix_pre: False + +# speaker config +id_rep: onehot +speaker_f: 0 + +# model config +batch_size: 64 +# warmup_epochs: 1 +# warmup_lr: 1e-6 +lr_base: 5e-4 +model: emage +g_name: MAGE_Transformer +trainer: emage +hidden_size: 768 +n_layer: 1 + +rec_weight: 1 +grad_norm: 0.99 +epochs: 400 +test_period: 20 +ll: 3 +lf: 3 +lu: 3 +lh: 3 +cl: 1 +cf: 0 +cu: 1 +ch: 1 diff --git a/configs/emage_test_colab.yaml b/configs/emage_test_colab.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79d64b05931bcba2291bea6199c8fff2a1557d71 --- /dev/null +++ b/configs/emage_test_colab.yaml @@ -0,0 +1,101 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +data_path: ./EMAGE/test_sequences/ +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +test_ckpt: ./EMAGE/emage_240.bin +data_path_1: ./EMAGE/ +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: ./datasets/beat_cache/beat_smplx_en_emage_test/ +dataset: beat_testonly_colab +new_cache: True + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 64 +stride: 20 +test_length: 64 +motion_f: 256 +m_pre_encoder: null +m_encoder: null +m_fix_pre: False + +# audio config +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: none +# a_fix_pre: False + +# text config +word_rep: textgrid +word_index_num: 11195 +word_dims: 300 +freeze_wordembed: False +word_f: 256 +t_pre_encoder: fasttext +t_encoder: null +t_fix_pre: False + +# facial config +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 +f_pre_encoder: null +f_encoder: null +f_fix_pre: False + +# speaker config +id_rep: onehot +speaker_f: 0 + +# model config +batch_size: 64 +# warmup_epochs: 1 +# warmup_lr: 1e-6 +lr_base: 5e-4 +model: emage +g_name: MAGE_Transformer +trainer: emage +hidden_size: 768 +n_layer: 1 + +rec_weight: 1 +grad_norm: 0.99 +epochs: 400 +test_period: 20 +ll: 3 +lf: 3 +lu: 3 +lh: 3 +cl: 1 +cf: 0 +cu: 1 +ch: 1 diff --git a/configs/emage_test_hf.yaml b/configs/emage_test_hf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..564aaf431479e2d25825507d2819b8e247e857a1 --- /dev/null +++ b/configs/emage_test_hf.yaml @@ -0,0 +1,101 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +data_path: ./EMAGE/test_sequences/ +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +test_ckpt: ./EMAGE/emage_audio_175.bin +data_path_1: ./EMAGE/ +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: ./datasets/beat_cache/beat_smplx_en_emage_test/ +dataset: beat_testonly_hf +new_cache: True + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 64 +stride: 20 +test_length: 64 +motion_f: 256 +m_pre_encoder: null +m_encoder: null +m_fix_pre: False + +# audio config +audio_rep: wave16k +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: none +# a_fix_pre: False + +# text config +# word_rep: textgrid +# word_index_num: 11195 +# word_dims: 300 +# freeze_wordembed: False +# word_f: 256 +# t_pre_encoder: fasttext +# t_encoder: null +# t_fix_pre: False + +# facial config +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 +f_pre_encoder: null +f_encoder: null +f_fix_pre: False + +# speaker config +id_rep: onehot +speaker_f: 0 + +# model config +batch_size: 64 +# warmup_epochs: 1 +# warmup_lr: 1e-6 +lr_base: 5e-4 +model: emage_audio +g_name: MAGE_Transformer +trainer: emage +hidden_size: 768 +n_layer: 1 + +rec_weight: 1 +grad_norm: 0.99 +epochs: 400 +test_period: 20 +ll: 3 +lf: 3 +lu: 3 +lh: 3 +cl: 1 +cf: 0 +cu: 1 +ch: 1 diff --git a/configs/skcnn_ae.yaml b/configs/skcnn_ae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19430f8e617d1b069c023684e46eaceabe6fd795 --- /dev/null +++ b/configs/skcnn_ae.yaml @@ -0,0 +1,80 @@ +is_train: True +ddp: False +stat: ts +training_speakers: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +root_path: /home/s24273/ +out_path: /home/s24273/outputs/audio2pose/ +cache_path: datasets/beat_cache/beat_smplx_en/ +project: mage_smplx +data_path: /home/s24273/datasets/beat_v2.0.0/beat_english_v2.0.0/ +e_path: weights/AESKConv_240_100.bin +test_ckpt: weights/multi.bin +data_path_1: /home/s24273/datasets/hub/ +#torch_hub_path: datasets/hub/ +additional_data: False +dataset: beat_smplx2020 +new_cache: False +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 + + +vae_test_len: 64 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_layer: 2 +vae_grow: [1,2] +variational: False + +pose_dims: 330 +pose_length: 64 +stride: 20 +facial_dims: 100 +word_index_num: 11195 +word_dims: 300 +batch_size: 32 +lr_base: 1e-4 +model: motion_representation +g_name: VAESKConv +#eval_model: motion_autoencoder +#e_name: HalfEmbeddingNet +trainer: ae +decay_epochs: 950 +# audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: lp +# a_fix_pre: False + +# freeze_wordembed: False +# word_f: 128 +# t_pre_encoder: fasttext +# t_encoder: lp +# t_fix_pre: False + +# motion_f: 256 +# m_pre_encoder: lp +# m_encoder: lp +# m_fix_pre: False + +# facial_f: 128 +# f_pre_encoder: lp +# f_encoder: lp +# f_fix_pre: False + +#m_decoder: lstm +#decode_fusion: cat +#n_layer: 2 +#hidden_size: 512 +rec_weight: 1 +rec_pos_weight: 10 +rec_ver_weight: 0 +# rec_fac_weight: 1 +#ita_weight: 0 +#iwa_weight: 0 +#fusion_mode: sum +# grad_norm: 1 +epochs: 1000 +test_period: 100 \ No newline at end of file diff --git a/dataloaders/.ipynb_checkpoints/beat_testonly_hf-checkpoint.py b/dataloaders/.ipynb_checkpoints/beat_testonly_hf-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..17393517d0d6f662114c19b009fa3471d7c1e91b --- /dev/null +++ b/dataloaders/.ipynb_checkpoints/beat_testonly_hf-checkpoint.py @@ -0,0 +1,740 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +import pyarrow +import librosa +import smplx + +from .build_vocab import Vocab +from .utils.audio_features import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools_hf + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, smplx_path=None, audio_path=None, text_path=None, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + self.smplx_path = "./EMAGE/test_sequences/smplxflame_30/2_scott_0_1_1.npz" + self.audio_path = audio_path + self.text_path = "./EMAGE/test_sequences/textgrid/2_scott_0_1_1.TextGrid" + self.rank = 0 # dist.get_rank() + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.tar_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).eval() + + split_rule = pd.read_csv(args.data_path+"test.csv") + self.selected_file = split_rule + self.data_dir = args.data_path + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + + if args.word_rep is not None: + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + + preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache" + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + + + def __len__(self): + return self.n_samples + + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + if len(self.args.training_speakers) == 1: + #dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 0.5))# 500M + else: + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 200))# 200G + n_filtered_out = defaultdict(int) + + #for index, file_name in self.selected_file.iterrows(): + #f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.smplx_path#self.data_dir + self.args.pose_rep + "/" + f_name + ext + pose_each_file = [] + trans_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = "dummy 2nd"#f_name + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] + trans_each_file = pose_data["trans"][::stride] + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + + assert self.args.pose_fps == 30, "should 30" + m_data = np.load(pose_file, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).float() + poses = torch.from_numpy(poses.reshape(n, c)).float() + exps = torch.from_numpy(exps.reshape(n, 100)).float() + trans = torch.from_numpy(trans.reshape(n, 3)).float() + max_length = 128 + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(max_length, 4, 3).cpu() + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(r, 4, 3).cpu() + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) # all, 4, 3 + # print(joints.shape) + feetv = torch.zeros(joints.shape[1], joints.shape[0]) + joints = joints.permute(1, 0, 2) + #print(joints.shape, feetv.shape) + feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1) + #print(feetv.shape) + contacts = (feetv < 0.01).numpy().astype(float) + # print(contacts.shape, contacts) + contacts = contacts.transpose(1, 0) + pose_each_file = pose_each_file * self.joint_mask + pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)] + pose_each_file = np.concatenate([pose_each_file, contacts], axis=1) + # print(pose_each_file.shape) + + + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + else: + assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(120/self.args.pose_fps) + with open(pose_file, "r") as pose_data: + for j, line in enumerate(pose_data.readlines()): + if j < 431: continue + if j%stride != 0:continue + data = np.fromstring(line, dtype=float, sep=" ") + rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ") + rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3) + rot_data = rot_data.numpy() * self.joint_mask + + pose_each_file.append(rot_data) + trans_each_file.append(data[:3]) + + pose_each_file = np.array(pose_each_file) + trans_each_file = np.array(trans_each_file) + shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json") + assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(60/self.args.pose_fps) + if not os.path.exists(facial_file): + logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #") + #self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + #continue + with open(facial_file, 'r') as facial_data_file: + facial_data = json.load(facial_data_file) + for j, frame_data in enumerate(facial_data['frames']): + if j%stride != 0:continue + facial_each_file.append(frame_data['weights']) + facial_each_file = np.array(facial_each_file) + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + int_value = 1 + vid_each_file = np.repeat(np.array(int_value).reshape(1, 1), pose_each_file.shape[0], axis=0) + + if self.args.audio_rep is not None: + logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #") + audio_file = self.audio_path[1]#pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + sr = self.audio_path[0] + print(sr) + #if not os.path.exists(audio_file): + # logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #") + #self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + #continue + #audio_each_file, sr = librosa.load(audio_file) + audio_each_file = audio_file.astype(np.float32) + print(audio_each_file.shape) + audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr) + print(audio_each_file.shape) + if self.args.audio_rep == "onset+amplitude": + from numpy.lib import stride_tricks + frame_length = 1024 + # hop_length = 512 + shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length) + strides = (audio_each_file.strides[-1], audio_each_file.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + # pad the last frame_length-1 samples + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames') + onset_array = np.zeros(len(audio_each_file), dtype=float) + onset_array[audio_onset_f] = 1.0 + # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape) + audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + elif self.args.audio_rep == "mfcc": + audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps)) + audio_each_file = audio_each_file.transpose(1, 0) + # print(audio_each_file.shape, pose_each_file.shape) + if self.args.audio_norm and self.args.audio_rep == "wave16k": + audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio + + time_offset = 0 + if self.args.word_rep is not None: + logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #") + word_file = self.text_path#f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid" + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #") + #self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + #continue + tgrid = tg.TextGrid.fromFile(word_file) + if self.args.t_pre_encoder == "bert": + from transformers import AutoTokenizer, BertModel + tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True) + model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval() + list_word = [] + all_hidden = [] + max_len = 400 + last = 0 + word_token_mapping = [] + first = True + for i, word in enumerate(tgrid[0]): + last = i + if (i%max_len != 0) or (i==0): + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + else: + max_counter = max_len + str_word = ' '.join(map(str, list_word)) + if first: + global_len = 0 + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + #print(i+global_len) + sub_mapping.append(i+global_len) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + global_len = word_token_mapping[-1][-1] + 1 + list_word = [] + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + + #list_word = list_word[:10] + if list_word == []: + pass + else: + if first: + global_len = 0 + str_word = ' '.join(map(str, list_word)) + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + sub_mapping.append(i+global_len) + #print(sub_mapping) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + last_hidden_states = np.concatenate(all_hidden, axis=0) + + for i in range(pose_each_file.shape[0]): + found_flag = False + current_time = i/self.args.pose_fps + time_offset + j_last = 0 + for j, word in enumerate(tgrid[0]): + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + if word_s<=current_time and current_time<=word_e: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + mapping_index = word_token_mapping[j] + #print(mapping_index, word_s, word_e) + s_t = np.linspace(word_s, word_e, len(mapping_index)+1) + #print(s_t) + for tt, t_sep in enumerate(s_t[1:]): + if current_time <= t_sep: + #if len(mapping_index) > 1: print(mapping_index[tt]) + word_each_file.append(last_hidden_states[mapping_index[tt]]) + break + else: + if word_n == " ": + word_each_file.append(self.lang_model.PAD_token) + else: + word_each_file.append(self.lang_model.get_word_index(word_n)) + found_flag = True + j_last = j + break + else: continue + if not found_flag: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + word_each_file.append(last_hidden_states[j_last]) + else: + word_each_file.append(self.lang_model.UNK_token) + word_each_file = np.array(word_each_file) + #print(word_each_file.shape) + + if self.args.emo_rep is not None: + logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #") + rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3]) + if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6: + if start >= 1 and start <= 64: + score = 0 + elif start >= 65 and start <= 72: + score = 1 + elif start >= 73 and start <= 80: + score = 2 + elif start >= 81 and start <= 86: + score = 3 + elif start >= 87 and start <= 94: + score = 4 + elif start >= 95 and start <= 102: + score = 5 + elif start >= 103 and start <= 110: + score = 6 + elif start >= 111 and start <= 118: + score = 7 + else: pass + else: + # you may denote as unknown in the future + score = 0 + emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0) + #print(emo_each_file) + + if self.args.sem_rep is not None: + logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #") + sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt" + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + # we adopt motion-level semantic score here. + for i in range(pose_each_file.shape[0]): + found_flag = False + for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])): + current_time = i/self.args.pose_fps + time_offset + if start<=current_time and current_time<=end: + sem_each_file.append(score) + found_flag=True + break + else: continue + if not found_flag: sem_each_file.append(0.) + sem_each_file = np.array(sem_each_file) + #print(sem_each_file) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + # audio_start = int(self.alignment[0] * self.args.audio_fps) + # pose_start = int(self.alignment[1] * self.args.pose_fps) + #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}") + # audio_each_file = audio_each_file[audio_start:] + # pose_each_file = pose_each_file[pose_start:] + # trans_each_file = + #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}") + #print(pose_each_file.shape) + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + print(pose_each_file.shape[0]) + #print(round_seconds_skeleton) + #if audio_each_file != []: + if self.args.audio_rep != "wave16k": + round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s + elif self.args.audio_rep == "mfcc": + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps + else: + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr + # if facial_each_file != []: + round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + # else: + # logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + # round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton) + # max_round = max(round_seconds_audio, round_seconds_skeleton) + # if round_seconds_skeleton != max_round: + # logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + # if audio_each_file != []: + audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps) + logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}") + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_audio_list = [] + sample_facial_list = [] + sample_shape_list = [] + sample_word_list = [] + sample_emo_list = [] + sample_sem_list = [] + sample_vid_list = [] + sample_trans_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + + sample_trans = trans_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + if self.args.audio_rep is not None: + audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps) + audio_end = audio_start + audio_short_length + sample_audio = audio_each_file[audio_start:audio_end] + else: + sample_audio = np.array([-1]) + sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1]) + sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1]) + sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1]) + sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1]) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + # filtering motion skeleton data + sample_pose, filtering_message = MotionPreprocessor(sample_pose).get() + is_correct_motion = True #(sample_pose != []) + if is_correct_motion or disable_filtering: + sample_pose_list.append(sample_pose) + sample_audio_list.append(sample_audio) + sample_facial_list.append(sample_facial) + sample_shape_list.append(sample_shape) + sample_word_list.append(sample_word) + sample_vid_list.append(sample_vid) + sample_emo_list.append(sample_emo) + sample_sem_list.append(sample_sem) + sample_trans_list.append(sample_trans) + else: + n_filtered_out[filtering_message] += 1 + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, audio, facial, shape, word, vid, emo, sem, trans in zip( + sample_pose_list, + sample_audio_list, + sample_facial_list, + sample_shape_list, + sample_word_list, + sample_vid_list, + sample_emo_list, + sample_sem_list, + sample_trans_list,): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose, audio, facial, shape, word, emo, sem, vid, trans] + # v = pyarrow.serialize(v).to_buffer() + # txn.put(k, v) + # self.n_out_samples += 1 + v = pickle.dumps(v) + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + # sample = pyarrow.deserialize(sample) + if sample is not None: + sample = pickle.loads(sample) + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample + #print(in_shape) + #vid = torch.from_numpy(vid).int() + emo = torch.from_numpy(emo).int() + sem = torch.from_numpy(sem).float() + in_audio = torch.from_numpy(in_audio).float() + in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int() + if self.loader_type == "test": + tar_pose = torch.from_numpy(tar_pose).float() + trans = torch.from_numpy(trans).float() + in_facial = torch.from_numpy(in_facial).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float() + in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float() + return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans} + + +class MotionPreprocessor: + def __init__(self, skeletons): + self.skeletons = skeletons + #self.mean_pose = mean_pose + self.filtering_message = "PASS" + + def get(self): + assert (self.skeletons is not None) + + # filtering + # if self.skeletons != []: + # if self.check_pose_diff(): + # self.skeletons = [] + # self.filtering_message = "pose" + # elif self.check_spine_angle(): + # self.skeletons = [] + # self.filtering_message = "spine angle" + # elif self.check_static_motion(): + # self.skeletons = [] + # self.filtering_message = "motion" + + # if self.skeletons != []: + # self.skeletons = self.skeletons.tolist() + # for i, frame in enumerate(self.skeletons): + # assert not np.isnan(self.skeletons[i]).any() # missing joints + + return self.skeletons, self.filtering_message + + def check_static_motion(self, verbose=True): + def get_variance(skeleton, joint_idx): + wrist_pos = skeleton[:, joint_idx] + variance = np.sum(np.var(wrist_pos, axis=0)) + return variance + + left_arm_var = get_variance(self.skeletons, 6) + right_arm_var = get_variance(self.skeletons, 9) + + th = 0.0014 # exclude 13110 + # th = 0.002 # exclude 16905 + if left_arm_var < th and right_arm_var < th: + if verbose: + print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return True + else: + if verbose: + print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return False + + + def check_pose_diff(self, verbose=False): +# diff = np.abs(self.skeletons - self.mean_pose) # 186*1 +# diff = np.mean(diff) + +# # th = 0.017 +# th = 0.02 #0.02 # exclude 3594 +# if diff < th: +# if verbose: +# print("skip - check_pose_diff {:.5f}".format(diff)) +# return True +# # th = 3.5 #0.02 # exclude 3594 +# # if 3.5 < diff < 5: +# # if verbose: +# # print("skip - check_pose_diff {:.5f}".format(diff)) +# # return True +# else: +# if verbose: +# print("pass - check_pose_diff {:.5f}".format(diff)) + return False + + + def check_spine_angle(self, verbose=True): + def angle_between(v1, v2): + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + angles = [] + for i in range(self.skeletons.shape[0]): + spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] + angle = angle_between(spine_vec, [0, -1, 0]) + angles.append(angle) + + if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 + # if np.rad2deg(max(angles)) > 20: # exclude 8270 + if verbose: + print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles))) + return True + else: + if verbose: + print("pass - check_spine_angle {:.5f}".format(max(angles))) + return False \ No newline at end of file diff --git a/dataloaders/__pycache__/beat_testonly_hf.cpython-310.pyc b/dataloaders/__pycache__/beat_testonly_hf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6606ef23e6ff89c968415cbd5938753b2e7b450 Binary files /dev/null and b/dataloaders/__pycache__/beat_testonly_hf.cpython-310.pyc differ diff --git a/dataloaders/__pycache__/beat_testonly_hf.cpython-38.pyc b/dataloaders/__pycache__/beat_testonly_hf.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b4442886fefd170a44a27598eab176fc9abaad5 Binary files /dev/null and b/dataloaders/__pycache__/beat_testonly_hf.cpython-38.pyc differ diff --git a/dataloaders/__pycache__/build_vocab.cpython-310.pyc b/dataloaders/__pycache__/build_vocab.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed333dec7d794bddf8216370abad83cdcd6bbbb6 Binary files /dev/null and b/dataloaders/__pycache__/build_vocab.cpython-310.pyc differ diff --git a/dataloaders/__pycache__/build_vocab.cpython-38.pyc b/dataloaders/__pycache__/build_vocab.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f391643b5e18cafcb145e2d8049141c978a0e26a Binary files /dev/null and b/dataloaders/__pycache__/build_vocab.cpython-38.pyc differ diff --git a/dataloaders/__pycache__/data_tools.cpython-310.pyc b/dataloaders/__pycache__/data_tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdf63580251ed8ad599b546f3defd04e1491f38c Binary files /dev/null and b/dataloaders/__pycache__/data_tools.cpython-310.pyc differ diff --git a/dataloaders/__pycache__/data_tools.cpython-38.pyc b/dataloaders/__pycache__/data_tools.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ebb215579d4b2073878f92efe3fef3604c2fe00 Binary files /dev/null and b/dataloaders/__pycache__/data_tools.cpython-38.pyc differ diff --git a/dataloaders/beat_sep.py b/dataloaders/beat_sep.py new file mode 100644 index 0000000000000000000000000000000000000000..e5dc7dff9f52f0d8c9762946006fb6dd32540f23 --- /dev/null +++ b/dataloaders/beat_sep.py @@ -0,0 +1,771 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +import pyarrow +import librosa +import smplx + +from .build_vocab import Vocab +from .utils.audio_features import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + + self.rank = dist.get_rank() + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.tar_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + + split_rule = pd.read_csv(args.data_path+"train_test_split.csv") + self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + if args.additional_data and loader_type == 'train': + split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + #self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = pd.concat([self.selected_file, split_b]) + if self.selected_file.empty: + logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") + self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = self.selected_file.iloc[0:8] + self.data_dir = args.data_path + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + + if args.word_rep is not None: + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + + preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache" + # if args.pose_norm: + # # careful for rotation vectors + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_pose() + # self.mean_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy") + # self.std_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_std.npy") + # if args.audio_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_audio() + # self.mean_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_mean.npy") + # self.std_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_std.npy") + # if args.facial_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_face() + # self.mean_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy") + # self.std_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy") + if self.args.beat_align: + if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"): + self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + + def calculate_mean_velocity(self, save_path): + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + dir_p = self.data_dir + self.args.pose_rep + "/" + all_list = [] + from tqdm import tqdm + for tar in tqdm(os.listdir(dir_p)): + if tar.endswith(".npz"): + m_data = np.load(dir_p+tar, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).cuda().float() + poses = torch.from_numpy(poses.reshape(n, c)).cuda().float() + exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float() + trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float() + max_length = 128 + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, :55, :].reshape(max_length, 55*3) + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, :55, :].reshape(r, 55*3) + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) + joints = joints.permute(1, 0) + dt = 1/30 + # first steps is forward diff (t+1 - t) / dt + init_vel = (joints[:, 1:2] - joints[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt + #print(joints.shape, init_vel.shape, middle_vel.shape, final_vel.shape) + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1).permute(1, 0).reshape(n, 55, 3) + #print(vel_seq.shape) + #.permute(1, 0).reshape(n, 55, 3) + vel_seq_np = vel_seq.cpu().numpy() + vel_joints_np = np.linalg.norm(vel_seq_np, axis=2) # n * 55 + all_list.append(vel_joints_np) + avg_vel = np.mean(np.concatenate(all_list, axis=0),axis=0) # 55 + np.save(save_path, avg_vel) + + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + + def __len__(self): + return self.n_samples + + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + # if "wav2vec2" in self.args.audio_rep: + # self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h") + # self.wav2vec_model.feature_extractor._freeze_parameters() + # self.wav2vec_model = self.wav2vec_model.cuda() + # self.wav2vec_model.eval() + + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G + n_filtered_out = defaultdict(int) + + for index, file_name in self.selected_file.iterrows(): + f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext + pose_each_file = [] + trans_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = f_name #1_wayne_0_1_1 + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] * self.joint_mask + pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)] + # print(pose_each_file.shape) + trans_each_file = pose_data["trans"][::stride] + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + else: + assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(120/self.args.pose_fps) + with open(pose_file, "r") as pose_data: + for j, line in enumerate(pose_data.readlines()): + if j < 431: continue + if j%stride != 0:continue + data = np.fromstring(line, dtype=float, sep=" ") + rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ") + rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3) + rot_data = rot_data.numpy() * self.joint_mask + + pose_each_file.append(rot_data) + trans_each_file.append(data[:3]) + + pose_each_file = np.array(pose_each_file) + # print(pose_each_file.shape) + trans_each_file = np.array(trans_each_file) + shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json") + assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(60/self.args.pose_fps) + if not os.path.exists(facial_file): + logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + with open(facial_file, 'r') as facial_data_file: + facial_data = json.load(facial_data_file) + for j, frame_data in enumerate(facial_data['frames']): + if j%stride != 0:continue + facial_each_file.append(frame_data['weights']) + facial_each_file = np.array(facial_each_file) + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + + if self.args.audio_rep is not None: + logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #") + audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + if not os.path.exists(audio_file): + logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + audio_each_file, sr = librosa.load(audio_file) + audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr) + if self.args.audio_rep == "onset+amplitude": + from numpy.lib import stride_tricks + frame_length = 1024 + # hop_length = 512 + shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length) + strides = (audio_each_file.strides[-1], audio_each_file.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + # pad the last frame_length-1 samples + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames') + onset_array = np.zeros(len(audio_each_file), dtype=float) + onset_array[audio_onset_f] = 1.0 + # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape) + audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + elif self.args.audio_rep == "mfcc": + audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps)) + audio_each_file = audio_each_file.transpose(1, 0) + # print(audio_each_file.shape, pose_each_file.shape) + if self.args.audio_norm and self.args.audio_rep == "wave16k": + audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio + # print(audio_each_file.shape) + time_offset = 0 + if self.args.word_rep is not None: + logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #") + word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid" + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + tgrid = tg.TextGrid.fromFile(word_file) + if self.args.t_pre_encoder == "bert": + from transformers import AutoTokenizer, BertModel + tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True) + model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval() + list_word = [] + all_hidden = [] + max_len = 400 + last = 0 + word_token_mapping = [] + first = True + for i, word in enumerate(tgrid[0]): + last = i + if (i%max_len != 0) or (i==0): + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + else: + max_counter = max_len + str_word = ' '.join(map(str, list_word)) + if first: + global_len = 0 + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + #print(i+global_len) + sub_mapping.append(i+global_len) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + global_len = word_token_mapping[-1][-1] + 1 + list_word = [] + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + + #list_word = list_word[:10] + if list_word == []: + pass + else: + if first: + global_len = 0 + str_word = ' '.join(map(str, list_word)) + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + sub_mapping.append(i+global_len) + #print(sub_mapping) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + last_hidden_states = np.concatenate(all_hidden, axis=0) + + for i in range(pose_each_file.shape[0]): + found_flag = False + current_time = i/self.args.pose_fps + time_offset + j_last = 0 + for j, word in enumerate(tgrid[0]): + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + if word_s<=current_time and current_time<=word_e: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + mapping_index = word_token_mapping[j] + #print(mapping_index, word_s, word_e) + s_t = np.linspace(word_s, word_e, len(mapping_index)+1) + #print(s_t) + for tt, t_sep in enumerate(s_t[1:]): + if current_time <= t_sep: + #if len(mapping_index) > 1: print(mapping_index[tt]) + word_each_file.append(last_hidden_states[mapping_index[tt]]) + break + else: + if word_n == " ": + word_each_file.append(self.lang_model.PAD_token) + else: + word_each_file.append(self.lang_model.get_word_index(word_n)) + found_flag = True + j_last = j + break + else: continue + if not found_flag: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + word_each_file.append(last_hidden_states[j_last]) + else: + word_each_file.append(self.lang_model.UNK_token) + word_each_file = np.array(word_each_file) + #print(word_each_file.shape) + + if self.args.emo_rep is not None: + logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #") + rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3]) + if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6: + if start >= 1 and start <= 64: + score = 0 + elif start >= 65 and start <= 72: + score = 1 + elif start >= 73 and start <= 80: + score = 2 + elif start >= 81 and start <= 86: + score = 3 + elif start >= 87 and start <= 94: + score = 4 + elif start >= 95 and start <= 102: + score = 5 + elif start >= 103 and start <= 110: + score = 6 + elif start >= 111 and start <= 118: + score = 7 + else: pass + else: + # you may denote as unknown in the future + score = 0 + emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0) + #print(emo_each_file) + + if self.args.sem_rep is not None: + logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #") + sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt" + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + # we adopt motion-level semantic score here. + for i in range(pose_each_file.shape[0]): + found_flag = False + for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])): + current_time = i/self.args.pose_fps + time_offset + if start<=current_time and current_time<=end: + sem_each_file.append(score) + found_flag=True + break + else: continue + if not found_flag: sem_each_file.append(0.) + sem_each_file = np.array(sem_each_file) + #print(sem_each_file) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + # audio_start = int(self.alignment[0] * self.args.audio_fps) + # pose_start = int(self.alignment[1] * self.args.pose_fps) + #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}") + # audio_each_file = audio_each_file[audio_start:] + # pose_each_file = pose_each_file[pose_start:] + # trans_each_file = + #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}") + #print(pose_each_file.shape) + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + #print(round_seconds_skeleton) + if audio_each_file != []: + if self.args.audio_rep != "wave16k": + round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s + elif self.args.audio_rep == "mfcc": + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps + else: + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr + if facial_each_file != []: + round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + else: + logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton) + max_round = max(round_seconds_audio, round_seconds_skeleton) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + if audio_each_file != []: + audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps) + """ + for audio sr = 16000, fps = 15, pose_length = 34, + audio short length = 36266.7 -> 36266 + this error is fine. + """ + logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}") + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_audio_list = [] + sample_facial_list = [] + sample_shape_list = [] + sample_word_list = [] + sample_emo_list = [] + sample_sem_list = [] + sample_vid_list = [] + sample_trans_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + sample_trans = trans_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + if self.args.audio_rep is not None: + audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps) + audio_end = audio_start + audio_short_length + sample_audio = audio_each_file[audio_start:audio_end] + else: + sample_audio = np.array([-1]) + sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1]) + sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1]) + sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1]) + sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1]) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + # filtering motion skeleton data + sample_pose, filtering_message = MotionPreprocessor(sample_pose).get() + is_correct_motion = (sample_pose != []) + if is_correct_motion or disable_filtering: + sample_pose_list.append(sample_pose) + sample_audio_list.append(sample_audio) + sample_facial_list.append(sample_facial) + sample_shape_list.append(sample_shape) + sample_word_list.append(sample_word) + sample_vid_list.append(sample_vid) + sample_emo_list.append(sample_emo) + sample_sem_list.append(sample_sem) + sample_trans_list.append(sample_trans) + else: + n_filtered_out[filtering_message] += 1 + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, audio, facial, shape, word, vid, emo, sem, trans in zip( + sample_pose_list, + sample_audio_list, + sample_facial_list, + sample_shape_list, + sample_word_list, + sample_vid_list, + sample_emo_list, + sample_sem_list, + sample_trans_list,): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose, audio, facial, shape, word, emo, sem, vid, trans] + v = pyarrow.serialize(v).to_buffer() + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pyarrow.deserialize(sample) + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample + #print(in_shape) + #vid = torch.from_numpy(vid).int() + emo = torch.from_numpy(emo).int() + sem = torch.from_numpy(sem).float() + in_audio = torch.from_numpy(in_audio).float() + in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int() + if self.loader_type == "test": + tar_pose = torch.from_numpy(tar_pose).float() + trans = torch.from_numpy(trans).float() + in_facial = torch.from_numpy(in_facial).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float() + in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float() + return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans} + + +class MotionPreprocessor: + def __init__(self, skeletons): + self.skeletons = skeletons + #self.mean_pose = mean_pose + self.filtering_message = "PASS" + + def get(self): + assert (self.skeletons is not None) + + # filtering + if self.skeletons != []: + if self.check_pose_diff(): + self.skeletons = [] + self.filtering_message = "pose" + # elif self.check_spine_angle(): + # self.skeletons = [] + # self.filtering_message = "spine angle" + # elif self.check_static_motion(): + # self.skeletons = [] + # self.filtering_message = "motion" + + # if self.skeletons != []: + # self.skeletons = self.skeletons.tolist() + # for i, frame in enumerate(self.skeletons): + # assert not np.isnan(self.skeletons[i]).any() # missing joints + + return self.skeletons, self.filtering_message + + def check_static_motion(self, verbose=True): + def get_variance(skeleton, joint_idx): + wrist_pos = skeleton[:, joint_idx] + variance = np.sum(np.var(wrist_pos, axis=0)) + return variance + + left_arm_var = get_variance(self.skeletons, 6) + right_arm_var = get_variance(self.skeletons, 9) + + th = 0.0014 # exclude 13110 + # th = 0.002 # exclude 16905 + if left_arm_var < th and right_arm_var < th: + if verbose: + print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return True + else: + if verbose: + print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return False + + + def check_pose_diff(self, verbose=False): +# diff = np.abs(self.skeletons - self.mean_pose) # 186*1 +# diff = np.mean(diff) + +# # th = 0.017 +# th = 0.02 #0.02 # exclude 3594 +# if diff < th: +# if verbose: +# print("skip - check_pose_diff {:.5f}".format(diff)) +# return True +# # th = 3.5 #0.02 # exclude 3594 +# # if 3.5 < diff < 5: +# # if verbose: +# # print("skip - check_pose_diff {:.5f}".format(diff)) +# # return True +# else: +# if verbose: +# print("pass - check_pose_diff {:.5f}".format(diff)) + return False + + + def check_spine_angle(self, verbose=True): + def angle_between(v1, v2): + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + angles = [] + for i in range(self.skeletons.shape[0]): + spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] + angle = angle_between(spine_vec, [0, -1, 0]) + angles.append(angle) + + if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 + # if np.rad2deg(max(angles)) > 20: # exclude 8270 + if verbose: + print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles))) + return True + else: + if verbose: + print("pass - check_spine_angle {:.5f}".format(max(angles))) + return False \ No newline at end of file diff --git a/dataloaders/beat_sep_lower.py b/dataloaders/beat_sep_lower.py new file mode 100644 index 0000000000000000000000000000000000000000..01c8b228d0436cfbf3b988e8443460af702c02e8 --- /dev/null +++ b/dataloaders/beat_sep_lower.py @@ -0,0 +1,855 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +import pyarrow +import librosa +import smplx + +from .build_vocab import Vocab +from .utils.audio_features import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + + self.rank = dist.get_rank() + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.tar_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + + split_rule = pd.read_csv(args.data_path+"train_test_split.csv") + self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + if args.additional_data and loader_type == 'train': + split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + #self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = pd.concat([self.selected_file, split_b]) + if self.selected_file.empty: + logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") + self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = self.selected_file.iloc[0:8] + self.data_dir = args.data_path + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + + if args.word_rep is not None: + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + + preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache" + # if args.pose_norm: + # # careful for rotation vectors + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_pose() + # self.mean_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy") + # self.std_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_std.npy") + # if args.audio_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_audio() + # self.mean_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_mean.npy") + # self.std_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_std.npy") + # if args.facial_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_face() + # self.mean_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy") + # self.std_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy") + if self.args.beat_align: + if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"): + self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + + def calculate_mean_velocity(self, save_path): + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + dir_p = self.data_dir + self.args.pose_rep + "/" + all_list = [] + from tqdm import tqdm + for tar in tqdm(os.listdir(dir_p)): + if tar.endswith(".npz"): + m_data = np.load(dir_p+tar, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).cuda().float() + poses = torch.from_numpy(poses.reshape(n, c)).cuda().float() + exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float() + trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float() + max_length = 128 + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, :55, :].reshape(max_length, 55*3) + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, :55, :].reshape(r, 55*3) + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) + joints = joints.permute(1, 0) + dt = 1/30 + # first steps is forward diff (t+1 - t) / dt + init_vel = (joints[:, 1:2] - joints[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt + #print(joints.shape, init_vel.shape, middle_vel.shape, final_vel.shape) + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1).permute(1, 0).reshape(n, 55, 3) + #print(vel_seq.shape) + #.permute(1, 0).reshape(n, 55, 3) + vel_seq_np = vel_seq.cpu().numpy() + vel_joints_np = np.linalg.norm(vel_seq_np, axis=2) # n * 55 + all_list.append(vel_joints_np) + avg_vel = np.mean(np.concatenate(all_list, axis=0),axis=0) # 55 + np.save(save_path, avg_vel) + + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + + def __len__(self): + return self.n_samples + + def idmapping(self, id): + # map 1,2,3,4,5, 6,7,9,10,11, 12,13,15,16,17, 18,20,21,22,23, 24,25,27,28,30 to 0-24 + if id == 30: id = 8 + if id == 28: id = 14 + if id == 27: id = 19 + return id - 1 + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + # if "wav2vec2" in self.args.audio_rep: + # self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h") + # self.wav2vec_model.feature_extractor._freeze_parameters() + # self.wav2vec_model = self.wav2vec_model.cuda() + # self.wav2vec_model.eval() + + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + if len(self.args.training_speakers) == 1: + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G + else: + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 200))# 200G + n_filtered_out = defaultdict(int) + + for index, file_name in self.selected_file.iterrows(): + f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext + pose_each_file = [] + trans_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = f_name #1_wayne_0_1_1 + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] + trans_each_file = pose_data["trans"][::stride] + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + + assert self.args.pose_fps == 30, "should 30" + m_data = np.load(pose_file, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).cuda().float() + poses = torch.from_numpy(poses.reshape(n, c)).cuda().float() + exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float() + trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float() + max_length = 128 + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(max_length, 4, 3).cpu() + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(r, 4, 3).cpu() + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) # all, 4, 3 + # print(joints.shape) + feetv = torch.zeros(joints.shape[1], joints.shape[0]) + joints = joints.permute(1, 0, 2) + #print(joints.shape, feetv.shape) + feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1) + #print(feetv.shape) + contacts = (feetv < 0.01).numpy().astype(float) + # print(contacts.shape, contacts) + contacts = contacts.transpose(1, 0) + pose_each_file = pose_each_file * self.joint_mask + pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)] + pose_each_file = np.concatenate([pose_each_file, contacts], axis=1) + # print(pose_each_file.shape) + + + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + else: + assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(120/self.args.pose_fps) + with open(pose_file, "r") as pose_data: + for j, line in enumerate(pose_data.readlines()): + if j < 431: continue + if j%stride != 0:continue + data = np.fromstring(line, dtype=float, sep=" ") + rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ") + rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3) + rot_data = rot_data.numpy() * self.joint_mask + + pose_each_file.append(rot_data) + trans_each_file.append(data[:3]) + + pose_each_file = np.array(pose_each_file) + # print(pose_each_file.shape) + trans_each_file = np.array(trans_each_file) + shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json") + assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(60/self.args.pose_fps) + if not os.path.exists(facial_file): + logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + with open(facial_file, 'r') as facial_data_file: + facial_data = json.load(facial_data_file) + for j, frame_data in enumerate(facial_data['frames']): + if j%stride != 0:continue + facial_each_file.append(frame_data['weights']) + facial_each_file = np.array(facial_each_file) + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + int_value = self.idmapping(int(f_name.split("_")[0])) + vid_each_file = np.repeat(np.array(int_value).reshape(1, 1), pose_each_file.shape[0], axis=0) + + if self.args.audio_rep is not None: + logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #") + audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + if not os.path.exists(audio_file): + logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + audio_each_file, sr = librosa.load(audio_file) + audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr) + if self.args.audio_rep == "onset+amplitude": + from numpy.lib import stride_tricks + frame_length = 1024 + # hop_length = 512 + shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length) + strides = (audio_each_file.strides[-1], audio_each_file.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + # pad the last frame_length-1 samples + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames') + onset_array = np.zeros(len(audio_each_file), dtype=float) + onset_array[audio_onset_f] = 1.0 + # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape) + audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + elif self.args.audio_rep == "mfcc": + audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps)) + audio_each_file = audio_each_file.transpose(1, 0) + # print(audio_each_file.shape, pose_each_file.shape) + if self.args.audio_norm and self.args.audio_rep == "wave16k": + audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio + + time_offset = 0 + if self.args.word_rep is not None: + logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #") + word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid" + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + tgrid = tg.TextGrid.fromFile(word_file) + if self.args.t_pre_encoder == "bert": + from transformers import AutoTokenizer, BertModel + tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True) + model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval() + list_word = [] + all_hidden = [] + max_len = 400 + last = 0 + word_token_mapping = [] + first = True + for i, word in enumerate(tgrid[0]): + last = i + if (i%max_len != 0) or (i==0): + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + else: + max_counter = max_len + str_word = ' '.join(map(str, list_word)) + if first: + global_len = 0 + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + #print(i+global_len) + sub_mapping.append(i+global_len) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + global_len = word_token_mapping[-1][-1] + 1 + list_word = [] + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + + #list_word = list_word[:10] + if list_word == []: + pass + else: + if first: + global_len = 0 + str_word = ' '.join(map(str, list_word)) + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + sub_mapping.append(i+global_len) + #print(sub_mapping) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + last_hidden_states = np.concatenate(all_hidden, axis=0) + + for i in range(pose_each_file.shape[0]): + found_flag = False + current_time = i/self.args.pose_fps + time_offset + j_last = 0 + for j, word in enumerate(tgrid[0]): + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + if word_s<=current_time and current_time<=word_e: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + mapping_index = word_token_mapping[j] + #print(mapping_index, word_s, word_e) + s_t = np.linspace(word_s, word_e, len(mapping_index)+1) + #print(s_t) + for tt, t_sep in enumerate(s_t[1:]): + if current_time <= t_sep: + #if len(mapping_index) > 1: print(mapping_index[tt]) + word_each_file.append(last_hidden_states[mapping_index[tt]]) + break + else: + if word_n == " ": + word_each_file.append(self.lang_model.PAD_token) + else: + word_each_file.append(self.lang_model.get_word_index(word_n)) + found_flag = True + j_last = j + break + else: continue + if not found_flag: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + word_each_file.append(last_hidden_states[j_last]) + else: + word_each_file.append(self.lang_model.UNK_token) + word_each_file = np.array(word_each_file) + #print(word_each_file.shape) + + if self.args.emo_rep is not None: + logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #") + rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3]) + if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6: + if start >= 1 and start <= 64: + score = 0 + elif start >= 65 and start <= 72: + score = 1 + elif start >= 73 and start <= 80: + score = 2 + elif start >= 81 and start <= 86: + score = 3 + elif start >= 87 and start <= 94: + score = 4 + elif start >= 95 and start <= 102: + score = 5 + elif start >= 103 and start <= 110: + score = 6 + elif start >= 111 and start <= 118: + score = 7 + else: pass + else: + # you may denote as unknown in the future + score = 0 + emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0) + #print(emo_each_file) + + if self.args.sem_rep is not None: + logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #") + sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt" + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + # we adopt motion-level semantic score here. + for i in range(pose_each_file.shape[0]): + found_flag = False + for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])): + current_time = i/self.args.pose_fps + time_offset + if start<=current_time and current_time<=end: + sem_each_file.append(score) + found_flag=True + break + else: continue + if not found_flag: sem_each_file.append(0.) + sem_each_file = np.array(sem_each_file) + #print(sem_each_file) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + # audio_start = int(self.alignment[0] * self.args.audio_fps) + # pose_start = int(self.alignment[1] * self.args.pose_fps) + #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}") + # audio_each_file = audio_each_file[audio_start:] + # pose_each_file = pose_each_file[pose_start:] + # trans_each_file = + #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}") + #print(pose_each_file.shape) + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + #print(round_seconds_skeleton) + if audio_each_file != []: + if self.args.audio_rep != "wave16k": + round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s + elif self.args.audio_rep == "mfcc": + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps + else: + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr + if facial_each_file != []: + round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + else: + logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton) + max_round = max(round_seconds_audio, round_seconds_skeleton) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + if audio_each_file != []: + audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps) + """ + for audio sr = 16000, fps = 15, pose_length = 34, + audio short length = 36266.7 -> 36266 + this error is fine. + """ + logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}") + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_audio_list = [] + sample_facial_list = [] + sample_shape_list = [] + sample_word_list = [] + sample_emo_list = [] + sample_sem_list = [] + sample_vid_list = [] + sample_trans_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + + sample_trans = trans_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + if self.args.audio_rep is not None: + audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps) + audio_end = audio_start + audio_short_length + sample_audio = audio_each_file[audio_start:audio_end] + else: + sample_audio = np.array([-1]) + sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1]) + sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1]) + sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1]) + sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1]) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + # filtering motion skeleton data + sample_pose, filtering_message = MotionPreprocessor(sample_pose).get() + is_correct_motion = (sample_pose != []) + if is_correct_motion or disable_filtering: + sample_pose_list.append(sample_pose) + sample_audio_list.append(sample_audio) + sample_facial_list.append(sample_facial) + sample_shape_list.append(sample_shape) + sample_word_list.append(sample_word) + sample_vid_list.append(sample_vid) + sample_emo_list.append(sample_emo) + sample_sem_list.append(sample_sem) + sample_trans_list.append(sample_trans) + else: + n_filtered_out[filtering_message] += 1 + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, audio, facial, shape, word, vid, emo, sem, trans in zip( + sample_pose_list, + sample_audio_list, + sample_facial_list, + sample_shape_list, + sample_word_list, + sample_vid_list, + sample_emo_list, + sample_sem_list, + sample_trans_list,): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose, audio, facial, shape, word, emo, sem, vid, trans] + v = pyarrow.serialize(v).to_buffer() + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pyarrow.deserialize(sample) + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample + #print(in_shape) + #vid = torch.from_numpy(vid).int() + emo = torch.from_numpy(emo).int() + sem = torch.from_numpy(sem).float() + in_audio = torch.from_numpy(in_audio).float() + in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int() + if self.loader_type == "test": + tar_pose = torch.from_numpy(tar_pose).float() + trans = torch.from_numpy(trans).float() + in_facial = torch.from_numpy(in_facial).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float() + in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float() + return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans} + + +class MotionPreprocessor: + def __init__(self, skeletons): + self.skeletons = skeletons + #self.mean_pose = mean_pose + self.filtering_message = "PASS" + + def get(self): + assert (self.skeletons is not None) + + # filtering + if self.skeletons != []: + if self.check_pose_diff(): + self.skeletons = [] + self.filtering_message = "pose" + # elif self.check_spine_angle(): + # self.skeletons = [] + # self.filtering_message = "spine angle" + # elif self.check_static_motion(): + # self.skeletons = [] + # self.filtering_message = "motion" + + # if self.skeletons != []: + # self.skeletons = self.skeletons.tolist() + # for i, frame in enumerate(self.skeletons): + # assert not np.isnan(self.skeletons[i]).any() # missing joints + + return self.skeletons, self.filtering_message + + def check_static_motion(self, verbose=True): + def get_variance(skeleton, joint_idx): + wrist_pos = skeleton[:, joint_idx] + variance = np.sum(np.var(wrist_pos, axis=0)) + return variance + + left_arm_var = get_variance(self.skeletons, 6) + right_arm_var = get_variance(self.skeletons, 9) + + th = 0.0014 # exclude 13110 + # th = 0.002 # exclude 16905 + if left_arm_var < th and right_arm_var < th: + if verbose: + print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return True + else: + if verbose: + print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return False + + + def check_pose_diff(self, verbose=False): +# diff = np.abs(self.skeletons - self.mean_pose) # 186*1 +# diff = np.mean(diff) + +# # th = 0.017 +# th = 0.02 #0.02 # exclude 3594 +# if diff < th: +# if verbose: +# print("skip - check_pose_diff {:.5f}".format(diff)) +# return True +# # th = 3.5 #0.02 # exclude 3594 +# # if 3.5 < diff < 5: +# # if verbose: +# # print("skip - check_pose_diff {:.5f}".format(diff)) +# # return True +# else: +# if verbose: +# print("pass - check_pose_diff {:.5f}".format(diff)) + return False + + + def check_spine_angle(self, verbose=True): + def angle_between(v1, v2): + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + angles = [] + for i in range(self.skeletons.shape[0]): + spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] + angle = angle_between(spine_vec, [0, -1, 0]) + angles.append(angle) + + if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 + # if np.rad2deg(max(angles)) > 20: # exclude 8270 + if verbose: + print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles))) + return True + else: + if verbose: + print("pass - check_spine_angle {:.5f}".format(max(angles))) + return False \ No newline at end of file diff --git a/dataloaders/beat_smplx2020.py b/dataloaders/beat_smplx2020.py new file mode 100644 index 0000000000000000000000000000000000000000..7647c73d5962dd8856584ebdd6cfe26902ef5070 --- /dev/null +++ b/dataloaders/beat_smplx2020.py @@ -0,0 +1,763 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +import pyarrow +import librosa +import smplx + +from .build_vocab import Vocab +from .utils.audio_features import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + + self.rank = dist.get_rank() + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.ori_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + + split_rule = pd.read_csv(args.data_path+"train_test_split.csv") + self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + if args.additional_data and loader_type == 'train': + split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + #self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = pd.concat([self.selected_file, split_b]) + if self.selected_file.empty: + logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") + self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = self.selected_file.iloc[0:8] + self.data_dir = args.data_path + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + + if args.word_rep is not None: + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + + preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache" + # if args.pose_norm: + # # careful for rotation vectors + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_pose() + # self.mean_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy") + # self.std_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_std.npy") + # if args.audio_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_audio() + # self.mean_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_mean.npy") + # self.std_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_std.npy") + # if args.facial_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_face() + # self.mean_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy") + # self.std_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy") + if self.args.beat_align: + if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"): + self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + + def calculate_mean_velocity(self, save_path): + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + dir_p = self.data_dir + self.args.pose_rep + "/" + all_list = [] + from tqdm import tqdm + for tar in tqdm(os.listdir(dir_p)): + if tar.endswith(".npz"): + m_data = np.load(dir_p+tar, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).cuda().float() + poses = torch.from_numpy(poses.reshape(n, c)).cuda().float() + exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float() + trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float() + max_length = 128 + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, :55, :].reshape(max_length, 55*3) + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, :55, :].reshape(r, 55*3) + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) + joints = joints.permute(1, 0) + dt = 1/30 + # first steps is forward diff (t+1 - t) / dt + init_vel = (joints[:, 1:2] - joints[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt + #print(joints.shape, init_vel.shape, middle_vel.shape, final_vel.shape) + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1).permute(1, 0).reshape(n, 55, 3) + #print(vel_seq.shape) + #.permute(1, 0).reshape(n, 55, 3) + vel_seq_np = vel_seq.cpu().numpy() + vel_joints_np = np.linalg.norm(vel_seq_np, axis=2) # n * 55 + all_list.append(vel_joints_np) + avg_vel = np.mean(np.concatenate(all_list, axis=0),axis=0) # 55 + np.save(save_path, avg_vel) + + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + + def __len__(self): + return self.n_samples + + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + # if "wav2vec2" in self.args.audio_rep: + # self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h") + # self.wav2vec_model.feature_extractor._freeze_parameters() + # self.wav2vec_model = self.wav2vec_model.cuda() + # self.wav2vec_model.eval() + + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G + n_filtered_out = defaultdict(int) + + for index, file_name in self.selected_file.iterrows(): + f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext + pose_each_file = [] + trans_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = f_name #1_wayne_0_1_1 + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] * self.joint_mask + trans_each_file = pose_data["trans"][::stride] + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + else: + assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(120/self.args.pose_fps) + with open(pose_file, "r") as pose_data: + for j, line in enumerate(pose_data.readlines()): + if j < 431: continue + if j%stride != 0:continue + data = np.fromstring(line, dtype=float, sep=" ") + rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ") + rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3) + rot_data = rot_data.numpy() * self.joint_mask + + pose_each_file.append(rot_data) + trans_each_file.append(data[:3]) + + pose_each_file = np.array(pose_each_file) + # print(pose_each_file.shape) + trans_each_file = np.array(trans_each_file) + shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json") + assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(60/self.args.pose_fps) + if not os.path.exists(facial_file): + logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + with open(facial_file, 'r') as facial_data_file: + facial_data = json.load(facial_data_file) + for j, frame_data in enumerate(facial_data['frames']): + if j%stride != 0:continue + facial_each_file.append(frame_data['weights']) + facial_each_file = np.array(facial_each_file) + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + + if self.args.audio_rep is not None: + logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #") + audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + if not os.path.exists(audio_file): + logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + audio_each_file, sr = librosa.load(audio_file) + audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr) + if self.args.audio_rep == "onset+amplitude": + from numpy.lib import stride_tricks + frame_length = 1024 + # hop_length = 512 + shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length) + strides = (audio_each_file.strides[-1], audio_each_file.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + # pad the last frame_length-1 samples + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames') + onset_array = np.zeros(len(audio_each_file), dtype=float) + onset_array[audio_onset_f] = 1.0 + # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape) + audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + elif self.args.audio_rep == "mfcc": + audio_each_file = librosa.feature.mfcc(audio_each_file, sr=self.args.audio_sr, n_mfcc=13, hop_length=int(self.args.audio_sr/self.args.audio_fps)) + + if self.args.audio_norm and self.args.audio_rep == "wave16k": + audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio + + time_offset = 0 + if self.args.word_rep is not None: + logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #") + word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid" + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + tgrid = tg.TextGrid.fromFile(word_file) + if self.args.t_pre_encoder == "bert": + from transformers import AutoTokenizer, BertModel + tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True) + model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval() + list_word = [] + all_hidden = [] + max_len = 400 + last = 0 + word_token_mapping = [] + first = True + for i, word in enumerate(tgrid[0]): + last = i + if (i%max_len != 0) or (i==0): + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + else: + max_counter = max_len + str_word = ' '.join(map(str, list_word)) + if first: + global_len = 0 + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + #print(i+global_len) + sub_mapping.append(i+global_len) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + global_len = word_token_mapping[-1][-1] + 1 + list_word = [] + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + + #list_word = list_word[:10] + if list_word == []: + pass + else: + if first: + global_len = 0 + str_word = ' '.join(map(str, list_word)) + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + sub_mapping.append(i+global_len) + #print(sub_mapping) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + last_hidden_states = np.concatenate(all_hidden, axis=0) + + for i in range(pose_each_file.shape[0]): + found_flag = False + current_time = i/self.args.pose_fps + time_offset + j_last = 0 + for j, word in enumerate(tgrid[0]): + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + if word_s<=current_time and current_time<=word_e: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + mapping_index = word_token_mapping[j] + #print(mapping_index, word_s, word_e) + s_t = np.linspace(word_s, word_e, len(mapping_index)+1) + #print(s_t) + for tt, t_sep in enumerate(s_t[1:]): + if current_time <= t_sep: + #if len(mapping_index) > 1: print(mapping_index[tt]) + word_each_file.append(last_hidden_states[mapping_index[tt]]) + break + else: + if word_n == " ": + word_each_file.append(self.lang_model.PAD_token) + else: + word_each_file.append(self.lang_model.get_word_index(word_n)) + found_flag = True + j_last = j + break + else: continue + if not found_flag: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + word_each_file.append(last_hidden_states[j_last]) + else: + word_each_file.append(self.lang_model.UNK_token) + word_each_file = np.array(word_each_file) + #print(word_each_file.shape) + + if self.args.emo_rep is not None: + logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #") + rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3]) + if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6: + if start >= 1 and start <= 64: + score = 0 + elif start >= 65 and start <= 72: + score = 1 + elif start >= 73 and start <= 80: + score = 2 + elif start >= 81 and start <= 86: + score = 3 + elif start >= 87 and start <= 94: + score = 4 + elif start >= 95 and start <= 102: + score = 5 + elif start >= 103 and start <= 110: + score = 6 + elif start >= 111 and start <= 118: + score = 7 + else: pass + else: + # you may denote as unknown in the future + score = 0 + emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0) + #print(emo_each_file) + + if self.args.sem_rep is not None: + logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #") + sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt" + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + # we adopt motion-level semantic score here. + for i in range(pose_each_file.shape[0]): + found_flag = False + for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])): + current_time = i/self.args.pose_fps + time_offset + if start<=current_time and current_time<=end: + sem_each_file.append(score) + found_flag=True + break + else: continue + if not found_flag: sem_each_file.append(0.) + sem_each_file = np.array(sem_each_file) + #print(sem_each_file) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + # audio_start = int(self.alignment[0] * self.args.audio_fps) + # pose_start = int(self.alignment[1] * self.args.pose_fps) + #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}") + # audio_each_file = audio_each_file[audio_start:] + # pose_each_file = pose_each_file[pose_start:] + # trans_each_file = + #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}") + #print(pose_each_file.shape) + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + #print(round_seconds_skeleton) + if audio_each_file != []: + round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s + if facial_each_file != []: + round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + else: + logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton) + max_round = max(round_seconds_audio, round_seconds_skeleton) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + if audio_each_file != []: + audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps) + """ + for audio sr = 16000, fps = 15, pose_length = 34, + audio short length = 36266.7 -> 36266 + this error is fine. + """ + logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}") + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_audio_list = [] + sample_facial_list = [] + sample_shape_list = [] + sample_word_list = [] + sample_emo_list = [] + sample_sem_list = [] + sample_vid_list = [] + sample_trans_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + sample_trans = trans_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + if self.args.audio_rep is not None: + audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps) + audio_end = audio_start + audio_short_length + sample_audio = audio_each_file[audio_start:audio_end] + else: + sample_audio = np.array([-1]) + sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1]) + sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1]) + sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1]) + sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1]) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + # filtering motion skeleton data + sample_pose, filtering_message = MotionPreprocessor(sample_pose).get() + is_correct_motion = (sample_pose != []) + if is_correct_motion or disable_filtering: + sample_pose_list.append(sample_pose) + sample_audio_list.append(sample_audio) + sample_facial_list.append(sample_facial) + sample_shape_list.append(sample_shape) + sample_word_list.append(sample_word) + sample_vid_list.append(sample_vid) + sample_emo_list.append(sample_emo) + sample_sem_list.append(sample_sem) + sample_trans_list.append(sample_trans) + else: + n_filtered_out[filtering_message] += 1 + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, audio, facial, shape, word, vid, emo, sem, trans in zip( + sample_pose_list, + sample_audio_list, + sample_facial_list, + sample_shape_list, + sample_word_list, + sample_vid_list, + sample_emo_list, + sample_sem_list, + sample_trans_list,): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose, audio, facial, shape, word, emo, sem, vid, trans] + v = pyarrow.serialize(v).to_buffer() + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pyarrow.deserialize(sample) + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample + #print(in_shape) + #vid = torch.from_numpy(vid).int() + emo = torch.from_numpy(emo).int() + sem = torch.from_numpy(sem).float() + in_audio = torch.from_numpy(in_audio).float() + in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int() + if self.loader_type == "test": + tar_pose = torch.from_numpy(tar_pose).float() + trans = torch.from_numpy(trans).float() + in_facial = torch.from_numpy(in_facial).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float() + in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float() + return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans} + + +class MotionPreprocessor: + def __init__(self, skeletons): + self.skeletons = skeletons + #self.mean_pose = mean_pose + self.filtering_message = "PASS" + + def get(self): + assert (self.skeletons is not None) + + # filtering + if self.skeletons != []: + if self.check_pose_diff(): + self.skeletons = [] + self.filtering_message = "pose" + # elif self.check_spine_angle(): + # self.skeletons = [] + # self.filtering_message = "spine angle" + # elif self.check_static_motion(): + # self.skeletons = [] + # self.filtering_message = "motion" + + # if self.skeletons != []: + # self.skeletons = self.skeletons.tolist() + # for i, frame in enumerate(self.skeletons): + # assert not np.isnan(self.skeletons[i]).any() # missing joints + + return self.skeletons, self.filtering_message + + def check_static_motion(self, verbose=True): + def get_variance(skeleton, joint_idx): + wrist_pos = skeleton[:, joint_idx] + variance = np.sum(np.var(wrist_pos, axis=0)) + return variance + + left_arm_var = get_variance(self.skeletons, 6) + right_arm_var = get_variance(self.skeletons, 9) + + th = 0.0014 # exclude 13110 + # th = 0.002 # exclude 16905 + if left_arm_var < th and right_arm_var < th: + if verbose: + print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return True + else: + if verbose: + print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return False + + + def check_pose_diff(self, verbose=False): +# diff = np.abs(self.skeletons - self.mean_pose) # 186*1 +# diff = np.mean(diff) + +# # th = 0.017 +# th = 0.02 #0.02 # exclude 3594 +# if diff < th: +# if verbose: +# print("skip - check_pose_diff {:.5f}".format(diff)) +# return True +# # th = 3.5 #0.02 # exclude 3594 +# # if 3.5 < diff < 5: +# # if verbose: +# # print("skip - check_pose_diff {:.5f}".format(diff)) +# # return True +# else: +# if verbose: +# print("pass - check_pose_diff {:.5f}".format(diff)) + return False + + + def check_spine_angle(self, verbose=True): + def angle_between(v1, v2): + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + angles = [] + for i in range(self.skeletons.shape[0]): + spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] + angle = angle_between(spine_vec, [0, -1, 0]) + angles.append(angle) + + if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 + # if np.rad2deg(max(angles)) > 20: # exclude 8270 + if verbose: + print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles))) + return True + else: + if verbose: + print("pass - check_spine_angle {:.5f}".format(max(angles))) + return False \ No newline at end of file diff --git a/dataloaders/beat_testonly.py b/dataloaders/beat_testonly.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d8d8fd524d63754ae68ba999f12085208f91bb --- /dev/null +++ b/dataloaders/beat_testonly.py @@ -0,0 +1,731 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +import pyarrow +import librosa +import smplx + +from .build_vocab import Vocab +from .utils.audio_features import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + + self.rank = dist.get_rank() + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.tar_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + + split_rule = pd.read_csv(args.data_path+"test.csv") + self.selected_file = split_rule + self.data_dir = args.data_path + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + + if args.word_rep is not None: + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + + preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache" + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + + + def __len__(self): + return self.n_samples + + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + if len(self.args.training_speakers) == 1: + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G + else: + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 200))# 200G + n_filtered_out = defaultdict(int) + + for index, file_name in self.selected_file.iterrows(): + f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext + pose_each_file = [] + trans_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = f_name #1_wayne_0_1_1 + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] + trans_each_file = pose_data["trans"][::stride] + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + + assert self.args.pose_fps == 30, "should 30" + m_data = np.load(pose_file, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).cuda().float() + poses = torch.from_numpy(poses.reshape(n, c)).cuda().float() + exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float() + trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float() + max_length = 128 + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(max_length, 4, 3).cpu() + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(r, 4, 3).cpu() + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) # all, 4, 3 + # print(joints.shape) + feetv = torch.zeros(joints.shape[1], joints.shape[0]) + joints = joints.permute(1, 0, 2) + #print(joints.shape, feetv.shape) + feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1) + #print(feetv.shape) + contacts = (feetv < 0.01).numpy().astype(float) + # print(contacts.shape, contacts) + contacts = contacts.transpose(1, 0) + pose_each_file = pose_each_file * self.joint_mask + pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)] + pose_each_file = np.concatenate([pose_each_file, contacts], axis=1) + # print(pose_each_file.shape) + + + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + else: + assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(120/self.args.pose_fps) + with open(pose_file, "r") as pose_data: + for j, line in enumerate(pose_data.readlines()): + if j < 431: continue + if j%stride != 0:continue + data = np.fromstring(line, dtype=float, sep=" ") + rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ") + rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3) + rot_data = rot_data.numpy() * self.joint_mask + + pose_each_file.append(rot_data) + trans_each_file.append(data[:3]) + + pose_each_file = np.array(pose_each_file) + trans_each_file = np.array(trans_each_file) + shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json") + assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(60/self.args.pose_fps) + if not os.path.exists(facial_file): + logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + with open(facial_file, 'r') as facial_data_file: + facial_data = json.load(facial_data_file) + for j, frame_data in enumerate(facial_data['frames']): + if j%stride != 0:continue + facial_each_file.append(frame_data['weights']) + facial_each_file = np.array(facial_each_file) + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + int_value = 1 + vid_each_file = np.repeat(np.array(int_value).reshape(1, 1), pose_each_file.shape[0], axis=0) + + if self.args.audio_rep is not None: + logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #") + audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + if not os.path.exists(audio_file): + logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + audio_each_file, sr = librosa.load(audio_file) + audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr) + if self.args.audio_rep == "onset+amplitude": + from numpy.lib import stride_tricks + frame_length = 1024 + # hop_length = 512 + shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length) + strides = (audio_each_file.strides[-1], audio_each_file.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + # pad the last frame_length-1 samples + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames') + onset_array = np.zeros(len(audio_each_file), dtype=float) + onset_array[audio_onset_f] = 1.0 + # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape) + audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + elif self.args.audio_rep == "mfcc": + audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps)) + audio_each_file = audio_each_file.transpose(1, 0) + # print(audio_each_file.shape, pose_each_file.shape) + if self.args.audio_norm and self.args.audio_rep == "wave16k": + audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio + + time_offset = 0 + if self.args.word_rep is not None: + logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #") + word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid" + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + tgrid = tg.TextGrid.fromFile(word_file) + if self.args.t_pre_encoder == "bert": + from transformers import AutoTokenizer, BertModel + tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True) + model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval() + list_word = [] + all_hidden = [] + max_len = 400 + last = 0 + word_token_mapping = [] + first = True + for i, word in enumerate(tgrid[0]): + last = i + if (i%max_len != 0) or (i==0): + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + else: + max_counter = max_len + str_word = ' '.join(map(str, list_word)) + if first: + global_len = 0 + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + #print(i+global_len) + sub_mapping.append(i+global_len) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + global_len = word_token_mapping[-1][-1] + 1 + list_word = [] + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + + #list_word = list_word[:10] + if list_word == []: + pass + else: + if first: + global_len = 0 + str_word = ' '.join(map(str, list_word)) + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + sub_mapping.append(i+global_len) + #print(sub_mapping) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + last_hidden_states = np.concatenate(all_hidden, axis=0) + + for i in range(pose_each_file.shape[0]): + found_flag = False + current_time = i/self.args.pose_fps + time_offset + j_last = 0 + for j, word in enumerate(tgrid[0]): + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + if word_s<=current_time and current_time<=word_e: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + mapping_index = word_token_mapping[j] + #print(mapping_index, word_s, word_e) + s_t = np.linspace(word_s, word_e, len(mapping_index)+1) + #print(s_t) + for tt, t_sep in enumerate(s_t[1:]): + if current_time <= t_sep: + #if len(mapping_index) > 1: print(mapping_index[tt]) + word_each_file.append(last_hidden_states[mapping_index[tt]]) + break + else: + if word_n == " ": + word_each_file.append(self.lang_model.PAD_token) + else: + word_each_file.append(self.lang_model.get_word_index(word_n)) + found_flag = True + j_last = j + break + else: continue + if not found_flag: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + word_each_file.append(last_hidden_states[j_last]) + else: + word_each_file.append(self.lang_model.UNK_token) + word_each_file = np.array(word_each_file) + #print(word_each_file.shape) + + if self.args.emo_rep is not None: + logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #") + rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3]) + if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6: + if start >= 1 and start <= 64: + score = 0 + elif start >= 65 and start <= 72: + score = 1 + elif start >= 73 and start <= 80: + score = 2 + elif start >= 81 and start <= 86: + score = 3 + elif start >= 87 and start <= 94: + score = 4 + elif start >= 95 and start <= 102: + score = 5 + elif start >= 103 and start <= 110: + score = 6 + elif start >= 111 and start <= 118: + score = 7 + else: pass + else: + # you may denote as unknown in the future + score = 0 + emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0) + #print(emo_each_file) + + if self.args.sem_rep is not None: + logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #") + sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt" + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + # we adopt motion-level semantic score here. + for i in range(pose_each_file.shape[0]): + found_flag = False + for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])): + current_time = i/self.args.pose_fps + time_offset + if start<=current_time and current_time<=end: + sem_each_file.append(score) + found_flag=True + break + else: continue + if not found_flag: sem_each_file.append(0.) + sem_each_file = np.array(sem_each_file) + #print(sem_each_file) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + # audio_start = int(self.alignment[0] * self.args.audio_fps) + # pose_start = int(self.alignment[1] * self.args.pose_fps) + #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}") + # audio_each_file = audio_each_file[audio_start:] + # pose_each_file = pose_each_file[pose_start:] + # trans_each_file = + #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}") + #print(pose_each_file.shape) + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + #print(round_seconds_skeleton) + if audio_each_file != []: + if self.args.audio_rep != "wave16k": + round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s + elif self.args.audio_rep == "mfcc": + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps + else: + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr + if facial_each_file != []: + round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + else: + logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton) + max_round = max(round_seconds_audio, round_seconds_skeleton) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + if audio_each_file != []: + audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps) + """ + for audio sr = 16000, fps = 15, pose_length = 34, + audio short length = 36266.7 -> 36266 + this error is fine. + """ + logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}") + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_audio_list = [] + sample_facial_list = [] + sample_shape_list = [] + sample_word_list = [] + sample_emo_list = [] + sample_sem_list = [] + sample_vid_list = [] + sample_trans_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + + sample_trans = trans_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + if self.args.audio_rep is not None: + audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps) + audio_end = audio_start + audio_short_length + sample_audio = audio_each_file[audio_start:audio_end] + else: + sample_audio = np.array([-1]) + sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1]) + sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1]) + sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1]) + sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1]) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + # filtering motion skeleton data + sample_pose, filtering_message = MotionPreprocessor(sample_pose).get() + is_correct_motion = (sample_pose != []) + if is_correct_motion or disable_filtering: + sample_pose_list.append(sample_pose) + sample_audio_list.append(sample_audio) + sample_facial_list.append(sample_facial) + sample_shape_list.append(sample_shape) + sample_word_list.append(sample_word) + sample_vid_list.append(sample_vid) + sample_emo_list.append(sample_emo) + sample_sem_list.append(sample_sem) + sample_trans_list.append(sample_trans) + else: + n_filtered_out[filtering_message] += 1 + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, audio, facial, shape, word, vid, emo, sem, trans in zip( + sample_pose_list, + sample_audio_list, + sample_facial_list, + sample_shape_list, + sample_word_list, + sample_vid_list, + sample_emo_list, + sample_sem_list, + sample_trans_list,): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose, audio, facial, shape, word, emo, sem, vid, trans] + v = pyarrow.serialize(v).to_buffer() + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pyarrow.deserialize(sample) + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample + #print(in_shape) + #vid = torch.from_numpy(vid).int() + emo = torch.from_numpy(emo).int() + sem = torch.from_numpy(sem).float() + in_audio = torch.from_numpy(in_audio).float() + in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int() + if self.loader_type == "test": + tar_pose = torch.from_numpy(tar_pose).float() + trans = torch.from_numpy(trans).float() + in_facial = torch.from_numpy(in_facial).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float() + in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float() + return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans} + + +class MotionPreprocessor: + def __init__(self, skeletons): + self.skeletons = skeletons + #self.mean_pose = mean_pose + self.filtering_message = "PASS" + + def get(self): + assert (self.skeletons is not None) + + # filtering + if self.skeletons != []: + if self.check_pose_diff(): + self.skeletons = [] + self.filtering_message = "pose" + # elif self.check_spine_angle(): + # self.skeletons = [] + # self.filtering_message = "spine angle" + # elif self.check_static_motion(): + # self.skeletons = [] + # self.filtering_message = "motion" + + # if self.skeletons != []: + # self.skeletons = self.skeletons.tolist() + # for i, frame in enumerate(self.skeletons): + # assert not np.isnan(self.skeletons[i]).any() # missing joints + + return self.skeletons, self.filtering_message + + def check_static_motion(self, verbose=True): + def get_variance(skeleton, joint_idx): + wrist_pos = skeleton[:, joint_idx] + variance = np.sum(np.var(wrist_pos, axis=0)) + return variance + + left_arm_var = get_variance(self.skeletons, 6) + right_arm_var = get_variance(self.skeletons, 9) + + th = 0.0014 # exclude 13110 + # th = 0.002 # exclude 16905 + if left_arm_var < th and right_arm_var < th: + if verbose: + print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return True + else: + if verbose: + print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return False + + + def check_pose_diff(self, verbose=False): +# diff = np.abs(self.skeletons - self.mean_pose) # 186*1 +# diff = np.mean(diff) + +# # th = 0.017 +# th = 0.02 #0.02 # exclude 3594 +# if diff < th: +# if verbose: +# print("skip - check_pose_diff {:.5f}".format(diff)) +# return True +# # th = 3.5 #0.02 # exclude 3594 +# # if 3.5 < diff < 5: +# # if verbose: +# # print("skip - check_pose_diff {:.5f}".format(diff)) +# # return True +# else: +# if verbose: +# print("pass - check_pose_diff {:.5f}".format(diff)) + return False + + + def check_spine_angle(self, verbose=True): + def angle_between(v1, v2): + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + angles = [] + for i in range(self.skeletons.shape[0]): + spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] + angle = angle_between(spine_vec, [0, -1, 0]) + angles.append(angle) + + if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 + # if np.rad2deg(max(angles)) > 20: # exclude 8270 + if verbose: + print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles))) + return True + else: + if verbose: + print("pass - check_spine_angle {:.5f}".format(max(angles))) + return False \ No newline at end of file diff --git a/dataloaders/beat_testonly_hf.py b/dataloaders/beat_testonly_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..17393517d0d6f662114c19b009fa3471d7c1e91b --- /dev/null +++ b/dataloaders/beat_testonly_hf.py @@ -0,0 +1,740 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +import pyarrow +import librosa +import smplx + +from .build_vocab import Vocab +from .utils.audio_features import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools_hf + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, smplx_path=None, audio_path=None, text_path=None, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + self.smplx_path = "./EMAGE/test_sequences/smplxflame_30/2_scott_0_1_1.npz" + self.audio_path = audio_path + self.text_path = "./EMAGE/test_sequences/textgrid/2_scott_0_1_1.TextGrid" + self.rank = 0 # dist.get_rank() + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.tar_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).eval() + + split_rule = pd.read_csv(args.data_path+"test.csv") + self.selected_file = split_rule + self.data_dir = args.data_path + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + + if args.word_rep is not None: + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + + preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache" + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + + + def __len__(self): + return self.n_samples + + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + if len(self.args.training_speakers) == 1: + #dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 0.5))# 500M + else: + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 200))# 200G + n_filtered_out = defaultdict(int) + + #for index, file_name in self.selected_file.iterrows(): + #f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.smplx_path#self.data_dir + self.args.pose_rep + "/" + f_name + ext + pose_each_file = [] + trans_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = "dummy 2nd"#f_name + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] + trans_each_file = pose_data["trans"][::stride] + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + + assert self.args.pose_fps == 30, "should 30" + m_data = np.load(pose_file, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).float() + poses = torch.from_numpy(poses.reshape(n, c)).float() + exps = torch.from_numpy(exps.reshape(n, 100)).float() + trans = torch.from_numpy(trans.reshape(n, 3)).float() + max_length = 128 + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(max_length, 4, 3).cpu() + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(r, 4, 3).cpu() + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) # all, 4, 3 + # print(joints.shape) + feetv = torch.zeros(joints.shape[1], joints.shape[0]) + joints = joints.permute(1, 0, 2) + #print(joints.shape, feetv.shape) + feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1) + #print(feetv.shape) + contacts = (feetv < 0.01).numpy().astype(float) + # print(contacts.shape, contacts) + contacts = contacts.transpose(1, 0) + pose_each_file = pose_each_file * self.joint_mask + pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)] + pose_each_file = np.concatenate([pose_each_file, contacts], axis=1) + # print(pose_each_file.shape) + + + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + else: + assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(120/self.args.pose_fps) + with open(pose_file, "r") as pose_data: + for j, line in enumerate(pose_data.readlines()): + if j < 431: continue + if j%stride != 0:continue + data = np.fromstring(line, dtype=float, sep=" ") + rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ") + rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3) + rot_data = rot_data.numpy() * self.joint_mask + + pose_each_file.append(rot_data) + trans_each_file.append(data[:3]) + + pose_each_file = np.array(pose_each_file) + trans_each_file = np.array(trans_each_file) + shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json") + assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(60/self.args.pose_fps) + if not os.path.exists(facial_file): + logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #") + #self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + #continue + with open(facial_file, 'r') as facial_data_file: + facial_data = json.load(facial_data_file) + for j, frame_data in enumerate(facial_data['frames']): + if j%stride != 0:continue + facial_each_file.append(frame_data['weights']) + facial_each_file = np.array(facial_each_file) + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + int_value = 1 + vid_each_file = np.repeat(np.array(int_value).reshape(1, 1), pose_each_file.shape[0], axis=0) + + if self.args.audio_rep is not None: + logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #") + audio_file = self.audio_path[1]#pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + sr = self.audio_path[0] + print(sr) + #if not os.path.exists(audio_file): + # logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #") + #self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + #continue + #audio_each_file, sr = librosa.load(audio_file) + audio_each_file = audio_file.astype(np.float32) + print(audio_each_file.shape) + audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr) + print(audio_each_file.shape) + if self.args.audio_rep == "onset+amplitude": + from numpy.lib import stride_tricks + frame_length = 1024 + # hop_length = 512 + shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length) + strides = (audio_each_file.strides[-1], audio_each_file.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + # pad the last frame_length-1 samples + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames') + onset_array = np.zeros(len(audio_each_file), dtype=float) + onset_array[audio_onset_f] = 1.0 + # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape) + audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + elif self.args.audio_rep == "mfcc": + audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps)) + audio_each_file = audio_each_file.transpose(1, 0) + # print(audio_each_file.shape, pose_each_file.shape) + if self.args.audio_norm and self.args.audio_rep == "wave16k": + audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio + + time_offset = 0 + if self.args.word_rep is not None: + logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #") + word_file = self.text_path#f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid" + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #") + #self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + #continue + tgrid = tg.TextGrid.fromFile(word_file) + if self.args.t_pre_encoder == "bert": + from transformers import AutoTokenizer, BertModel + tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True) + model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval() + list_word = [] + all_hidden = [] + max_len = 400 + last = 0 + word_token_mapping = [] + first = True + for i, word in enumerate(tgrid[0]): + last = i + if (i%max_len != 0) or (i==0): + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + else: + max_counter = max_len + str_word = ' '.join(map(str, list_word)) + if first: + global_len = 0 + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + #print(i+global_len) + sub_mapping.append(i+global_len) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + global_len = word_token_mapping[-1][-1] + 1 + list_word = [] + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + + #list_word = list_word[:10] + if list_word == []: + pass + else: + if first: + global_len = 0 + str_word = ' '.join(map(str, list_word)) + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + sub_mapping.append(i+global_len) + #print(sub_mapping) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + last_hidden_states = np.concatenate(all_hidden, axis=0) + + for i in range(pose_each_file.shape[0]): + found_flag = False + current_time = i/self.args.pose_fps + time_offset + j_last = 0 + for j, word in enumerate(tgrid[0]): + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + if word_s<=current_time and current_time<=word_e: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + mapping_index = word_token_mapping[j] + #print(mapping_index, word_s, word_e) + s_t = np.linspace(word_s, word_e, len(mapping_index)+1) + #print(s_t) + for tt, t_sep in enumerate(s_t[1:]): + if current_time <= t_sep: + #if len(mapping_index) > 1: print(mapping_index[tt]) + word_each_file.append(last_hidden_states[mapping_index[tt]]) + break + else: + if word_n == " ": + word_each_file.append(self.lang_model.PAD_token) + else: + word_each_file.append(self.lang_model.get_word_index(word_n)) + found_flag = True + j_last = j + break + else: continue + if not found_flag: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + word_each_file.append(last_hidden_states[j_last]) + else: + word_each_file.append(self.lang_model.UNK_token) + word_each_file = np.array(word_each_file) + #print(word_each_file.shape) + + if self.args.emo_rep is not None: + logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #") + rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3]) + if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6: + if start >= 1 and start <= 64: + score = 0 + elif start >= 65 and start <= 72: + score = 1 + elif start >= 73 and start <= 80: + score = 2 + elif start >= 81 and start <= 86: + score = 3 + elif start >= 87 and start <= 94: + score = 4 + elif start >= 95 and start <= 102: + score = 5 + elif start >= 103 and start <= 110: + score = 6 + elif start >= 111 and start <= 118: + score = 7 + else: pass + else: + # you may denote as unknown in the future + score = 0 + emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0) + #print(emo_each_file) + + if self.args.sem_rep is not None: + logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #") + sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt" + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + # we adopt motion-level semantic score here. + for i in range(pose_each_file.shape[0]): + found_flag = False + for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])): + current_time = i/self.args.pose_fps + time_offset + if start<=current_time and current_time<=end: + sem_each_file.append(score) + found_flag=True + break + else: continue + if not found_flag: sem_each_file.append(0.) + sem_each_file = np.array(sem_each_file) + #print(sem_each_file) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + # audio_start = int(self.alignment[0] * self.args.audio_fps) + # pose_start = int(self.alignment[1] * self.args.pose_fps) + #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}") + # audio_each_file = audio_each_file[audio_start:] + # pose_each_file = pose_each_file[pose_start:] + # trans_each_file = + #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}") + #print(pose_each_file.shape) + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + print(pose_each_file.shape[0]) + #print(round_seconds_skeleton) + #if audio_each_file != []: + if self.args.audio_rep != "wave16k": + round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s + elif self.args.audio_rep == "mfcc": + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps + else: + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr + # if facial_each_file != []: + round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + # else: + # logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + # round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton) + # max_round = max(round_seconds_audio, round_seconds_skeleton) + # if round_seconds_skeleton != max_round: + # logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + # if audio_each_file != []: + audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps) + logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}") + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_audio_list = [] + sample_facial_list = [] + sample_shape_list = [] + sample_word_list = [] + sample_emo_list = [] + sample_sem_list = [] + sample_vid_list = [] + sample_trans_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + + sample_trans = trans_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + if self.args.audio_rep is not None: + audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps) + audio_end = audio_start + audio_short_length + sample_audio = audio_each_file[audio_start:audio_end] + else: + sample_audio = np.array([-1]) + sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1]) + sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1]) + sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1]) + sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1]) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + # filtering motion skeleton data + sample_pose, filtering_message = MotionPreprocessor(sample_pose).get() + is_correct_motion = True #(sample_pose != []) + if is_correct_motion or disable_filtering: + sample_pose_list.append(sample_pose) + sample_audio_list.append(sample_audio) + sample_facial_list.append(sample_facial) + sample_shape_list.append(sample_shape) + sample_word_list.append(sample_word) + sample_vid_list.append(sample_vid) + sample_emo_list.append(sample_emo) + sample_sem_list.append(sample_sem) + sample_trans_list.append(sample_trans) + else: + n_filtered_out[filtering_message] += 1 + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, audio, facial, shape, word, vid, emo, sem, trans in zip( + sample_pose_list, + sample_audio_list, + sample_facial_list, + sample_shape_list, + sample_word_list, + sample_vid_list, + sample_emo_list, + sample_sem_list, + sample_trans_list,): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose, audio, facial, shape, word, emo, sem, vid, trans] + # v = pyarrow.serialize(v).to_buffer() + # txn.put(k, v) + # self.n_out_samples += 1 + v = pickle.dumps(v) + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + # sample = pyarrow.deserialize(sample) + if sample is not None: + sample = pickle.loads(sample) + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample + #print(in_shape) + #vid = torch.from_numpy(vid).int() + emo = torch.from_numpy(emo).int() + sem = torch.from_numpy(sem).float() + in_audio = torch.from_numpy(in_audio).float() + in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int() + if self.loader_type == "test": + tar_pose = torch.from_numpy(tar_pose).float() + trans = torch.from_numpy(trans).float() + in_facial = torch.from_numpy(in_facial).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float() + in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float() + return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans} + + +class MotionPreprocessor: + def __init__(self, skeletons): + self.skeletons = skeletons + #self.mean_pose = mean_pose + self.filtering_message = "PASS" + + def get(self): + assert (self.skeletons is not None) + + # filtering + # if self.skeletons != []: + # if self.check_pose_diff(): + # self.skeletons = [] + # self.filtering_message = "pose" + # elif self.check_spine_angle(): + # self.skeletons = [] + # self.filtering_message = "spine angle" + # elif self.check_static_motion(): + # self.skeletons = [] + # self.filtering_message = "motion" + + # if self.skeletons != []: + # self.skeletons = self.skeletons.tolist() + # for i, frame in enumerate(self.skeletons): + # assert not np.isnan(self.skeletons[i]).any() # missing joints + + return self.skeletons, self.filtering_message + + def check_static_motion(self, verbose=True): + def get_variance(skeleton, joint_idx): + wrist_pos = skeleton[:, joint_idx] + variance = np.sum(np.var(wrist_pos, axis=0)) + return variance + + left_arm_var = get_variance(self.skeletons, 6) + right_arm_var = get_variance(self.skeletons, 9) + + th = 0.0014 # exclude 13110 + # th = 0.002 # exclude 16905 + if left_arm_var < th and right_arm_var < th: + if verbose: + print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return True + else: + if verbose: + print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return False + + + def check_pose_diff(self, verbose=False): +# diff = np.abs(self.skeletons - self.mean_pose) # 186*1 +# diff = np.mean(diff) + +# # th = 0.017 +# th = 0.02 #0.02 # exclude 3594 +# if diff < th: +# if verbose: +# print("skip - check_pose_diff {:.5f}".format(diff)) +# return True +# # th = 3.5 #0.02 # exclude 3594 +# # if 3.5 < diff < 5: +# # if verbose: +# # print("skip - check_pose_diff {:.5f}".format(diff)) +# # return True +# else: +# if verbose: +# print("pass - check_pose_diff {:.5f}".format(diff)) + return False + + + def check_spine_angle(self, verbose=True): + def angle_between(v1, v2): + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + angles = [] + for i in range(self.skeletons.shape[0]): + spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] + angle = angle_between(spine_vec, [0, -1, 0]) + angles.append(angle) + + if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 + # if np.rad2deg(max(angles)) > 20: # exclude 8270 + if verbose: + print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles))) + return True + else: + if verbose: + print("pass - check_spine_angle {:.5f}".format(max(angles))) + return False \ No newline at end of file diff --git a/dataloaders/build_vocab.py b/dataloaders/build_vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..649204d0d502eaf671e40d2bbb08ca324e5e02f3 --- /dev/null +++ b/dataloaders/build_vocab.py @@ -0,0 +1,199 @@ +import numpy as np +import glob +import os +import pickle +import lmdb +import pyarrow +import fasttext +from loguru import logger +from scipy import linalg + + +class Vocab: + PAD_token = 0 + SOS_token = 1 + EOS_token = 2 + UNK_token = 3 + + def __init__(self, name, insert_default_tokens=True): + self.name = name + self.trimmed = False + self.word_embedding_weights = None + self.reset_dictionary(insert_default_tokens) + + def reset_dictionary(self, insert_default_tokens=True): + self.word2index = {} + self.word2count = {} + if insert_default_tokens: + self.index2word = {self.PAD_token: "", self.SOS_token: "", + self.EOS_token: "", self.UNK_token: ""} + else: + self.index2word = {self.UNK_token: ""} + self.n_words = len(self.index2word) # count default tokens + + def index_word(self, word): + if word not in self.word2index: + self.word2index[word] = self.n_words + self.word2count[word] = 1 + self.index2word[self.n_words] = word + self.n_words += 1 + else: + self.word2count[word] += 1 + + def add_vocab(self, other_vocab): + for word, _ in other_vocab.word2count.items(): + self.index_word(word) + + # remove words below a certain count threshold + def trim(self, min_count): + if self.trimmed: + return + self.trimmed = True + + keep_words = [] + + for k, v in self.word2count.items(): + if v >= min_count: + keep_words.append(k) + + print(' word trimming, kept %s / %s = %.4f' % ( + len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) + )) + + # reinitialize dictionary + self.reset_dictionary() + for word in keep_words: + self.index_word(word) + + def get_word_index(self, word): + if word in self.word2index: + return self.word2index[word] + else: + return self.UNK_token + + def load_word_vectors(self, pretrained_path, embedding_dim=300): + print(" loading word vectors from '{}'...".format(pretrained_path)) + + # initialize embeddings to random values for special words + init_sd = 1 / np.sqrt(embedding_dim) + weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) + weights = weights.astype(np.float32) + + # read word vectors + word_model = fasttext.load_model(pretrained_path) + for word, id in self.word2index.items(): + vec = word_model.get_word_vector(word) + weights[id] = vec + self.word_embedding_weights = weights + + def __get_embedding_weight(self, pretrained_path, embedding_dim=300): + """ function modified from http://ronny.rest/blog/post_2017_08_04_glove/ """ + print("Loading word embedding '{}'...".format(pretrained_path)) + cache_path = pretrained_path + weights = None + + # use cached file if it exists + if os.path.exists(cache_path): # + with open(cache_path, 'rb') as f: + print(' using cached result from {}'.format(cache_path)) + weights = pickle.load(f) + if weights.shape != (self.n_words, embedding_dim): + logging.warning(' failed to load word embedding weights. reinitializing...') + weights = None + + if weights is None: + # initialize embeddings to random values for special and OOV words + init_sd = 1 / np.sqrt(embedding_dim) + weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) + weights = weights.astype(np.float32) + + with open(pretrained_path, encoding="utf-8", mode="r") as textFile: + num_embedded_words = 0 + for line_raw in textFile: + # extract the word, and embeddings vector + line = line_raw.split() + try: + word, vector = (line[0], np.array(line[1:], dtype=np.float32)) + # if word == 'love': # debugging + # print(word, vector) + + # if it is in our vocab, then update the corresponding weights + id = self.word2index.get(word, None) + if id is not None: + weights[id] = vector + num_embedded_words += 1 + except ValueError: + print(' parsing error at {}...'.format(line_raw[:50])) + continue + print(' {} / {} word vectors are found in the embedding'.format(num_embedded_words, len(self.word2index))) + + with open(cache_path, 'wb') as f: + pickle.dump(weights, f) + return weights + + +def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None): + print(' building a language model...') + #if not os.path.exists(cache_path): + lang_model = Vocab(name) + print(' indexing words from {}'.format(data_path)) + index_words_from_textgrid(lang_model, data_path) + + if word_vec_path is not None: + lang_model.load_word_vectors(word_vec_path, feat_dim) + else: + print(' loaded from {}'.format(cache_path)) + with open(cache_path, 'rb') as f: + lang_model = pickle.load(f) + if word_vec_path is None: + lang_model.word_embedding_weights = None + elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: + logging.warning(' failed to load word embedding weights. check this') + assert False + + with open(cache_path, 'wb') as f: + pickle.dump(lang_model, f) + + + return lang_model + + +def index_words(lang_model, data_path): + #index words form text + with open(data_path, "r") as f: + for line in f.readlines(): + line = line.replace(",", " ") + line = line.replace(".", " ") + line = line.replace("?", " ") + line = line.replace("!", " ") + for word in line.split(): + lang_model.index_word(word) + print(' indexed %d words' % lang_model.n_words) + +def index_words_from_textgrid(lang_model, data_path): + import textgrid as tg + from tqdm import tqdm + #trainvaltest=os.listdir(data_path) + # for loadtype in trainvaltest: + # if "." in loadtype: continue #ignore .ipynb_checkpoints + texts = os.listdir(data_path+"/textgrid/") + #print(texts) + for textfile in tqdm(texts): + tgrid = tg.TextGrid.fromFile(data_path+"/textgrid/"+textfile) + for word in tgrid[0]: + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + word_n = word_n.replace(",", " ") + word_n = word_n.replace(".", " ") + word_n = word_n.replace("?", " ") + word_n = word_n.replace("!", " ") + #print(word_n) + lang_model.index_word(word_n) + print(' indexed %d words' % lang_model.n_words) + print(lang_model.word2index, lang_model.word2count) + +if __name__ == "__main__": + # 11195 for all, 5793 for 4 speakers + # build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300) + build_vocab("beat_chinese_v1.0.0", "/data/datasets/beat_chinese_v1.0.0/", "/data/datasets/beat_chinese_v1.0.0/weights/vocab.pkl", "/home/ma-user/work/cc.zh.300.bin", 300) + + \ No newline at end of file diff --git a/dataloaders/data_tools.py b/dataloaders/data_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f7388daf0f9ed7620ba98ccf964d30dc901a7737 --- /dev/null +++ b/dataloaders/data_tools.py @@ -0,0 +1,1756 @@ +import numpy as np +import glob +import os +import pickle +import lmdb +import pyarrow +import fasttext +from loguru import logger +from scipy import linalg +from .pymo.parsers import BVHParser +from .pymo.viz_tools import * +from .pymo.preprocessing import * + + + + +# pose version fpsxx_trinity/japanese_joints(_xxx) +joints_list = { + "trinity_joints":{ + 'Hips': [6,6], + 'Spine': [3,9], + 'Spine1': [3,12], + 'Spine2': [3,15], + 'Spine3': [3,18], + 'Neck': [3,21], + 'Neck1': [3,24], + 'Head': [3,27], + 'RShoulder': [3,30], + 'RArm': [3,33], + 'RArm1': [3,36], + 'RHand': [3,39], + 'RHandT1': [3,42], + 'RHandT2': [3,45], + 'RHandT3': [3,48], + 'RHandI1': [3,51], + 'RHandI2': [3,54], + 'RHandI3': [3,57], + 'RHandM1': [3,60], + 'RHandM2': [3,63], + 'RHandM3': [3,66], + 'RHandR1': [3,69], + 'RHandR2': [3,72], + 'RHandR3': [3,75], + 'RHandP1': [3,78], + 'RHandP2': [3,81], + 'RHandP3': [3,84], + 'LShoulder': [3,87], + 'LArm': [3,90], + 'LArm1': [3,93], + 'LHand': [3,96], + 'LHandT1': [3,99], + 'LHandT2': [3,102], + 'LHandT3': [3,105], + 'LHandI1': [3,108], + 'LHandI2': [3,111], + 'LHandI3': [3,114], + 'LHandM1': [3,117], + 'LHandM2': [3,120], + 'LHandM3': [3,123], + 'LHandR1': [3,126], + 'LHandR2': [3,129], + 'LHandR3': [3,132], + 'LHandP1': [3,135], + 'LHandP2': [3,138], + 'LHandP3': [3,141], + 'RUpLeg': [3,144], + 'RLeg': [3,147], + 'RFoot': [3,150], + 'RFootF': [3,153], + 'RToeBase': [3,156], + 'LUpLeg': [3,159], + 'LLeg': [3,162], + 'LFoot': [3,165], + 'LFootF': [3,168], + 'LToeBase': [3,171],}, + "trinity_joints_123":{ + 'Spine': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 ,}, + "trinity_joints_168":{ + 'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Spine2': 3 , + 'Spine3': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'Head': 3 , + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 , + 'RUpLeg': 3 , + 'RLeg': 3 , + 'RFoot': 3 , + 'RFootF': 3 , + 'RToeBase': 3 , + 'LUpLeg': 3 , + 'LLeg': 3 , + 'LFoot': 3 , + 'LFootF': 3 , + 'LToeBase': 3 ,}, + "trinity_joints_138":{ + "Hips": 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Spine2': 3 , + 'Spine3': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'Head': 3 , + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 ,}, + "beat_smplx_joints": { + 'pelvis': [3,3], + 'left_hip': [3,6], + 'right_hip': [3,9], + 'spine1': [3,12], + 'left_knee': [3,15], + 'right_knee': [3,18], + 'spine2': [3,21], + 'left_ankle': [3,24], + 'right_ankle': [3,27], + + 'spine3': [3,30], + 'left_foot': [3,33], + 'right_foot': [3,36], + 'neck': [3,39], + 'left_collar': [3,42], + 'right_collar': [3,45], + 'head': [3,48], + 'left_shoulder': [3,51], + + 'right_shoulder': [3,54], + 'left_elbow': [3,57], + 'right_elbow': [3,60], + 'left_wrist': [3,63], + 'right_wrist': [3,66], + + 'jaw': [3,69], + 'left_eye_smplhf': [3,72], + 'right_eye_smplhf': [3,75], + 'left_index1': [3,78], + 'left_index2': [3,81], + + 'left_index3': [3,84], + 'left_middle1': [3,87], + 'left_middle2': [3,90], + 'left_middle3': [3,93], + 'left_pinky1': [3,96], + + 'left_pinky2': [3,99], + 'left_pinky3': [3,102], + 'left_ring1': [3,105], + 'left_ring2': [3,108], + + 'left_ring3': [3,111], + 'left_thumb1': [3,114], + 'left_thumb2': [3,117], + 'left_thumb3': [3,120], + 'right_index1': [3,123], + 'right_index2': [3,126], + 'right_index3': [3,129], + 'right_middle1': [3,132], + + 'right_middle2': [3,135], + 'right_middle3': [3,138], + 'right_pinky1': [3,141], + 'right_pinky2': [3,144], + 'right_pinky3': [3,147], + + 'right_ring1': [3,150], + 'right_ring2': [3,153], + 'right_ring3': [3,156], + 'right_thumb1': [3,159], + 'right_thumb2': [3,162], + 'right_thumb3': [3,165], + +# 'nose': [3,168], +# 'right_eye': [3,171], +# 'left_eye': [3,174], +# 'right_ear': [3,177], + +# 'left_ear': [3,180], +# 'left_big_toe': [3,183], +# 'left_small_toe': [3,186], +# 'left_heel': [3,189], + +# 'right_big_toe': [3,192], +# 'right_small_toe': [3,195], +# 'right_heel': [3,198], +# 'left_thumb': [3,201], +# 'left_index': [3,204], +# 'left_middle': [3,207], + +# 'left_ring': [3,210], +# 'left_pinky': [3,213], +# 'right_thumb': [3,216], +# 'right_index': [3,219], +# 'right_middle': [3,222], +# 'right_ring': [3,225], + +# 'right_pinky': [3,228], +# 'right_eye_brow1': [3,231], +# 'right_eye_brow2': [3,234], +# 'right_eye_brow3': [3,237], + +# 'right_eye_brow4': [3,240], +# 'right_eye_brow5': [3,243], +# 'left_eye_brow5': [3,246], +# 'left_eye_brow4': [3,249], + +# 'left_eye_brow3': [3,252], +# 'left_eye_brow2': [3,255], +# 'left_eye_brow1': [3,258], +# 'nose1': [3,261], +# 'nose2': [3,264], +# 'nose3': [3,267], + +# 'nose4': [3,270], +# 'right_nose_2': [3,273], +# 'right_nose_1': [3,276], +# 'nose_middle': [3,279], +# 'left_nose_1': [3,282], +# 'left_nose_2': [3,285], + +# 'right_eye1': [3,288], +# 'right_eye2': [3,291], +# 'right_eye3': [3,294], +# 'right_eye4': [3,297], + +# 'right_eye5': [3,300], +# 'right_eye6': [3,303], +# 'left_eye4': [3,306], +# 'left_eye3': [3,309], + +# 'left_eye2': [3,312], +# 'left_eye1': [3,315], +# 'left_eye6': [3,318], +# 'left_eye5': [3,321], +# 'right_mouth_1': [3,324], +# 'right_mouth_2': [3,327], +# 'right_mouth_3': [3,330], +# 'mouth_top': [3,333], +# 'left_mouth_3': [3,336], +# 'left_mouth_2': [3,339], +# 'left_mouth_1': [3,342], +# 'left_mouth_5': [3,345], +# 'left_mouth_4': [3,348], +# 'mouth_bottom': [3,351], +# 'right_mouth_4': [3,354], +# 'right_mouth_5': [3,357], +# 'right_lip_1': [3,360], +# 'right_lip_2': [3,363], +# 'lip_top': [3,366], +# 'left_lip_2': [3,369], + +# 'left_lip_1': [3,372], +# 'left_lip_3': [3,375], +# 'lip_bottom': [3,378], +# 'right_lip_3': [3,381], +# 'right_contour_1': [3,384], +# 'right_contour_2': [3,387], +# 'right_contour_3': [3,390], +# 'right_contour_4': [3,393], +# 'right_contour_5': [3,396], +# 'right_contour_6': [3,399], +# 'right_contour_7': [3,402], +# 'right_contour_8': [3,405], +# 'contour_middle': [3,408], +# 'left_contour_8': [3,411], +# 'left_contour_7': [3,414], +# 'left_contour_6': [3,417], +# 'left_contour_5': [3,420], +# 'left_contour_4': [3,423], +# 'left_contour_3': [3,426], +# 'left_contour_2': [3,429], +# 'left_contour_1': [3,432], + }, + + "beat_smplx_no_eyes": { + "pelvis":3, + "left_hip":3, + "right_hip":3, + "spine1":3, + "left_knee":3, + "right_knee":3, + "spine2":3, + "left_ankle":3, + "right_ankle":3, + "spine3":3, + "left_foot":3, + "right_foot":3, + "neck":3, + "left_collar":3, + "right_collar":3, + "head":3, + "left_shoulder":3, + "right_shoulder":3, + "left_elbow":3, + "right_elbow":3, + "left_wrist":3, + "right_wrist":3, + "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + "left_index1":3, + "left_index2":3, + "left_index3":3, + "left_middle1":3, + "left_middle2":3, + "left_middle3":3, + "left_pinky1":3, + "left_pinky2":3, + "left_pinky3":3, + "left_ring1":3, + "left_ring2":3, + "left_ring3":3, + "left_thumb1":3, + "left_thumb2":3, + "left_thumb3":3, + "right_index1":3, + "right_index2":3, + "right_index3":3, + "right_middle1":3, + "right_middle2":3, + "right_middle3":3, + "right_pinky1":3, + "right_pinky2":3, + "right_pinky3":3, + "right_ring1":3, + "right_ring2":3, + "right_ring3":3, + "right_thumb1":3, + "right_thumb2":3, + "right_thumb3":3, + }, + + "beat_smplx_full": { + "pelvis":3, + "left_hip":3, + "right_hip":3, + "spine1":3, + "left_knee":3, + "right_knee":3, + "spine2":3, + "left_ankle":3, + "right_ankle":3, + "spine3":3, + "left_foot":3, + "right_foot":3, + "neck":3, + "left_collar":3, + "right_collar":3, + "head":3, + "left_shoulder":3, + "right_shoulder":3, + "left_elbow":3, + "right_elbow":3, + "left_wrist":3, + "right_wrist":3, + "jaw":3, + "left_eye_smplhf":3, + "right_eye_smplhf":3, + "left_index1":3, + "left_index2":3, + "left_index3":3, + "left_middle1":3, + "left_middle2":3, + "left_middle3":3, + "left_pinky1":3, + "left_pinky2":3, + "left_pinky3":3, + "left_ring1":3, + "left_ring2":3, + "left_ring3":3, + "left_thumb1":3, + "left_thumb2":3, + "left_thumb3":3, + "right_index1":3, + "right_index2":3, + "right_index3":3, + "right_middle1":3, + "right_middle2":3, + "right_middle3":3, + "right_pinky1":3, + "right_pinky2":3, + "right_pinky3":3, + "right_ring1":3, + "right_ring2":3, + "right_ring3":3, + "right_thumb1":3, + "right_thumb2":3, + "right_thumb3":3, + }, + + "beat_smplx_upall": { + # "pelvis":3, + # "left_hip":3, + # "right_hip":3, + "spine1":3, + # "left_knee":3, + # "right_knee":3, + "spine2":3, + # "left_ankle":3, + # "right_ankle":3, + "spine3":3, + # "left_foot":3, + # "right_foot":3, + "neck":3, + "left_collar":3, + "right_collar":3, + "head":3, + "left_shoulder":3, + "right_shoulder":3, + "left_elbow":3, + "right_elbow":3, + "left_wrist":3, + "right_wrist":3, + # "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + "left_index1":3, + "left_index2":3, + "left_index3":3, + "left_middle1":3, + "left_middle2":3, + "left_middle3":3, + "left_pinky1":3, + "left_pinky2":3, + "left_pinky3":3, + "left_ring1":3, + "left_ring2":3, + "left_ring3":3, + "left_thumb1":3, + "left_thumb2":3, + "left_thumb3":3, + "right_index1":3, + "right_index2":3, + "right_index3":3, + "right_middle1":3, + "right_middle2":3, + "right_middle3":3, + "right_pinky1":3, + "right_pinky2":3, + "right_pinky3":3, + "right_ring1":3, + "right_ring2":3, + "right_ring3":3, + "right_thumb1":3, + "right_thumb2":3, + "right_thumb3":3, + }, + + "beat_smplx_upper": { + #"pelvis":3, + # "left_hip":3, + # "right_hip":3, + "spine1":3, + # "left_knee":3, + # "right_knee":3, + "spine2":3, + # "left_ankle":3, + # "right_ankle":3, + "spine3":3, + # "left_foot":3, + # "right_foot":3, + "neck":3, + "left_collar":3, + "right_collar":3, + "head":3, + "left_shoulder":3, + "right_shoulder":3, + "left_elbow":3, + "right_elbow":3, + "left_wrist":3, + "right_wrist":3, + # "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + # "left_index1":3, + # "left_index2":3, + # "left_index3":3, + # "left_middle1":3, + # "left_middle2":3, + # "left_middle3":3, + # "left_pinky1":3, + # "left_pinky2":3, + # "left_pinky3":3, + # "left_ring1":3, + # "left_ring2":3, + # "left_ring3":3, + # "left_thumb1":3, + # "left_thumb2":3, + # "left_thumb3":3, + # "right_index1":3, + # "right_index2":3, + # "right_index3":3, + # "right_middle1":3, + # "right_middle2":3, + # "right_middle3":3, + # "right_pinky1":3, + # "right_pinky2":3, + # "right_pinky3":3, + # "right_ring1":3, + # "right_ring2":3, + # "right_ring3":3, + # "right_thumb1":3, + # "right_thumb2":3, + # "right_thumb3":3, + }, + + "beat_smplx_hands": { + #"pelvis":3, + # "left_hip":3, + # "right_hip":3, + # "spine1":3, + # "left_knee":3, + # "right_knee":3, + # "spine2":3, + # "left_ankle":3, + # "right_ankle":3, + # "spine3":3, + # "left_foot":3, + # "right_foot":3, + # "neck":3, + # "left_collar":3, + # "right_collar":3, + # "head":3, + # "left_shoulder":3, + # "right_shoulder":3, + # "left_elbow":3, + # "right_elbow":3, + # "left_wrist":3, + # "right_wrist":3, + # "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + "left_index1":3, + "left_index2":3, + "left_index3":3, + "left_middle1":3, + "left_middle2":3, + "left_middle3":3, + "left_pinky1":3, + "left_pinky2":3, + "left_pinky3":3, + "left_ring1":3, + "left_ring2":3, + "left_ring3":3, + "left_thumb1":3, + "left_thumb2":3, + "left_thumb3":3, + "right_index1":3, + "right_index2":3, + "right_index3":3, + "right_middle1":3, + "right_middle2":3, + "right_middle3":3, + "right_pinky1":3, + "right_pinky2":3, + "right_pinky3":3, + "right_ring1":3, + "right_ring2":3, + "right_ring3":3, + "right_thumb1":3, + "right_thumb2":3, + "right_thumb3":3, + }, + + "beat_smplx_lower": { + "pelvis":3, + "left_hip":3, + "right_hip":3, + # "spine1":3, + "left_knee":3, + "right_knee":3, + # "spine2":3, + "left_ankle":3, + "right_ankle":3, + # "spine3":3, + "left_foot":3, + "right_foot":3, + # "neck":3, + # "left_collar":3, + # "right_collar":3, + # "head":3, + # "left_shoulder":3, + # "right_shoulder":3, + # "left_elbow":3, + # "right_elbow":3, + # "left_wrist":3, + # "right_wrist":3, + # "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + # "left_index1":3, + # "left_index2":3, + # "left_index3":3, + # "left_middle1":3, + # "left_middle2":3, + # "left_middle3":3, + # "left_pinky1":3, + # "left_pinky2":3, + # "left_pinky3":3, + # "left_ring1":3, + # "left_ring2":3, + # "left_ring3":3, + # "left_thumb1":3, + # "left_thumb2":3, + # "left_thumb3":3, + # "right_index1":3, + # "right_index2":3, + # "right_index3":3, + # "right_middle1":3, + # "right_middle2":3, + # "right_middle3":3, + # "right_pinky1":3, + # "right_pinky2":3, + # "right_pinky3":3, + # "right_ring1":3, + # "right_ring2":3, + # "right_ring3":3, + # "right_thumb1":3, + # "right_thumb2":3, + # "right_thumb3":3, + }, + + "beat_smplx_face": { + # "pelvis":3, + # "left_hip":3, + # "right_hip":3, + # # "spine1":3, + # "left_knee":3, + # "right_knee":3, + # # "spine2":3, + # "left_ankle":3, + # "right_ankle":3, + # # "spine3":3, + # "left_foot":3, + # "right_foot":3, + # "neck":3, + # "left_collar":3, + # "right_collar":3, + # "head":3, + # "left_shoulder":3, + # "right_shoulder":3, + # "left_elbow":3, + # "right_elbow":3, + # "left_wrist":3, + # "right_wrist":3, + "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + # "left_index1":3, + # "left_index2":3, + # "left_index3":3, + # "left_middle1":3, + # "left_middle2":3, + # "left_middle3":3, + # "left_pinky1":3, + # "left_pinky2":3, + # "left_pinky3":3, + # "left_ring1":3, + # "left_ring2":3, + # "left_ring3":3, + # "left_thumb1":3, + # "left_thumb2":3, + # "left_thumb3":3, + # "right_index1":3, + # "right_index2":3, + # "right_index3":3, + # "right_middle1":3, + # "right_middle2":3, + # "right_middle3":3, + # "right_pinky1":3, + # "right_pinky2":3, + # "right_pinky3":3, + # "right_ring1":3, + # "right_ring2":3, + # "right_ring3":3, + # "right_thumb1":3, + # "right_thumb2":3, + # "right_thumb3":3, + }, + + "beat_joints": { + 'Hips': [6,6], + 'Spine': [3,9], + 'Spine1': [3,12], + 'Spine2': [3,15], + 'Spine3': [3,18], + 'Neck': [3,21], + 'Neck1': [3,24], + 'Head': [3,27], + 'HeadEnd': [3,30], + + 'RShoulder': [3,33], + 'RArm': [3,36], + 'RArm1': [3,39], + 'RHand': [3,42], + 'RHandM1': [3,45], + 'RHandM2': [3,48], + 'RHandM3': [3,51], + 'RHandM4': [3,54], + + 'RHandR': [3,57], + 'RHandR1': [3,60], + 'RHandR2': [3,63], + 'RHandR3': [3,66], + 'RHandR4': [3,69], + + 'RHandP': [3,72], + 'RHandP1': [3,75], + 'RHandP2': [3,78], + 'RHandP3': [3,81], + 'RHandP4': [3,84], + + 'RHandI': [3,87], + 'RHandI1': [3,90], + 'RHandI2': [3,93], + 'RHandI3': [3,96], + 'RHandI4': [3,99], + + 'RHandT1': [3,102], + 'RHandT2': [3,105], + 'RHandT3': [3,108], + 'RHandT4': [3,111], + + 'LShoulder': [3,114], + 'LArm': [3,117], + 'LArm1': [3,120], + 'LHand': [3,123], + 'LHandM1': [3,126], + 'LHandM2': [3,129], + 'LHandM3': [3,132], + 'LHandM4': [3,135], + + 'LHandR': [3,138], + 'LHandR1': [3,141], + 'LHandR2': [3,144], + 'LHandR3': [3,147], + 'LHandR4': [3,150], + + 'LHandP': [3,153], + 'LHandP1': [3,156], + 'LHandP2': [3,159], + 'LHandP3': [3,162], + 'LHandP4': [3,165], + + 'LHandI': [3,168], + 'LHandI1': [3,171], + 'LHandI2': [3,174], + 'LHandI3': [3,177], + 'LHandI4': [3,180], + + 'LHandT1': [3,183], + 'LHandT2': [3,186], + 'LHandT3': [3,189], + 'LHandT4': [3,192], + + 'RUpLeg': [3,195], + 'RLeg': [3,198], + 'RFoot': [3,201], + 'RFootF': [3,204], + 'RToeBase': [3,207], + 'RToeBaseEnd': [3,210], + + 'LUpLeg': [3,213], + 'LLeg': [3,216], + 'LFoot': [3,219], + 'LFootF': [3,222], + 'LToeBase': [3,225], + 'LToeBaseEnd': [3,228],}, + + "beat_full":{ + 'Hips': 3, + 'Spine': 3 , + 'Spine1': 3 , + 'Spine2': 3 , + 'Spine3': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'Head' : 3, + 'HeadEnd' : 3, + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandM4': 3 , + 'RHandR': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandR4': 3 , + 'RHandP': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'RHandP4': 3 , + 'RHandI': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandI4': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'RHandT4': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandM4': 3 , + 'LHandR': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandR4': 3 , + 'LHandP': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 , + 'LHandP4': 3 , + 'LHandI': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandI4': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 , + 'LHandT4': 3 , + 'RUpLeg': 3, + 'RLeg': 3, + 'RFoot': 3, + 'RFootF': 3, + 'RToeBase': 3, + 'RToeBaseEnd': 3, + 'LUpLeg': 3, + 'LLeg': 3, + 'LFoot': 3, + 'LFootF': 3, + 'LToeBase': 3, + 'LToeBaseEnd': 3, + }, + + "japanese_joints":{ + 'Hips': [6,6], + 'Spine': [6,12], + 'Spine1': [6,18], + 'Spine2': [6,24], + 'Spine3': [6,30], + 'Neck': [6,36], + 'Neck1': [6,42], + 'Head': [6,48], + 'RShoulder': [6,54], + 'RArm': [6,60], + 'RArm1': [6,66], + 'RHand': [6,72], + 'RHandM1': [6,78], + 'RHandM2': [6,84], + 'RHandM3': [6,90], + 'RHandR': [6,96], + 'RHandR1': [6,102], + 'RHandR2': [6,108], + 'RHandR3': [6,114], + 'RHandP': [6,120], + 'RHandP1': [6,126], + 'RHandP2': [6,132], + 'RHandP3': [6,138], + 'RHandI': [6,144], + 'RHandI1': [6,150], + 'RHandI2': [6,156], + 'RHandI3': [6,162], + 'RHandT1': [6,168], + 'RHandT2': [6,174], + 'RHandT3': [6,180], + 'LShoulder': [6,186], + 'LArm': [6,192], + 'LArm1': [6,198], + 'LHand': [6,204], + 'LHandM1': [6,210], + 'LHandM2': [6,216], + 'LHandM3': [6,222], + 'LHandR': [6,228], + 'LHandR1': [6,234], + 'LHandR2': [6,240], + 'LHandR3': [6,246], + 'LHandP': [6,252], + 'LHandP1': [6,258], + 'LHandP2': [6,264], + 'LHandP3': [6,270], + 'LHandI': [6,276], + 'LHandI1': [6,282], + 'LHandI2': [6,288], + 'LHandI3': [6,294], + 'LHandT1': [6,300], + 'LHandT2': [6,306], + 'LHandT3': [6,312], + 'RUpLeg': [6,318], + 'RLeg': [6,324], + 'RFoot': [6,330], + 'RFootF': [6,336], + 'RToeBase': [6,342], + 'LUpLeg': [6,348], + 'LLeg': [6,354], + 'LFoot': [6,360], + 'LFootF': [6,366], + 'LToeBase': [6,372],}, + + "yostar":{ + 'Hips': [6,6], + 'Spine': [3,9], + 'Spine1': [3,12], + 'Bone040': [3,15], + 'Bone041': [3,18], + + 'Bone034': [3,21], + 'Bone035': [3,24], + 'Bone036': [3,27], + 'Bone037': [3,30], + 'Bone038': [3,33], + 'Bone039': [3,36], + + 'RibbonL1': [3,39], + 'RibbonL1_end': [3,42], + + 'Chest': [3,45], + 'L_eri': [3,48], + 'R_eri': [3,51], + 'Neck': [3,54], + 'Head': [3,57], + 'Head_end': [3,60], + + 'RBackHair_1': [3,63], + 'RBackHair_2': [3,66], + 'RBackHair_3': [3,69], + 'RBackHair_4': [3,72], + 'RBackHair_end': [3,75], + + 'RFrontHair': [3,78], + 'CFrontHair_1': [3,81], + 'CFrontHair_2': [3,84], + 'CFrontHair_3': [3,87], + 'CFrontHair_emd': [3,90], + + 'LFrontHair_1': [3,93], + 'LFrontHair_2': [3,96], + 'LFrontHair_3': [3,99], + + 'LBackHair_1': [3,102], + 'LBackHair_2': [3,105], + 'LBackHair_3': [3,108], + 'LBackHair_4': [3,111], + 'LBackHair_end': [3,114], + + 'LSideHair_1': [3,117], + 'LSideHair_2': [3,120], + 'LSideHair_3': [3,123], + 'LSideHair_4': [3,126], + 'LSideHair_5': [3,129], + 'LSideHair_6': [3,132], + 'LSideHair_7': [3,135], + 'LSideHair_end': [3,138], + + 'CBackHair_1': [3,141], + 'CBackHair_2': [3,144], + 'CBackHair_3': [3,147], + 'CBackHair_4': [3,150], + 'CBackHair_end': [3,153], + + 'RSideHair_1': [3,156], + 'RSideHair_2': [3,159], + 'RSideHair_3': [3,162], + 'RSideHair_4': [3,165], + + 'RibbonR_1': [3,168], + 'RibbonR_2': [3,171], + 'RibbonR_3': [3,174], + + 'RibbonL_1': [3,177], + 'RibbonL_2': [3,180], + 'RibbonL_3': [3,183], + + 'LeftEye': [3,186], + 'LeftEye_end': [3,189], + 'RightEye': [3,192], + 'RightEye_end': [3,195], + + 'LeftShoulder': [3,198], + 'LeftArm': [3,201], + 'LeftForearm': [3,204], + 'LeftHand': [3,207], + 'LeftHandThumb1': [3,210], + 'LeftHandThumb2': [3,213], + 'LeftHandThumb3': [3,216], + 'LeftHandThumb_end': [3,219], + + 'LeftHandIndex1': [3,222], + 'LeftHandIndex2': [3,225], + 'LeftHandIndex3': [3,228], + 'LeftHandIndex_end': [3,231], + + 'LeftHandMiddle1': [3,234], + 'LeftHandMiddle2': [3,237], + 'LeftHandMiddle3': [3,240], + 'LeftHandMiddle_end': [3,243], + + 'LeftHandRing1': [3,246], + 'LeftHandRing2': [3,249], + 'LeftHandRing3': [3,252], + 'LeftHandRing_end': [3,255], + + 'LeftHandPinky1': [3,258], + 'LeftHandPinky2': [3,261], + 'LeftHandPinky3': [3,264], + 'LeftHandPinky_end': [3,267], + + 'RightShoulder': [3,270], + 'RightArm': [3,273], + 'RightForearm': [3,276], + 'RightHand': [3,279], + 'RightHandThumb1': [3,282], + 'RightHandThumb2': [3,285], + 'RightHandThumb3': [3,288], + 'RightHandThumb_end': [3,291], + + 'RightHandIndex1': [3,294], + 'RightHandIndex2': [3,297], + 'RightHandIndex3': [3,300], + 'RightHandIndex_end': [3,303], + + 'RightHandMiddle1': [3,306], + 'RightHandMiddle2': [3,309], + 'RightHandMiddle3': [3,312], + 'RightHandMiddle_end': [3,315], + + 'RightHandRing1': [3,318], + 'RightHandRing2': [3,321], + 'RightHandRing3': [3,324], + 'RightHandRing_end': [3,327], + + 'RightHandPinky1': [3,330], + 'RightHandPinky2': [3,333], + 'RightHandPinky3': [3,336], + 'RightHandPinky_end': [3,339], + + 'RibbonR1': [3,342], + 'RibbonR1_end': [3,345], + 'RibbonR2': [3,348], + 'RibbonR2_end': [3,351], + 'RibbonL2': [3,354], + 'RibbonL2_end': [3,357], + + 'LeftUpLeg': [3,360], + 'LeftLeg': [3,363], + 'LeftFoot': [3,366], + 'LeftToe': [3,369], + 'LeftToe_end': [3,372], + + 'RightUpLeg': [3,375], + 'RightLEg': [3,378], + 'RightFoot': [3,381], + 'RightToe': [3,384], + 'RightToe_end': [3,387], + + 'bone_skirtF00': [3, 390], + 'bone_skirtF01': [3, 393], + 'bone_skirtF02': [3, 396], + 'bone_skirtF03': [3, 399], + 'Bone020': [3, 402], + 'Bone026': [3, 405], + + 'bone_skirtF_R_00': [3, 408], + 'bone_skirtF_R_01': [3, 411], + 'bone_skirtF_R_02': [3, 414], + 'bone_skirtF_R_03': [3, 417], + 'Bone019': [3, 420], + 'Bone028': [3, 423], + + 'bone_skirtR00': [3, 426], + 'bone_skirtR01': [3, 429], + 'bone_skirtR02': [3, 432], + 'bone_skirtR03': [3, 435], + 'Bone018': [3, 438], + 'Bone029': [3, 441], + + 'bone_skirtF_L_00': [3, 444], + 'bone_skirtF_L_01': [3, 447], + 'bone_skirtF_L_02': [3, 450], + 'bone_skirtF_L_03': [3, 453], + 'Bone021': [3, 456], + 'Bone027': [3, 459], + + 'bone_skirtL00': [3, 462], + 'bone_skirtL01': [3, 465], + 'bone_skirtL02': [3, 468], + 'bone_skirtL03': [3, 471], + 'Bone022': [3, 474], + 'Bone033': [3, 477], + + 'bone_skirtB_L_00': [3, 480], + 'bone_skirtB_L_01': [3, 483], + 'bone_skirtB_L_02': [3, 486], + 'bone_skirtB_L_03': [3, 489], + 'Bone023': [3, 492], + 'Bone032': [3, 495], + + 'bone_skirtB00': [3, 498], + 'bone_skirtB01': [3, 501], + 'bone_skirtB02': [3, 504], + 'bone_skirtB03': [3, 507], + 'Bone024': [3, 510], + 'Bone031': [3, 513], + + 'bone_skirtB_R_00': [3, 516], + 'bone_skirtB_R_01': [3, 519], + 'bone_skirtB_R_02': [3, 521], + 'bone_skirtB_R_03': [3, 524], + 'Bone025': [3, 527], + 'Bone030': [3, 530], + }, + + "yostar_fullbody_213":{ + 'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Chest': 3 , + 'L_eri': 3 , + 'R_eri': 3 , + 'Neck': 3 , + 'Head': 3 , + 'Head_end': 3 , + + 'LeftEye': 3, + 'LeftEye_end': 3, + 'RightEye': 3, + 'RightEye_end': 3, + + 'LeftShoulder': 3, + 'LeftArm': 3, + 'LeftForearm': 3, + 'LeftHand': 3, + 'LeftHandThumb1': 3, + 'LeftHandThumb2': 3, + 'LeftHandThumb3': 3, + 'LeftHandThumb_end': 3, + + 'LeftHandIndex1': 3, + 'LeftHandIndex2': 3, + 'LeftHandIndex3': 3, + 'LeftHandIndex_end': 3, + + 'LeftHandMiddle1': 3, + 'LeftHandMiddle2': 3, + 'LeftHandMiddle3': 3, + 'LeftHandMiddle_end': 3, + + 'LeftHandRing1': 3, + 'LeftHandRing2': 3, + 'LeftHandRing3': 3, + 'LeftHandRing_end': 3, + + 'LeftHandPinky1': 3, + 'LeftHandPinky2': 3, + 'LeftHandPinky3': 3, + 'LeftHandPinky_end':3, + + 'RightShoulder': 3, + 'RightArm': 3, + 'RightForearm': 3, + 'RightHand': 3, + 'RightHandThumb1': 3, + 'RightHandThumb2': 3, + 'RightHandThumb3': 3, + 'RightHandThumb_end': 3, + + 'RightHandIndex1': 3, + 'RightHandIndex2': 3, + 'RightHandIndex3': 3, + 'RightHandIndex_end': 3, + + 'RightHandMiddle1': 3, + 'RightHandMiddle2': 3, + 'RightHandMiddle3': 3, + 'RightHandMiddle_end': 3, + + 'RightHandRing1': 3, + 'RightHandRing2': 3, + 'RightHandRing3': 3, + 'RightHandRing_end': 3, + + 'RightHandPinky1': 3, + 'RightHandPinky2': 3, + 'RightHandPinky3': 3, + 'RightHandPinky_end': 3, + + 'LeftUpLeg': 3, + 'LeftLeg': 3, + 'LeftFoot': 3, + 'LeftToe': 3, + 'LeftToe_end': 3, + + 'RightUpLeg': 3, + 'RightLEg': 3, + 'RightFoot': 3, + 'RightToe': 3, + 'RightToe_end': 3, + }, + "yostar_mainbody_48": { + #'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Chest': 3 , + 'L_eri': 3 , + 'R_eri': 3 , + 'Neck': 3 , + 'Head': 3 , + 'Head_end': 3 , + + 'LeftShoulder': 3, + 'LeftArm': 3, + 'LeftForearm': 3, + 'LeftHand': 3, + + 'RightShoulder': 3, + 'RightArm': 3, + 'RightForearm': 3, + 'RightHand': 3, + }, + "yostar_mainbody_69": { + 'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Chest': 3 , + 'L_eri': 3 , + 'R_eri': 3 , + 'Neck': 3 , + 'Head': 3 , + 'Head_end': 3 , + + 'LeftShoulder': 3, + 'LeftArm': 3, + 'LeftForearm': 3, + 'LeftHand': 3, + + 'RightShoulder': 3, + 'RightArm': 3, + 'RightForearm': 3, + 'RightHand': 3, + + 'LeftUpLeg': 3, + 'LeftLeg': 3, + 'LeftFoot': 3, + + 'RightUpLeg': 3, + 'RightLEg': 3, + 'RightFoot': 3, + }, + + "yostar_upbody_168": { + #'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Chest': 3 , + 'L_eri': 3 , + 'R_eri': 3 , + 'Neck': 3 , + 'Head': 3 , + 'Head_end': 3 , + + 'LeftShoulder': 3, + 'LeftArm': 3, + 'LeftForearm': 3, + 'LeftHand': 3, + 'LeftHandThumb1': 3, + 'LeftHandThumb2': 3, + 'LeftHandThumb3': 3, + 'LeftHandThumb_end': 3, + + 'LeftHandIndex1': 3, + 'LeftHandIndex2': 3, + 'LeftHandIndex3': 3, + 'LeftHandIndex_end': 3, + + 'LeftHandMiddle1': 3, + 'LeftHandMiddle2': 3, + 'LeftHandMiddle3': 3, + 'LeftHandMiddle_end': 3, + + 'LeftHandRing1': 3, + 'LeftHandRing2': 3, + 'LeftHandRing3': 3, + 'LeftHandRing_end': 3, + + 'LeftHandPinky1': 3, + 'LeftHandPinky2': 3, + 'LeftHandPinky3': 3, + 'LeftHandPinky_end':3, + + 'RightShoulder': 3, + 'RightArm': 3, + 'RightForearm': 3, + 'RightHand': 3, + 'RightHandThumb1': 3, + 'RightHandThumb2': 3, + 'RightHandThumb3': 3, + 'RightHandThumb_end': 3, + + 'RightHandIndex1': 3, + 'RightHandIndex2': 3, + 'RightHandIndex3': 3, + 'RightHandIndex_end': 3, + + 'RightHandMiddle1': 3, + 'RightHandMiddle2': 3, + 'RightHandMiddle3': 3, + 'RightHandMiddle_end': 3, + + 'RightHandRing1': 3, + 'RightHandRing2': 3, + 'RightHandRing3': 3, + 'RightHandRing_end': 3, + + 'RightHandPinky1': 3, + 'RightHandPinky2': 3, + 'RightHandPinky3': 3, + 'RightHandPinky_end': 3, + }, + "spine_neck_141":{ + 'Spine': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandR': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandP': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'RHandI': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandR': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandP': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 , + 'LHandI': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 ,}, +} + + +class FIDCalculator(object): + ''' + todo + ''' + def __init__(self): + self.gt_rot = None # pandas dataframe for n frames * joints * 6 + self.gt_pos = None # n frames * (joints + 13) * 3 + self.op_rot = None # pandas dataframe for n frames * joints * 6 + self.op_pos = None # n frames * (joints + 13) * 3 + + + def load(self, path, load_type, save_pos=False): + ''' + select gt or op for load_type + ''' + parser = BVHParser() + parsed_data = parser.parse(path) + if load_type == 'gt': + self.gt_rot = parsed_data.values + elif load_type == 'op': + self.op_rot = parsed_data.values + else: print('error, select gt or op for load_type') + + if save_pos: + mp = MocapParameterizer('position') + positions = mp.fit_transform([parsed_data]) + if load_type == 'gt': + self.gt_pos = positions[0].values + elif load_type == 'op': + self.op_pos = positions[0].values + else: print('error, select gt or op for load_type') + + + def _joint_selector(self, selected_joints, ori_data): + selected_data = pd.DataFrame(columns=[]) + + for joint_name in selected_joints: + selected_data[joint_name] = ori_data[joint_name] + return selected_data.to_numpy() + + + def cal_vol(self, dtype): + if dtype == 'pos': + gt = self.gt_pos + op = self.op_pos + else: + gt = self.gt_rot + op = self.op_rot + + gt_v = gt.to_numpy()[1:, :] - gt.to_numpy()[0:-1, :] + op_v = op.to_numpy()[1:, :] - op.to_numpy()[0:-1, :] + if dtype == 'pos': + self.gt_vol_pos = pd.DataFrame(gt_v, columns = gt.columns.tolist()) + self.op_vol_pos = pd.DataFrame(op_v, columns = gt.columns.tolist()) + else: + self.gt_vol_rot = pd.DataFrame(gt_v, columns = gt.columns.tolist()) + self.op_vol_rot = pd.DataFrame(op_v, columns = gt.columns.tolist()) + + + @staticmethod + def frechet_distance(samples_A, samples_B): + A_mu = np.mean(samples_A, axis=0) + A_sigma = np.cov(samples_A, rowvar=False) + B_mu = np.mean(samples_B, axis=0) + B_sigma = np.cov(samples_B, rowvar=False) + try: + frechet_dist = FIDCalculator.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma) + except ValueError: + frechet_dist = 1e+10 + return frechet_dist + + + @staticmethod + def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """ + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + #print(mu1[0], mu2[0]) + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + #print(sigma1[0], sigma2[0]) + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + #print(diff, covmean[0]) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + + def calculate_fid(self, cal_type, joint_type, high_level_opt): + + if cal_type == 'pos': + if self.gt_pos.shape != self.op_pos.shape: + min_val = min(self.gt_pos.shape[0],self.op_pos.shape[0]) + gt = self.gt_pos[:min_val] + op = self.op_pos[:min_val] + else: + gt = self.gt_pos + op = self.op_pos + full_body = gt.columns.tolist() + elif cal_type == 'rot': + if self.gt_rot.shape != self.op_rot.shape: + min_val = min(self.gt_rot.shape[0],self.op_rot.shape[0]) + gt = self.gt_rot[:min_val] + op = self.op_rot[:min_val] + else: + gt = self.gt_rot + op = self.op_rot + full_body_with_offset = gt.columns.tolist() + full_body = [o for o in full_body_with_offset if ('position' not in o)] + elif cal_type == 'pos_vol': + assert self.gt_vol_pos.shape == self.op_vol_pos.shape + gt = self.gt_vol_pos + op = self.op_vol_pos + full_body_with_offset = gt.columns.tolist() + full_body = gt.columns.tolist() + elif cal_type == 'rot_vol': + assert self.gt_vol_rot.shape == self.op_vol_rot.shape + gt = self.gt_vol_rot + op = self.op_vol_rot + full_body_with_offset = gt.columns.tolist() + full_body = [o for o in full_body_with_offset if ('position' not in o)] + #print(f'full_body contains {len(full_body)//3} joints') + + if joint_type == 'full_upper_body': + selected_body = [o for o in full_body if ('Leg' not in o) and ('Foot' not in o) and ('Toe' not in o)] + elif joint_type == 'upper_body': + selected_body = [o for o in full_body if ('Hand' not in o) and ('Leg' not in o) and ('Foot' not in o) and ('Toe' not in o)] + elif joint_type == 'fingers': + selected_body = [o for o in full_body if ('Hand' in o)] + elif joint_type == 'indivdual': + pass + else: print('error, plz select correct joint type') + #print(f'calculate fid for {len(selected_body)//3} joints') + + gt = self._joint_selector(selected_body, gt) + op = self._joint_selector(selected_body, op) + + if high_level_opt == 'fid': + fid = FIDCalculator.frechet_distance(gt, op) + return fid + elif high_level_opt == 'var': + var_gt = gt.var() + var_op = op.var() + return var_gt, var_op + elif high_level_opt == 'mean': + mean_gt = gt.mean() + mean_op = op.mean() + return mean_gt, mean_op + else: return 0 + + +def result2target_vis(pose_version, res_bvhlist, save_path, demo_name, verbose=True): + if "trinity" in pose_version: + ori_list = joints_list[pose_version[6:-4]] + target_list = joints_list[pose_version[6:]] + file_content_length = 336 + elif "beat" in pose_version or "spine_neck_141" in pose_version: + ori_list = joints_list["beat_joints"] + target_list = joints_list["spine_neck_141"] + file_content_length = 431 + elif "yostar" in pose_version: + ori_list = joints_list["yostar"] + target_list = joints_list[pose_version] + file_content_length = 1056 + else: + ori_list = joints_list["japanese_joints"] + target_list = joints_list[pose_version] + file_content_length = 366 + + bvh_files_dirs = sorted(glob.glob(f'{res_bvhlist}*.bvh'), key=str) + #test_seq_list = os.list_dir(demo_name).sort() + + counter = 0 + if not os.path.exists(save_path): + os.makedirs(save_path) + for i, bvh_file_dir in enumerate(bvh_files_dirs): + short_name = bvh_file_dir.split("/")[-1][11:] + #print(short_name) + wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'w+') + with open(f"{demo_name}{short_name}",'r') as pose_data_pre: + pose_data_pre_file = pose_data_pre.readlines() + for j, line in enumerate(pose_data_pre_file[0:file_content_length]): + wirte_file.write(line) + offset_data = pose_data_pre_file[file_content_length] + offset_data = np.fromstring(offset_data, dtype=float, sep=' ') + wirte_file.close() + + wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'r') + ori_lines = wirte_file.readlines() + with open(bvh_file_dir, 'r') as pose_data: + pose_data_file = pose_data.readlines() + ori_lines[file_content_length-2] = 'Frames: ' + str(len(pose_data_file)-1) + '\n' + wirte_file.close() + + wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'w+') + wirte_file.writelines(i for i in ori_lines[:file_content_length]) + wirte_file.close() + + with open(os.path.join(save_path, f'res_{short_name}'),'a+') as wirte_file: + with open(bvh_file_dir, 'r') as pose_data: + data_each_file = [] + pose_data_file = pose_data.readlines() + for j, line in enumerate(pose_data_file): + if not j: + pass + else: + data = np.fromstring(line, dtype=float, sep=' ') + data_rotation = offset_data.copy() + for iii, (k, v) in enumerate(target_list.items()): # here is 147 rotations by 3 + #print(data_rotation[ori_list[k][1]-v:ori_list[k][1]], data[iii*3:iii*3+3]) + data_rotation[ori_list[k][1]-v:ori_list[k][1]] = data[iii*3:iii*3+3] + data_each_file.append(data_rotation) + + for line_data in data_each_file: + line_data = np.array2string(line_data, max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + wirte_file.write(line_data[1:-2]+'\n') + + counter += 1 + if verbose: + logger.info('data_shape:', data_rotation.shape, 'process:', counter, '/', len(bvh_files_dirs)) \ No newline at end of file diff --git a/dataloaders/pymo/Quaternions.py b/dataloaders/pymo/Quaternions.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b754871310a264e2bd2675479db9a79d24358e --- /dev/null +++ b/dataloaders/pymo/Quaternions.py @@ -0,0 +1,468 @@ +import numpy as np + +class Quaternions: + """ + Quaternions is a wrapper around a numpy ndarray + that allows it to act as if it were an narray of + a quaternion data type. + + Therefore addition, subtraction, multiplication, + division, negation, absolute, are all defined + in terms of quaternion operations such as quaternion + multiplication. + + This allows for much neater code and many routines + which conceptually do the same thing to be written + in the same way for point data and for rotation data. + + The Quaternions class has been desgined such that it + should support broadcasting and slicing in all of the + usual ways. + """ + + def __init__(self, qs): + if isinstance(qs, np.ndarray): + + if len(qs.shape) == 1: qs = np.array([qs]) + self.qs = qs + return + + if isinstance(qs, Quaternions): + self.qs = qs.qs + return + + raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs)) + + def __str__(self): return "Quaternions("+ str(self.qs) + ")" + def __repr__(self): return "Quaternions("+ repr(self.qs) + ")" + + """ Helper Methods for Broadcasting and Data extraction """ + + @classmethod + def _broadcast(cls, sqs, oqs, scalar=False): + + if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1]) + + ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1]) + os = np.array(oqs.shape) + + if len(ss) != len(os): + raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) + + if np.all(ss == os): return sqs, oqs + + if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))): + raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) + + sqsn, oqsn = sqs.copy(), oqs.copy() + + for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a) + for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a) + + return sqsn, oqsn + + """ Adding Quaterions is just Defined as Multiplication """ + + def __add__(self, other): return self * other + def __sub__(self, other): return self / other + + """ Quaterion Multiplication """ + + def __mul__(self, other): + """ + Quaternion multiplication has three main methods. + + When multiplying a Quaternions array by Quaternions + normal quaternion multiplication is performed. + + When multiplying a Quaternions array by a vector + array of the same shape, where the last axis is 3, + it is assumed to be a Quaternion by 3D-Vector + multiplication and the 3D-Vectors are rotated + in space by the Quaternions. + + When multipplying a Quaternions array by a scalar + or vector of different shape it is assumed to be + a Quaternions by Scalars multiplication and the + Quaternions are scaled using Slerp and the identity + quaternions. + """ + + """ If Quaternions type do Quaternions * Quaternions """ + if isinstance(other, Quaternions): + + sqs, oqs = Quaternions._broadcast(self.qs, other.qs) + + q0 = sqs[...,0]; q1 = sqs[...,1]; + q2 = sqs[...,2]; q3 = sqs[...,3]; + r0 = oqs[...,0]; r1 = oqs[...,1]; + r2 = oqs[...,2]; r3 = oqs[...,3]; + + qs = np.empty(sqs.shape) + qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3 + qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2 + qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1 + qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0 + + return Quaternions(qs) + + """ If array type do Quaternions * Vectors """ + if isinstance(other, np.ndarray) and other.shape[-1] == 3: + vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1)) + return (self * (vs * -self)).imaginaries + + """ If float do Quaternions * Scalars """ + if isinstance(other, np.ndarray) or isinstance(other, float): + return Quaternions.slerp(Quaternions.id_like(self), self, other) + + raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other))) + + def __div__(self, other): + """ + When a Quaternion type is supplied, division is defined + as multiplication by the inverse of that Quaternion. + + When a scalar or vector is supplied it is defined + as multiplicaion of one over the supplied value. + Essentially a scaling. + """ + + if isinstance(other, Quaternions): return self * (-other) + if isinstance(other, np.ndarray): return self * (1.0 / other) + if isinstance(other, float): return self * (1.0 / other) + raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other))) + + def __eq__(self, other): return self.qs == other.qs + def __ne__(self, other): return self.qs != other.qs + + def __neg__(self): + """ Invert Quaternions """ + return Quaternions(self.qs * np.array([[1, -1, -1, -1]])) + + def __abs__(self): + """ Unify Quaternions To Single Pole """ + qabs = self.normalized().copy() + top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1) + bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1) + qabs.qs[top < bot] = -qabs.qs[top < bot] + return qabs + + def __iter__(self): return iter(self.qs) + def __len__(self): return len(self.qs) + + def __getitem__(self, k): return Quaternions(self.qs[k]) + def __setitem__(self, k, v): self.qs[k] = v.qs + + @property + def lengths(self): + return np.sum(self.qs**2.0, axis=-1)**0.5 + + @property + def reals(self): + return self.qs[...,0] + + @property + def imaginaries(self): + return self.qs[...,1:4] + + @property + def shape(self): return self.qs.shape[:-1] + + def repeat(self, n, **kwargs): + return Quaternions(self.qs.repeat(n, **kwargs)) + + def normalized(self): + return Quaternions(self.qs / self.lengths[...,np.newaxis]) + + def log(self): + norm = abs(self.normalized()) + imgs = norm.imaginaries + lens = np.sqrt(np.sum(imgs**2, axis=-1)) + lens = np.arctan2(lens, norm.reals) / (lens + 1e-10) + return imgs * lens[...,np.newaxis] + + def constrained(self, axis): + + rl = self.reals + im = np.sum(axis * self.imaginaries, axis=-1) + + t1 = -2 * np.arctan2(rl, im) + np.pi + t2 = -2 * np.arctan2(rl, im) - np.pi + + top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0)) + bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0)) + img = self.dot(top) > self.dot(bot) + + ret = top.copy() + ret[ img] = top[ img] + ret[~img] = bot[~img] + return ret + + def constrained_x(self): return self.constrained(np.array([1,0,0])) + def constrained_y(self): return self.constrained(np.array([0,1,0])) + def constrained_z(self): return self.constrained(np.array([0,0,1])) + + def dot(self, q): return np.sum(self.qs * q.qs, axis=-1) + + def copy(self): return Quaternions(np.copy(self.qs)) + + def reshape(self, s): + self.qs.reshape(s) + return self + + def interpolate(self, ws): + return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws)) + + def euler(self, order='xyz'): + + q = self.normalized().qs + q0 = q[...,0] + q1 = q[...,1] + q2 = q[...,2] + q3 = q[...,3] + es = np.zeros(self.shape + (3,)) + + if order == 'xyz': + es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0) + es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0) + es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1)) + else: + raise NotImplementedError('Cannot convert from ordering %s' % order) + + """ + + # These conversion don't appear to work correctly for Maya. + # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/ + + if order == 'xyz': + es[...,0] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + elif order == 'yzx': + es[...,0] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + elif order == 'zxy': + es[...,0] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + elif order == 'xzy': + es[...,0] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + elif order == 'yxz': + es[...,0] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + elif order == 'zyx': + es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + else: + raise KeyError('Unknown ordering %s' % order) + + """ + + # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp + # Use this class and convert from matrix + + return es + + + def average(self): + + if len(self.shape) == 1: + + import numpy.core.umath_tests as ut + system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0) + w, v = np.linalg.eigh(system) + qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1) + return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))]) + + else: + + raise NotImplementedError('Cannot average multi-dimensionsal Quaternions') + + def angle_axis(self): + + norm = self.normalized() + s = np.sqrt(1 - (norm.reals**2.0)) + s[s == 0] = 0.001 + + angles = 2.0 * np.arccos(norm.reals) + axis = norm.imaginaries / s[...,np.newaxis] + + return angles, axis + + + def transforms(self): + + qw = self.qs[...,0] + qx = self.qs[...,1] + qy = self.qs[...,2] + qz = self.qs[...,3] + + x2 = qx + qx; y2 = qy + qy; z2 = qz + qz; + xx = qx * x2; yy = qy * y2; wx = qw * x2; + xy = qx * y2; yz = qy * z2; wy = qw * y2; + xz = qx * z2; zz = qz * z2; wz = qw * z2; + + m = np.empty(self.shape + (3,3)) + m[...,0,0] = 1.0 - (yy + zz) + m[...,0,1] = xy - wz + m[...,0,2] = xz + wy + m[...,1,0] = xy + wz + m[...,1,1] = 1.0 - (xx + zz) + m[...,1,2] = yz - wx + m[...,2,0] = xz - wy + m[...,2,1] = yz + wx + m[...,2,2] = 1.0 - (xx + yy) + + return m + + def ravel(self): + return self.qs.ravel() + + @classmethod + def id(cls, n): + + if isinstance(n, tuple): + qs = np.zeros(n + (4,)) + qs[...,0] = 1.0 + return Quaternions(qs) + + if isinstance(n, int) or isinstance(n, long): + qs = np.zeros((n,4)) + qs[:,0] = 1.0 + return Quaternions(qs) + + raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n))) + + @classmethod + def id_like(cls, a): + qs = np.zeros(a.shape + (4,)) + qs[...,0] = 1.0 + return Quaternions(qs) + + @classmethod + def exp(cls, ws): + + ts = np.sum(ws**2.0, axis=-1)**0.5 + ts[ts == 0] = 0.001 + ls = np.sin(ts) / ts + + qs = np.empty(ws.shape[:-1] + (4,)) + qs[...,0] = np.cos(ts) + qs[...,1] = ws[...,0] * ls + qs[...,2] = ws[...,1] * ls + qs[...,3] = ws[...,2] * ls + + return Quaternions(qs).normalized() + + @classmethod + def slerp(cls, q0s, q1s, a): + + fst, snd = cls._broadcast(q0s.qs, q1s.qs) + fst, a = cls._broadcast(fst, a, scalar=True) + snd, a = cls._broadcast(snd, a, scalar=True) + + len = np.sum(fst * snd, axis=-1) + + neg = len < 0.0 + len[neg] = -len[neg] + snd[neg] = -snd[neg] + + amount0 = np.zeros(a.shape) + amount1 = np.zeros(a.shape) + + linear = (1.0 - len) < 0.01 + omegas = np.arccos(len[~linear]) + sinoms = np.sin(omegas) + + amount0[ linear] = 1.0 - a[linear] + amount1[ linear] = a[linear] + amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms + amount1[~linear] = np.sin( a[~linear] * omegas) / sinoms + + return Quaternions( + amount0[...,np.newaxis] * fst + + amount1[...,np.newaxis] * snd) + + @classmethod + def between(cls, v0s, v1s): + a = np.cross(v0s, v1s) + w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1) + return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized() + + @classmethod + def from_angle_axis(cls, angles, axis): + axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis] + sines = np.sin(angles / 2.0)[...,np.newaxis] + cosines = np.cos(angles / 2.0)[...,np.newaxis] + return Quaternions(np.concatenate([cosines, axis * sines], axis=-1)) + + @classmethod + def from_euler(cls, es, order='xyz', world=False): + + axis = { + 'x' : np.array([1,0,0]), + 'y' : np.array([0,1,0]), + 'z' : np.array([0,0,1]), + } + + q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]]) + q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]]) + q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]]) + + return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s)) + + @classmethod + def from_transforms(cls, ts): + + d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2] + + q0 = ( d0 + d1 + d2 + 1.0) / 4.0 + q1 = ( d0 - d1 - d2 + 1.0) / 4.0 + q2 = (-d0 + d1 - d2 + 1.0) / 4.0 + q3 = (-d0 - d1 + d2 + 1.0) / 4.0 + + q0 = np.sqrt(q0.clip(0,None)) + q1 = np.sqrt(q1.clip(0,None)) + q2 = np.sqrt(q2.clip(0,None)) + q3 = np.sqrt(q3.clip(0,None)) + + c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3) + c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3) + c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3) + c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2) + + q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2]) + q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0]) + q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1]) + + q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2]) + q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1]) + q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0]) + + q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0]) + q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1]) + q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2]) + + q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1]) + q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2]) + q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2]) + + qs = np.empty(ts.shape[:-2] + (4,)) + qs[...,0] = q0 + qs[...,1] = q1 + qs[...,2] = q2 + qs[...,3] = q3 + + return cls(qs) + + + \ No newline at end of file diff --git a/dataloaders/pymo/__init__.py b/dataloaders/pymo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dataloaders/pymo/__pycache__/Quaternions.cpython-310.pyc b/dataloaders/pymo/__pycache__/Quaternions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c442d4bb4e465e0b2ac01b857277011dc21b1a0a Binary files /dev/null and b/dataloaders/pymo/__pycache__/Quaternions.cpython-310.pyc differ diff --git a/dataloaders/pymo/__pycache__/Quaternions.cpython-38.pyc b/dataloaders/pymo/__pycache__/Quaternions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a633c51e23554d02d76e0c91871ea03071adb39 Binary files /dev/null and b/dataloaders/pymo/__pycache__/Quaternions.cpython-38.pyc differ diff --git a/dataloaders/pymo/__pycache__/__init__.cpython-310.pyc b/dataloaders/pymo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ca872db26fcc35f1ca291eb73f69d1e00a65a28 Binary files /dev/null and b/dataloaders/pymo/__pycache__/__init__.cpython-310.pyc differ diff --git a/dataloaders/pymo/__pycache__/__init__.cpython-38.pyc b/dataloaders/pymo/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddaae9de304099439b221c6969b125291b42b2da Binary files /dev/null and b/dataloaders/pymo/__pycache__/__init__.cpython-38.pyc differ diff --git a/dataloaders/pymo/__pycache__/data.cpython-310.pyc b/dataloaders/pymo/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8ee1fd950153646a547b06997ab7a4923b90fa1 Binary files /dev/null and b/dataloaders/pymo/__pycache__/data.cpython-310.pyc differ diff --git a/dataloaders/pymo/__pycache__/data.cpython-38.pyc b/dataloaders/pymo/__pycache__/data.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e21f34f2ac73ab547c8e3ee741a0716229768d3 Binary files /dev/null and b/dataloaders/pymo/__pycache__/data.cpython-38.pyc differ diff --git a/dataloaders/pymo/__pycache__/parsers.cpython-310.pyc b/dataloaders/pymo/__pycache__/parsers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6fa4d089bc3fc42c0e5f813a88af38aab0b82e7 Binary files /dev/null and b/dataloaders/pymo/__pycache__/parsers.cpython-310.pyc differ diff --git a/dataloaders/pymo/__pycache__/parsers.cpython-38.pyc b/dataloaders/pymo/__pycache__/parsers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fd7b42e790095e0304bfe2d6d171abf2b2a7e4d Binary files /dev/null and b/dataloaders/pymo/__pycache__/parsers.cpython-38.pyc differ diff --git a/dataloaders/pymo/__pycache__/preprocessing.cpython-310.pyc b/dataloaders/pymo/__pycache__/preprocessing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..437e256e692b503b119ceed149866f8c75a7f49b Binary files /dev/null and b/dataloaders/pymo/__pycache__/preprocessing.cpython-310.pyc differ diff --git a/dataloaders/pymo/__pycache__/preprocessing.cpython-38.pyc b/dataloaders/pymo/__pycache__/preprocessing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88a4768f8d007de4b32e62874354dc3ae16b7875 Binary files /dev/null and b/dataloaders/pymo/__pycache__/preprocessing.cpython-38.pyc differ diff --git a/dataloaders/pymo/__pycache__/rotation_tools.cpython-310.pyc b/dataloaders/pymo/__pycache__/rotation_tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..361b98cfa75690ed68f5a39330a39b3fa717bc1b Binary files /dev/null and b/dataloaders/pymo/__pycache__/rotation_tools.cpython-310.pyc differ diff --git a/dataloaders/pymo/__pycache__/rotation_tools.cpython-38.pyc b/dataloaders/pymo/__pycache__/rotation_tools.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f6607c79dae42826f81737c479d227abeb7f93b Binary files /dev/null and b/dataloaders/pymo/__pycache__/rotation_tools.cpython-38.pyc differ diff --git a/dataloaders/pymo/__pycache__/viz_tools.cpython-310.pyc b/dataloaders/pymo/__pycache__/viz_tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a343c349362720914fc8f9ac8ed5a5d9e6bdf28 Binary files /dev/null and b/dataloaders/pymo/__pycache__/viz_tools.cpython-310.pyc differ diff --git a/dataloaders/pymo/__pycache__/viz_tools.cpython-38.pyc b/dataloaders/pymo/__pycache__/viz_tools.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..317a125318df8ac36572f0fd96fabeb43e484b09 Binary files /dev/null and b/dataloaders/pymo/__pycache__/viz_tools.cpython-38.pyc differ diff --git a/dataloaders/pymo/data.py b/dataloaders/pymo/data.py new file mode 100644 index 0000000000000000000000000000000000000000..7be4f0a819aa041218b8a3d78e700017253d277c --- /dev/null +++ b/dataloaders/pymo/data.py @@ -0,0 +1,53 @@ +import numpy as np + +class Joint(): + def __init__(self, name, parent=None, children=None): + self.name = name + self.parent = parent + self.children = children + +class MocapData(): + def __init__(self): + self.skeleton = {} + self.values = None + self.channel_names = [] + self.framerate = 0.0 + self.root_name = '' + + def traverse(self, j=None): + stack = [self.root_name] + while stack: + joint = stack.pop() + yield joint + for c in self.skeleton[joint]['children']: + stack.append(c) + + def clone(self): + import copy + new_data = MocapData() + new_data.skeleton = copy.copy(self.skeleton) + new_data.values = copy.copy(self.values) + new_data.channel_names = copy.copy(self.channel_names) + new_data.root_name = copy.copy(self.root_name) + new_data.framerate = copy.copy(self.framerate) + return new_data + + def get_all_channels(self): + '''Returns all of the channels parsed from the file as a 2D numpy array''' + + frames = [f[1] for f in self.values] + return np.asarray([[channel[2] for channel in frame] for frame in frames]) + + def get_skeleton_tree(self): + tree = [] + root_key = [j for j in self.skeleton if self.skeleton[j]['parent']==None][0] + + root_joint = Joint(root_key) + + def get_empty_channels(self): + #TODO + pass + + def get_constant_channels(self): + #TODO + pass diff --git a/dataloaders/pymo/features.py b/dataloaders/pymo/features.py new file mode 100644 index 0000000000000000000000000000000000000000..fec29ed5758f79b61f296e01e9b077cba573f495 --- /dev/null +++ b/dataloaders/pymo/features.py @@ -0,0 +1,43 @@ +''' +A set of mocap feature extraction functions + +Created by Omid Alemi | Nov 17 2017 + +''' +import numpy as np +import pandas as pd +import peakutils +import matplotlib.pyplot as plt + +def get_foot_contact_idxs(signal, t=0.02, min_dist=120): + up_idxs = peakutils.indexes(signal, thres=t/max(signal), min_dist=min_dist) + down_idxs = peakutils.indexes(-signal, thres=t/min(signal), min_dist=min_dist) + + return [up_idxs, down_idxs] + + +def create_foot_contact_signal(mocap_track, col_name, start=1, t=0.02, min_dist=120): + signal = mocap_track.values[col_name].values + idxs = get_foot_contact_idxs(signal, t, min_dist) + + step_signal = [] + + c = start + for f in range(len(signal)): + if f in idxs[1]: + c = 0 + elif f in idxs[0]: + c = 1 + + step_signal.append(c) + + return step_signal + +def plot_foot_up_down(mocap_track, col_name, t=0.02, min_dist=120): + + signal = mocap_track.values[col_name].values + idxs = get_foot_contact_idxs(signal, t, min_dist) + + plt.plot(mocap_track.values.index, signal) + plt.plot(mocap_track.values.index[idxs[0]], signal[idxs[0]], 'ro') + plt.plot(mocap_track.values.index[idxs[1]], signal[idxs[1]], 'go') diff --git a/dataloaders/pymo/mocapplayer/data-template.js b/dataloaders/pymo/mocapplayer/data-template.js new file mode 100644 index 0000000000000000000000000000000000000000..68a51392fb7d2458487eae2a00a3ed03c1e7153a --- /dev/null +++ b/dataloaders/pymo/mocapplayer/data-template.js @@ -0,0 +1,3 @@ +var dataBuffer = `$$DATA$$`; + +start(dataBuffer); \ No newline at end of file diff --git a/dataloaders/pymo/mocapplayer/js/skeletonFactory.js b/dataloaders/pymo/mocapplayer/js/skeletonFactory.js new file mode 100644 index 0000000000000000000000000000000000000000..e1d072b7df2fb40772e93f2dee595e467744e36b --- /dev/null +++ b/dataloaders/pymo/mocapplayer/js/skeletonFactory.js @@ -0,0 +1,233 @@ +bm_v = new THREE.MeshPhongMaterial({ + color: 0x08519c, + emissive: 0x08306b, + specular: 0x08519c, + shininess: 10, + side: THREE.DoubleSide +}); + +jm_v = new THREE.MeshPhongMaterial({ + color: 0x08306b, + emissive: 0x000000, + specular: 0x111111, + shininess: 90, + side: THREE.DoubleSide +}); + +bm_a = new THREE.MeshPhongMaterial({ + color: 0x980043, + emissive: 0x67001f, + specular: 0x6a51a3, + shininess: 10, + side: THREE.DoubleSide +}); + +jm_a = new THREE.MeshPhongMaterial({ + color: 0x67001f, + emissive: 0x000000, + specular: 0x111111, + shininess: 90, + side: THREE.DoubleSide +}); + +bm_b = new THREE.MeshPhongMaterial({ + color: 0x3f007d, + emissive: 0x3f007d, + specular: 0x807dba, + shininess: 2, + side: THREE.DoubleSide +}); + +jm_b = new THREE.MeshPhongMaterial({ + color: 0x3f007d, + emissive: 0x000000, + specular: 0x807dba, + shininess: 90, + side: THREE.DoubleSide +}); + +//------------------ + + +jointmaterial = new THREE.MeshLambertMaterial({ + color: 0xc57206, + emissive: 0x271c18, + side: THREE.DoubleSide, + // shading: THREE.FlatShading, + wireframe: false, + shininess: 90, +}); + +bonematerial = new THREE.MeshPhongMaterial({ + color: 0xbd9a6d, + emissive: 0x271c18, + side: THREE.DoubleSide, + // shading: THREE.FlatShading, + wireframe: false +}); + +jointmaterial2 = new THREE.MeshPhongMaterial({ + color: 0x1562a2, + emissive: 0x000000, + specular: 0x111111, + shininess: 30, + side: THREE.DoubleSide +}); + +bonematerial2 = new THREE.MeshPhongMaterial({ + color: 0x552211, + emissive: 0x882211, + // emissive: 0x000000, + specular: 0x111111, + shininess: 30, + side: THREE.DoubleSide +}); + +bonematerial3 = new THREE.MeshPhongMaterial({ + color: 0x176793, + emissive: 0x000000, + specular: 0x111111, + shininess: 90, + side: THREE.DoubleSide +}); + + + +jointmaterial4 = new THREE.MeshPhongMaterial({ + color: 0xFF8A00, + emissive: 0x000000, + specular: 0x111111, + shininess: 90, + side: THREE.DoubleSide +}); + + +bonematerial4 = new THREE.MeshPhongMaterial({ + color: 0x53633D, + emissive: 0x000000, + specular: 0xFFC450, + shininess: 90, + side: THREE.DoubleSide +}); + + + +bonematerial44 = new THREE.MeshPhongMaterial({ + color: 0x582A72, + emissive: 0x000000, + specular: 0xFFC450, + shininess: 90, + side: THREE.DoubleSide +}); + +jointmaterial5 = new THREE.MeshPhongMaterial({ + color: 0xAA5533, + emissive: 0x000000, + specular: 0x111111, + shininess: 30, + side: THREE.DoubleSide +}); + +bonematerial5 = new THREE.MeshPhongMaterial({ + color: 0x552211, + emissive: 0x772211, + specular: 0x111111, + shininess: 30, + side: THREE.DoubleSide +}); + + +markermaterial = new THREE.MeshPhongMaterial({ + color: 0xc57206, + emissive: 0x271c18, + side: THREE.DoubleSide, + // shading: THREE.FlatShading, + wireframe: false, + shininess: 20, +}); + +markermaterial2 = new THREE.MeshPhongMaterial({ + color: 0x1562a2, + emissive: 0x271c18, + side: THREE.DoubleSide, + // shading: THREE.FlatShading, + wireframe: false, + shininess: 20, +}); + +markermaterial3 = new THREE.MeshPhongMaterial({ + color: 0x555555, + emissive: 0x999999, + side: THREE.DoubleSide, + // shading: THREE.FlatShading, + wireframe: false, + shininess: 20, +}); + + +var makeMarkerGeometry_Sphere10 = function(markerName, scale) { + return new THREE.SphereGeometry(10, 60, 60); +}; + +var makeMarkerGeometry_Sphere3 = function(markerName, scale) { + return new THREE.SphereGeometry(3, 60, 60); +}; + +var makeMarkerGeometry_SphereX = function(markerName, scale) { + return new THREE.SphereGeometry(5, 60, 60); +}; + +var makeJointGeometry_SphereX = function(X) { + return function(jointName, scale) { + return new THREE.SphereGeometry(X, 60, 60); + }; +}; + + +var makeJointGeometry_Sphere1 = function(jointName, scale) { + return new THREE.SphereGeometry(2 / scale, 60, 60); +}; + +var makeJointGeometry_Sphere2 = function(jointName, scale) { + return new THREE.SphereGeometry(1 / scale, 60, 60); +}; + +var makeJointGeometry_Dode = function(jointName, scale) { + return new THREE.DodecahedronGeometry(1 / scale, 0); +}; + +var makeBoneGeometry_Cylinder1 = function(joint1Name, joint2Name, length, scale) { + return new THREE.CylinderGeometry(1.5 / scale, 0.7 / scale, length, 40); +}; + +var makeBoneGeometry_Cylinder2 = function(joint1Name, joint2Name, length, scale) { + // if (joint1Name.includes("LeftHip")) + // length = 400; + return new THREE.CylinderGeometry(1.5 / scale, 0.2 / scale, length, 40); +}; + +var makeBoneGeometry_Cylinder3 = function(joint1Name, joint2Name, length, scale) { + var c1 = new THREE.CylinderGeometry(1.5 / scale, 0.2 / scale, length / 1, 20); + var c2 = new THREE.CylinderGeometry(0.2 / scale, 1.5 / scale, length / 1, 40); + + var material = new THREE.MeshPhongMaterial({ + color: 0xF7FE2E + }); + var mmesh = new THREE.Mesh(c1, material); + mmesh.updateMatrix(); + c2.merge(mmesh.geometry, mmesh.matrix); + return c2; +}; + +var makeBoneGeometry_Box1 = function(joint1Name, joint2Name, length, scale) { + return new THREE.BoxGeometry(1 / scale, length, 1 / scale, 40); +}; + + +var makeJointGeometry_Empty = function(jointName, scale) { + return new THREE.SphereGeometry(0.001, 60, 60); +}; + +var makeBoneGeometry_Empty = function(joint1Name, joint2Name, length, scale) { + return new THREE.CylinderGeometry(0.001, 0.001, 0.001, 40); +}; diff --git a/dataloaders/pymo/mocapplayer/libs/jquery.min.js b/dataloaders/pymo/mocapplayer/libs/jquery.min.js new file mode 100644 index 0000000000000000000000000000000000000000..b8c4187de18dd413ad3029839ce0773549e92a14 --- /dev/null +++ b/dataloaders/pymo/mocapplayer/libs/jquery.min.js @@ -0,0 +1,4 @@ +/*! jQuery v2.2.3 | (c) jQuery Foundation | jquery.org/license */ +!function(a,b){"object"==typeof module&&"object"==typeof module.exports?module.exports=a.document?b(a,!0):function(a){if(!a.document)throw new Error("jQuery requires a window with a document");return b(a)}:b(a)}("undefined"!=typeof window?window:this,function(a,b){var c=[],d=a.document,e=c.slice,f=c.concat,g=c.push,h=c.indexOf,i={},j=i.toString,k=i.hasOwnProperty,l={},m="2.2.3",n=function(a,b){return new n.fn.init(a,b)},o=/^[\s\uFEFF\xA0]+|[\s\uFEFF\xA0]+$/g,p=/^-ms-/,q=/-([\da-z])/gi,r=function(a,b){return b.toUpperCase()};n.fn=n.prototype={jquery:m,constructor:n,selector:"",length:0,toArray:function(){return e.call(this)},get:function(a){return null!=a?0>a?this[a+this.length]:this[a]:e.call(this)},pushStack:function(a){var b=n.merge(this.constructor(),a);return b.prevObject=this,b.context=this.context,b},each:function(a){return n.each(this,a)},map:function(a){return this.pushStack(n.map(this,function(b,c){return a.call(b,c,b)}))},slice:function(){return this.pushStack(e.apply(this,arguments))},first:function(){return this.eq(0)},last:function(){return this.eq(-1)},eq:function(a){var b=this.length,c=+a+(0>a?b:0);return this.pushStack(c>=0&&b>c?[this[c]]:[])},end:function(){return this.prevObject||this.constructor()},push:g,sort:c.sort,splice:c.splice},n.extend=n.fn.extend=function(){var a,b,c,d,e,f,g=arguments[0]||{},h=1,i=arguments.length,j=!1;for("boolean"==typeof g&&(j=g,g=arguments[h]||{},h++),"object"==typeof g||n.isFunction(g)||(g={}),h===i&&(g=this,h--);i>h;h++)if(null!=(a=arguments[h]))for(b in a)c=g[b],d=a[b],g!==d&&(j&&d&&(n.isPlainObject(d)||(e=n.isArray(d)))?(e?(e=!1,f=c&&n.isArray(c)?c:[]):f=c&&n.isPlainObject(c)?c:{},g[b]=n.extend(j,f,d)):void 0!==d&&(g[b]=d));return g},n.extend({expando:"jQuery"+(m+Math.random()).replace(/\D/g,""),isReady:!0,error:function(a){throw new Error(a)},noop:function(){},isFunction:function(a){return"function"===n.type(a)},isArray:Array.isArray,isWindow:function(a){return null!=a&&a===a.window},isNumeric:function(a){var b=a&&a.toString();return!n.isArray(a)&&b-parseFloat(b)+1>=0},isPlainObject:function(a){var b;if("object"!==n.type(a)||a.nodeType||n.isWindow(a))return!1;if(a.constructor&&!k.call(a,"constructor")&&!k.call(a.constructor.prototype||{},"isPrototypeOf"))return!1;for(b in a);return void 0===b||k.call(a,b)},isEmptyObject:function(a){var b;for(b in a)return!1;return!0},type:function(a){return null==a?a+"":"object"==typeof a||"function"==typeof a?i[j.call(a)]||"object":typeof a},globalEval:function(a){var b,c=eval;a=n.trim(a),a&&(1===a.indexOf("use strict")?(b=d.createElement("script"),b.text=a,d.head.appendChild(b).parentNode.removeChild(b)):c(a))},camelCase:function(a){return a.replace(p,"ms-").replace(q,r)},nodeName:function(a,b){return a.nodeName&&a.nodeName.toLowerCase()===b.toLowerCase()},each:function(a,b){var c,d=0;if(s(a)){for(c=a.length;c>d;d++)if(b.call(a[d],d,a[d])===!1)break}else for(d in a)if(b.call(a[d],d,a[d])===!1)break;return a},trim:function(a){return null==a?"":(a+"").replace(o,"")},makeArray:function(a,b){var c=b||[];return null!=a&&(s(Object(a))?n.merge(c,"string"==typeof a?[a]:a):g.call(c,a)),c},inArray:function(a,b,c){return null==b?-1:h.call(b,a,c)},merge:function(a,b){for(var c=+b.length,d=0,e=a.length;c>d;d++)a[e++]=b[d];return a.length=e,a},grep:function(a,b,c){for(var d,e=[],f=0,g=a.length,h=!c;g>f;f++)d=!b(a[f],f),d!==h&&e.push(a[f]);return e},map:function(a,b,c){var d,e,g=0,h=[];if(s(a))for(d=a.length;d>g;g++)e=b(a[g],g,c),null!=e&&h.push(e);else for(g in a)e=b(a[g],g,c),null!=e&&h.push(e);return f.apply([],h)},guid:1,proxy:function(a,b){var c,d,f;return"string"==typeof b&&(c=a[b],b=a,a=c),n.isFunction(a)?(d=e.call(arguments,2),f=function(){return a.apply(b||this,d.concat(e.call(arguments)))},f.guid=a.guid=a.guid||n.guid++,f):void 0},now:Date.now,support:l}),"function"==typeof Symbol&&(n.fn[Symbol.iterator]=c[Symbol.iterator]),n.each("Boolean Number String Function Array Date RegExp Object Error Symbol".split(" "),function(a,b){i["[object "+b+"]"]=b.toLowerCase()});function s(a){var b=!!a&&"length"in a&&a.length,c=n.type(a);return"function"===c||n.isWindow(a)?!1:"array"===c||0===b||"number"==typeof b&&b>0&&b-1 in a}var t=function(a){var b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u="sizzle"+1*new Date,v=a.document,w=0,x=0,y=ga(),z=ga(),A=ga(),B=function(a,b){return a===b&&(l=!0),0},C=1<<31,D={}.hasOwnProperty,E=[],F=E.pop,G=E.push,H=E.push,I=E.slice,J=function(a,b){for(var c=0,d=a.length;d>c;c++)if(a[c]===b)return c;return-1},K="checked|selected|async|autofocus|autoplay|controls|defer|disabled|hidden|ismap|loop|multiple|open|readonly|required|scoped",L="[\\x20\\t\\r\\n\\f]",M="(?:\\\\.|[\\w-]|[^\\x00-\\xa0])+",N="\\["+L+"*("+M+")(?:"+L+"*([*^$|!~]?=)"+L+"*(?:'((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\"|("+M+"))|)"+L+"*\\]",O=":("+M+")(?:\\((('((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\")|((?:\\\\.|[^\\\\()[\\]]|"+N+")*)|.*)\\)|)",P=new RegExp(L+"+","g"),Q=new RegExp("^"+L+"+|((?:^|[^\\\\])(?:\\\\.)*)"+L+"+$","g"),R=new RegExp("^"+L+"*,"+L+"*"),S=new RegExp("^"+L+"*([>+~]|"+L+")"+L+"*"),T=new RegExp("="+L+"*([^\\]'\"]*?)"+L+"*\\]","g"),U=new RegExp(O),V=new RegExp("^"+M+"$"),W={ID:new RegExp("^#("+M+")"),CLASS:new RegExp("^\\.("+M+")"),TAG:new RegExp("^("+M+"|[*])"),ATTR:new RegExp("^"+N),PSEUDO:new RegExp("^"+O),CHILD:new RegExp("^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\("+L+"*(even|odd|(([+-]|)(\\d*)n|)"+L+"*(?:([+-]|)"+L+"*(\\d+)|))"+L+"*\\)|)","i"),bool:new RegExp("^(?:"+K+")$","i"),needsContext:new RegExp("^"+L+"*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\("+L+"*((?:-\\d)?\\d*)"+L+"*\\)|)(?=[^-]|$)","i")},X=/^(?:input|select|textarea|button)$/i,Y=/^h\d$/i,Z=/^[^{]+\{\s*\[native \w/,$=/^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/,_=/[+~]/,aa=/'|\\/g,ba=new RegExp("\\\\([\\da-f]{1,6}"+L+"?|("+L+")|.)","ig"),ca=function(a,b,c){var d="0x"+b-65536;return d!==d||c?b:0>d?String.fromCharCode(d+65536):String.fromCharCode(d>>10|55296,1023&d|56320)},da=function(){m()};try{H.apply(E=I.call(v.childNodes),v.childNodes),E[v.childNodes.length].nodeType}catch(ea){H={apply:E.length?function(a,b){G.apply(a,I.call(b))}:function(a,b){var c=a.length,d=0;while(a[c++]=b[d++]);a.length=c-1}}}function fa(a,b,d,e){var f,h,j,k,l,o,r,s,w=b&&b.ownerDocument,x=b?b.nodeType:9;if(d=d||[],"string"!=typeof a||!a||1!==x&&9!==x&&11!==x)return d;if(!e&&((b?b.ownerDocument||b:v)!==n&&m(b),b=b||n,p)){if(11!==x&&(o=$.exec(a)))if(f=o[1]){if(9===x){if(!(j=b.getElementById(f)))return d;if(j.id===f)return d.push(j),d}else if(w&&(j=w.getElementById(f))&&t(b,j)&&j.id===f)return d.push(j),d}else{if(o[2])return H.apply(d,b.getElementsByTagName(a)),d;if((f=o[3])&&c.getElementsByClassName&&b.getElementsByClassName)return H.apply(d,b.getElementsByClassName(f)),d}if(c.qsa&&!A[a+" "]&&(!q||!q.test(a))){if(1!==x)w=b,s=a;else if("object"!==b.nodeName.toLowerCase()){(k=b.getAttribute("id"))?k=k.replace(aa,"\\$&"):b.setAttribute("id",k=u),r=g(a),h=r.length,l=V.test(k)?"#"+k:"[id='"+k+"']";while(h--)r[h]=l+" "+qa(r[h]);s=r.join(","),w=_.test(a)&&oa(b.parentNode)||b}if(s)try{return H.apply(d,w.querySelectorAll(s)),d}catch(y){}finally{k===u&&b.removeAttribute("id")}}}return i(a.replace(Q,"$1"),b,d,e)}function ga(){var a=[];function b(c,e){return a.push(c+" ")>d.cacheLength&&delete b[a.shift()],b[c+" "]=e}return b}function ha(a){return a[u]=!0,a}function ia(a){var b=n.createElement("div");try{return!!a(b)}catch(c){return!1}finally{b.parentNode&&b.parentNode.removeChild(b),b=null}}function ja(a,b){var c=a.split("|"),e=c.length;while(e--)d.attrHandle[c[e]]=b}function ka(a,b){var c=b&&a,d=c&&1===a.nodeType&&1===b.nodeType&&(~b.sourceIndex||C)-(~a.sourceIndex||C);if(d)return d;if(c)while(c=c.nextSibling)if(c===b)return-1;return a?1:-1}function la(a){return function(b){var c=b.nodeName.toLowerCase();return"input"===c&&b.type===a}}function ma(a){return function(b){var c=b.nodeName.toLowerCase();return("input"===c||"button"===c)&&b.type===a}}function na(a){return ha(function(b){return b=+b,ha(function(c,d){var e,f=a([],c.length,b),g=f.length;while(g--)c[e=f[g]]&&(c[e]=!(d[e]=c[e]))})})}function oa(a){return a&&"undefined"!=typeof a.getElementsByTagName&&a}c=fa.support={},f=fa.isXML=function(a){var b=a&&(a.ownerDocument||a).documentElement;return b?"HTML"!==b.nodeName:!1},m=fa.setDocument=function(a){var b,e,g=a?a.ownerDocument||a:v;return g!==n&&9===g.nodeType&&g.documentElement?(n=g,o=n.documentElement,p=!f(n),(e=n.defaultView)&&e.top!==e&&(e.addEventListener?e.addEventListener("unload",da,!1):e.attachEvent&&e.attachEvent("onunload",da)),c.attributes=ia(function(a){return a.className="i",!a.getAttribute("className")}),c.getElementsByTagName=ia(function(a){return a.appendChild(n.createComment("")),!a.getElementsByTagName("*").length}),c.getElementsByClassName=Z.test(n.getElementsByClassName),c.getById=ia(function(a){return o.appendChild(a).id=u,!n.getElementsByName||!n.getElementsByName(u).length}),c.getById?(d.find.ID=function(a,b){if("undefined"!=typeof b.getElementById&&p){var c=b.getElementById(a);return c?[c]:[]}},d.filter.ID=function(a){var b=a.replace(ba,ca);return function(a){return a.getAttribute("id")===b}}):(delete d.find.ID,d.filter.ID=function(a){var b=a.replace(ba,ca);return function(a){var c="undefined"!=typeof a.getAttributeNode&&a.getAttributeNode("id");return c&&c.value===b}}),d.find.TAG=c.getElementsByTagName?function(a,b){return"undefined"!=typeof b.getElementsByTagName?b.getElementsByTagName(a):c.qsa?b.querySelectorAll(a):void 0}:function(a,b){var c,d=[],e=0,f=b.getElementsByTagName(a);if("*"===a){while(c=f[e++])1===c.nodeType&&d.push(c);return d}return f},d.find.CLASS=c.getElementsByClassName&&function(a,b){return"undefined"!=typeof b.getElementsByClassName&&p?b.getElementsByClassName(a):void 0},r=[],q=[],(c.qsa=Z.test(n.querySelectorAll))&&(ia(function(a){o.appendChild(a).innerHTML="",a.querySelectorAll("[msallowcapture^='']").length&&q.push("[*^$]="+L+"*(?:''|\"\")"),a.querySelectorAll("[selected]").length||q.push("\\["+L+"*(?:value|"+K+")"),a.querySelectorAll("[id~="+u+"-]").length||q.push("~="),a.querySelectorAll(":checked").length||q.push(":checked"),a.querySelectorAll("a#"+u+"+*").length||q.push(".#.+[+~]")}),ia(function(a){var b=n.createElement("input");b.setAttribute("type","hidden"),a.appendChild(b).setAttribute("name","D"),a.querySelectorAll("[name=d]").length&&q.push("name"+L+"*[*^$|!~]?="),a.querySelectorAll(":enabled").length||q.push(":enabled",":disabled"),a.querySelectorAll("*,:x"),q.push(",.*:")})),(c.matchesSelector=Z.test(s=o.matches||o.webkitMatchesSelector||o.mozMatchesSelector||o.oMatchesSelector||o.msMatchesSelector))&&ia(function(a){c.disconnectedMatch=s.call(a,"div"),s.call(a,"[s!='']:x"),r.push("!=",O)}),q=q.length&&new RegExp(q.join("|")),r=r.length&&new RegExp(r.join("|")),b=Z.test(o.compareDocumentPosition),t=b||Z.test(o.contains)?function(a,b){var c=9===a.nodeType?a.documentElement:a,d=b&&b.parentNode;return a===d||!(!d||1!==d.nodeType||!(c.contains?c.contains(d):a.compareDocumentPosition&&16&a.compareDocumentPosition(d)))}:function(a,b){if(b)while(b=b.parentNode)if(b===a)return!0;return!1},B=b?function(a,b){if(a===b)return l=!0,0;var d=!a.compareDocumentPosition-!b.compareDocumentPosition;return d?d:(d=(a.ownerDocument||a)===(b.ownerDocument||b)?a.compareDocumentPosition(b):1,1&d||!c.sortDetached&&b.compareDocumentPosition(a)===d?a===n||a.ownerDocument===v&&t(v,a)?-1:b===n||b.ownerDocument===v&&t(v,b)?1:k?J(k,a)-J(k,b):0:4&d?-1:1)}:function(a,b){if(a===b)return l=!0,0;var c,d=0,e=a.parentNode,f=b.parentNode,g=[a],h=[b];if(!e||!f)return a===n?-1:b===n?1:e?-1:f?1:k?J(k,a)-J(k,b):0;if(e===f)return ka(a,b);c=a;while(c=c.parentNode)g.unshift(c);c=b;while(c=c.parentNode)h.unshift(c);while(g[d]===h[d])d++;return d?ka(g[d],h[d]):g[d]===v?-1:h[d]===v?1:0},n):n},fa.matches=function(a,b){return fa(a,null,null,b)},fa.matchesSelector=function(a,b){if((a.ownerDocument||a)!==n&&m(a),b=b.replace(T,"='$1']"),c.matchesSelector&&p&&!A[b+" "]&&(!r||!r.test(b))&&(!q||!q.test(b)))try{var d=s.call(a,b);if(d||c.disconnectedMatch||a.document&&11!==a.document.nodeType)return d}catch(e){}return fa(b,n,null,[a]).length>0},fa.contains=function(a,b){return(a.ownerDocument||a)!==n&&m(a),t(a,b)},fa.attr=function(a,b){(a.ownerDocument||a)!==n&&m(a);var e=d.attrHandle[b.toLowerCase()],f=e&&D.call(d.attrHandle,b.toLowerCase())?e(a,b,!p):void 0;return void 0!==f?f:c.attributes||!p?a.getAttribute(b):(f=a.getAttributeNode(b))&&f.specified?f.value:null},fa.error=function(a){throw new Error("Syntax error, unrecognized expression: "+a)},fa.uniqueSort=function(a){var b,d=[],e=0,f=0;if(l=!c.detectDuplicates,k=!c.sortStable&&a.slice(0),a.sort(B),l){while(b=a[f++])b===a[f]&&(e=d.push(f));while(e--)a.splice(d[e],1)}return k=null,a},e=fa.getText=function(a){var b,c="",d=0,f=a.nodeType;if(f){if(1===f||9===f||11===f){if("string"==typeof a.textContent)return a.textContent;for(a=a.firstChild;a;a=a.nextSibling)c+=e(a)}else if(3===f||4===f)return a.nodeValue}else while(b=a[d++])c+=e(b);return c},d=fa.selectors={cacheLength:50,createPseudo:ha,match:W,attrHandle:{},find:{},relative:{">":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(a){return a[1]=a[1].replace(ba,ca),a[3]=(a[3]||a[4]||a[5]||"").replace(ba,ca),"~="===a[2]&&(a[3]=" "+a[3]+" "),a.slice(0,4)},CHILD:function(a){return a[1]=a[1].toLowerCase(),"nth"===a[1].slice(0,3)?(a[3]||fa.error(a[0]),a[4]=+(a[4]?a[5]+(a[6]||1):2*("even"===a[3]||"odd"===a[3])),a[5]=+(a[7]+a[8]||"odd"===a[3])):a[3]&&fa.error(a[0]),a},PSEUDO:function(a){var b,c=!a[6]&&a[2];return W.CHILD.test(a[0])?null:(a[3]?a[2]=a[4]||a[5]||"":c&&U.test(c)&&(b=g(c,!0))&&(b=c.indexOf(")",c.length-b)-c.length)&&(a[0]=a[0].slice(0,b),a[2]=c.slice(0,b)),a.slice(0,3))}},filter:{TAG:function(a){var b=a.replace(ba,ca).toLowerCase();return"*"===a?function(){return!0}:function(a){return a.nodeName&&a.nodeName.toLowerCase()===b}},CLASS:function(a){var b=y[a+" "];return b||(b=new RegExp("(^|"+L+")"+a+"("+L+"|$)"))&&y(a,function(a){return b.test("string"==typeof a.className&&a.className||"undefined"!=typeof a.getAttribute&&a.getAttribute("class")||"")})},ATTR:function(a,b,c){return function(d){var e=fa.attr(d,a);return null==e?"!="===b:b?(e+="","="===b?e===c:"!="===b?e!==c:"^="===b?c&&0===e.indexOf(c):"*="===b?c&&e.indexOf(c)>-1:"$="===b?c&&e.slice(-c.length)===c:"~="===b?(" "+e.replace(P," ")+" ").indexOf(c)>-1:"|="===b?e===c||e.slice(0,c.length+1)===c+"-":!1):!0}},CHILD:function(a,b,c,d,e){var f="nth"!==a.slice(0,3),g="last"!==a.slice(-4),h="of-type"===b;return 1===d&&0===e?function(a){return!!a.parentNode}:function(b,c,i){var j,k,l,m,n,o,p=f!==g?"nextSibling":"previousSibling",q=b.parentNode,r=h&&b.nodeName.toLowerCase(),s=!i&&!h,t=!1;if(q){if(f){while(p){m=b;while(m=m[p])if(h?m.nodeName.toLowerCase()===r:1===m.nodeType)return!1;o=p="only"===a&&!o&&"nextSibling"}return!0}if(o=[g?q.firstChild:q.lastChild],g&&s){m=q,l=m[u]||(m[u]={}),k=l[m.uniqueID]||(l[m.uniqueID]={}),j=k[a]||[],n=j[0]===w&&j[1],t=n&&j[2],m=n&&q.childNodes[n];while(m=++n&&m&&m[p]||(t=n=0)||o.pop())if(1===m.nodeType&&++t&&m===b){k[a]=[w,n,t];break}}else if(s&&(m=b,l=m[u]||(m[u]={}),k=l[m.uniqueID]||(l[m.uniqueID]={}),j=k[a]||[],n=j[0]===w&&j[1],t=n),t===!1)while(m=++n&&m&&m[p]||(t=n=0)||o.pop())if((h?m.nodeName.toLowerCase()===r:1===m.nodeType)&&++t&&(s&&(l=m[u]||(m[u]={}),k=l[m.uniqueID]||(l[m.uniqueID]={}),k[a]=[w,t]),m===b))break;return t-=e,t===d||t%d===0&&t/d>=0}}},PSEUDO:function(a,b){var c,e=d.pseudos[a]||d.setFilters[a.toLowerCase()]||fa.error("unsupported pseudo: "+a);return e[u]?e(b):e.length>1?(c=[a,a,"",b],d.setFilters.hasOwnProperty(a.toLowerCase())?ha(function(a,c){var d,f=e(a,b),g=f.length;while(g--)d=J(a,f[g]),a[d]=!(c[d]=f[g])}):function(a){return e(a,0,c)}):e}},pseudos:{not:ha(function(a){var b=[],c=[],d=h(a.replace(Q,"$1"));return d[u]?ha(function(a,b,c,e){var f,g=d(a,null,e,[]),h=a.length;while(h--)(f=g[h])&&(a[h]=!(b[h]=f))}):function(a,e,f){return b[0]=a,d(b,null,f,c),b[0]=null,!c.pop()}}),has:ha(function(a){return function(b){return fa(a,b).length>0}}),contains:ha(function(a){return a=a.replace(ba,ca),function(b){return(b.textContent||b.innerText||e(b)).indexOf(a)>-1}}),lang:ha(function(a){return V.test(a||"")||fa.error("unsupported lang: "+a),a=a.replace(ba,ca).toLowerCase(),function(b){var c;do if(c=p?b.lang:b.getAttribute("xml:lang")||b.getAttribute("lang"))return c=c.toLowerCase(),c===a||0===c.indexOf(a+"-");while((b=b.parentNode)&&1===b.nodeType);return!1}}),target:function(b){var c=a.location&&a.location.hash;return c&&c.slice(1)===b.id},root:function(a){return a===o},focus:function(a){return a===n.activeElement&&(!n.hasFocus||n.hasFocus())&&!!(a.type||a.href||~a.tabIndex)},enabled:function(a){return a.disabled===!1},disabled:function(a){return a.disabled===!0},checked:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&!!a.checked||"option"===b&&!!a.selected},selected:function(a){return a.parentNode&&a.parentNode.selectedIndex,a.selected===!0},empty:function(a){for(a=a.firstChild;a;a=a.nextSibling)if(a.nodeType<6)return!1;return!0},parent:function(a){return!d.pseudos.empty(a)},header:function(a){return Y.test(a.nodeName)},input:function(a){return X.test(a.nodeName)},button:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&"button"===a.type||"button"===b},text:function(a){var b;return"input"===a.nodeName.toLowerCase()&&"text"===a.type&&(null==(b=a.getAttribute("type"))||"text"===b.toLowerCase())},first:na(function(){return[0]}),last:na(function(a,b){return[b-1]}),eq:na(function(a,b,c){return[0>c?c+b:c]}),even:na(function(a,b){for(var c=0;b>c;c+=2)a.push(c);return a}),odd:na(function(a,b){for(var c=1;b>c;c+=2)a.push(c);return a}),lt:na(function(a,b,c){for(var d=0>c?c+b:c;--d>=0;)a.push(d);return a}),gt:na(function(a,b,c){for(var d=0>c?c+b:c;++db;b++)d+=a[b].value;return d}function ra(a,b,c){var d=b.dir,e=c&&"parentNode"===d,f=x++;return b.first?function(b,c,f){while(b=b[d])if(1===b.nodeType||e)return a(b,c,f)}:function(b,c,g){var h,i,j,k=[w,f];if(g){while(b=b[d])if((1===b.nodeType||e)&&a(b,c,g))return!0}else while(b=b[d])if(1===b.nodeType||e){if(j=b[u]||(b[u]={}),i=j[b.uniqueID]||(j[b.uniqueID]={}),(h=i[d])&&h[0]===w&&h[1]===f)return k[2]=h[2];if(i[d]=k,k[2]=a(b,c,g))return!0}}}function sa(a){return a.length>1?function(b,c,d){var e=a.length;while(e--)if(!a[e](b,c,d))return!1;return!0}:a[0]}function ta(a,b,c){for(var d=0,e=b.length;e>d;d++)fa(a,b[d],c);return c}function ua(a,b,c,d,e){for(var f,g=[],h=0,i=a.length,j=null!=b;i>h;h++)(f=a[h])&&(c&&!c(f,d,e)||(g.push(f),j&&b.push(h)));return g}function va(a,b,c,d,e,f){return d&&!d[u]&&(d=va(d)),e&&!e[u]&&(e=va(e,f)),ha(function(f,g,h,i){var j,k,l,m=[],n=[],o=g.length,p=f||ta(b||"*",h.nodeType?[h]:h,[]),q=!a||!f&&b?p:ua(p,m,a,h,i),r=c?e||(f?a:o||d)?[]:g:q;if(c&&c(q,r,h,i),d){j=ua(r,n),d(j,[],h,i),k=j.length;while(k--)(l=j[k])&&(r[n[k]]=!(q[n[k]]=l))}if(f){if(e||a){if(e){j=[],k=r.length;while(k--)(l=r[k])&&j.push(q[k]=l);e(null,r=[],j,i)}k=r.length;while(k--)(l=r[k])&&(j=e?J(f,l):m[k])>-1&&(f[j]=!(g[j]=l))}}else r=ua(r===g?r.splice(o,r.length):r),e?e(null,g,r,i):H.apply(g,r)})}function wa(a){for(var b,c,e,f=a.length,g=d.relative[a[0].type],h=g||d.relative[" "],i=g?1:0,k=ra(function(a){return a===b},h,!0),l=ra(function(a){return J(b,a)>-1},h,!0),m=[function(a,c,d){var e=!g&&(d||c!==j)||((b=c).nodeType?k(a,c,d):l(a,c,d));return b=null,e}];f>i;i++)if(c=d.relative[a[i].type])m=[ra(sa(m),c)];else{if(c=d.filter[a[i].type].apply(null,a[i].matches),c[u]){for(e=++i;f>e;e++)if(d.relative[a[e].type])break;return va(i>1&&sa(m),i>1&&qa(a.slice(0,i-1).concat({value:" "===a[i-2].type?"*":""})).replace(Q,"$1"),c,e>i&&wa(a.slice(i,e)),f>e&&wa(a=a.slice(e)),f>e&&qa(a))}m.push(c)}return sa(m)}function xa(a,b){var c=b.length>0,e=a.length>0,f=function(f,g,h,i,k){var l,o,q,r=0,s="0",t=f&&[],u=[],v=j,x=f||e&&d.find.TAG("*",k),y=w+=null==v?1:Math.random()||.1,z=x.length;for(k&&(j=g===n||g||k);s!==z&&null!=(l=x[s]);s++){if(e&&l){o=0,g||l.ownerDocument===n||(m(l),h=!p);while(q=a[o++])if(q(l,g||n,h)){i.push(l);break}k&&(w=y)}c&&((l=!q&&l)&&r--,f&&t.push(l))}if(r+=s,c&&s!==r){o=0;while(q=b[o++])q(t,u,g,h);if(f){if(r>0)while(s--)t[s]||u[s]||(u[s]=F.call(i));u=ua(u)}H.apply(i,u),k&&!f&&u.length>0&&r+b.length>1&&fa.uniqueSort(i)}return k&&(w=y,j=v),t};return c?ha(f):f}return h=fa.compile=function(a,b){var c,d=[],e=[],f=A[a+" "];if(!f){b||(b=g(a)),c=b.length;while(c--)f=wa(b[c]),f[u]?d.push(f):e.push(f);f=A(a,xa(e,d)),f.selector=a}return f},i=fa.select=function(a,b,e,f){var i,j,k,l,m,n="function"==typeof a&&a,o=!f&&g(a=n.selector||a);if(e=e||[],1===o.length){if(j=o[0]=o[0].slice(0),j.length>2&&"ID"===(k=j[0]).type&&c.getById&&9===b.nodeType&&p&&d.relative[j[1].type]){if(b=(d.find.ID(k.matches[0].replace(ba,ca),b)||[])[0],!b)return e;n&&(b=b.parentNode),a=a.slice(j.shift().value.length)}i=W.needsContext.test(a)?0:j.length;while(i--){if(k=j[i],d.relative[l=k.type])break;if((m=d.find[l])&&(f=m(k.matches[0].replace(ba,ca),_.test(j[0].type)&&oa(b.parentNode)||b))){if(j.splice(i,1),a=f.length&&qa(j),!a)return H.apply(e,f),e;break}}}return(n||h(a,o))(f,b,!p,e,!b||_.test(a)&&oa(b.parentNode)||b),e},c.sortStable=u.split("").sort(B).join("")===u,c.detectDuplicates=!!l,m(),c.sortDetached=ia(function(a){return 1&a.compareDocumentPosition(n.createElement("div"))}),ia(function(a){return a.innerHTML="","#"===a.firstChild.getAttribute("href")})||ja("type|href|height|width",function(a,b,c){return c?void 0:a.getAttribute(b,"type"===b.toLowerCase()?1:2)}),c.attributes&&ia(function(a){return a.innerHTML="",a.firstChild.setAttribute("value",""),""===a.firstChild.getAttribute("value")})||ja("value",function(a,b,c){return c||"input"!==a.nodeName.toLowerCase()?void 0:a.defaultValue}),ia(function(a){return null==a.getAttribute("disabled")})||ja(K,function(a,b,c){var d;return c?void 0:a[b]===!0?b.toLowerCase():(d=a.getAttributeNode(b))&&d.specified?d.value:null}),fa}(a);n.find=t,n.expr=t.selectors,n.expr[":"]=n.expr.pseudos,n.uniqueSort=n.unique=t.uniqueSort,n.text=t.getText,n.isXMLDoc=t.isXML,n.contains=t.contains;var u=function(a,b,c){var d=[],e=void 0!==c;while((a=a[b])&&9!==a.nodeType)if(1===a.nodeType){if(e&&n(a).is(c))break;d.push(a)}return d},v=function(a,b){for(var c=[];a;a=a.nextSibling)1===a.nodeType&&a!==b&&c.push(a);return c},w=n.expr.match.needsContext,x=/^<([\w-]+)\s*\/?>(?:<\/\1>|)$/,y=/^.[^:#\[\.,]*$/;function z(a,b,c){if(n.isFunction(b))return n.grep(a,function(a,d){return!!b.call(a,d,a)!==c});if(b.nodeType)return n.grep(a,function(a){return a===b!==c});if("string"==typeof b){if(y.test(b))return n.filter(b,a,c);b=n.filter(b,a)}return n.grep(a,function(a){return h.call(b,a)>-1!==c})}n.filter=function(a,b,c){var d=b[0];return c&&(a=":not("+a+")"),1===b.length&&1===d.nodeType?n.find.matchesSelector(d,a)?[d]:[]:n.find.matches(a,n.grep(b,function(a){return 1===a.nodeType}))},n.fn.extend({find:function(a){var b,c=this.length,d=[],e=this;if("string"!=typeof a)return this.pushStack(n(a).filter(function(){for(b=0;c>b;b++)if(n.contains(e[b],this))return!0}));for(b=0;c>b;b++)n.find(a,e[b],d);return d=this.pushStack(c>1?n.unique(d):d),d.selector=this.selector?this.selector+" "+a:a,d},filter:function(a){return this.pushStack(z(this,a||[],!1))},not:function(a){return this.pushStack(z(this,a||[],!0))},is:function(a){return!!z(this,"string"==typeof a&&w.test(a)?n(a):a||[],!1).length}});var A,B=/^(?:\s*(<[\w\W]+>)[^>]*|#([\w-]*))$/,C=n.fn.init=function(a,b,c){var e,f;if(!a)return this;if(c=c||A,"string"==typeof a){if(e="<"===a[0]&&">"===a[a.length-1]&&a.length>=3?[null,a,null]:B.exec(a),!e||!e[1]&&b)return!b||b.jquery?(b||c).find(a):this.constructor(b).find(a);if(e[1]){if(b=b instanceof n?b[0]:b,n.merge(this,n.parseHTML(e[1],b&&b.nodeType?b.ownerDocument||b:d,!0)),x.test(e[1])&&n.isPlainObject(b))for(e in b)n.isFunction(this[e])?this[e](b[e]):this.attr(e,b[e]);return this}return f=d.getElementById(e[2]),f&&f.parentNode&&(this.length=1,this[0]=f),this.context=d,this.selector=a,this}return a.nodeType?(this.context=this[0]=a,this.length=1,this):n.isFunction(a)?void 0!==c.ready?c.ready(a):a(n):(void 0!==a.selector&&(this.selector=a.selector,this.context=a.context),n.makeArray(a,this))};C.prototype=n.fn,A=n(d);var D=/^(?:parents|prev(?:Until|All))/,E={children:!0,contents:!0,next:!0,prev:!0};n.fn.extend({has:function(a){var b=n(a,this),c=b.length;return this.filter(function(){for(var a=0;c>a;a++)if(n.contains(this,b[a]))return!0})},closest:function(a,b){for(var c,d=0,e=this.length,f=[],g=w.test(a)||"string"!=typeof a?n(a,b||this.context):0;e>d;d++)for(c=this[d];c&&c!==b;c=c.parentNode)if(c.nodeType<11&&(g?g.index(c)>-1:1===c.nodeType&&n.find.matchesSelector(c,a))){f.push(c);break}return this.pushStack(f.length>1?n.uniqueSort(f):f)},index:function(a){return a?"string"==typeof a?h.call(n(a),this[0]):h.call(this,a.jquery?a[0]:a):this[0]&&this[0].parentNode?this.first().prevAll().length:-1},add:function(a,b){return this.pushStack(n.uniqueSort(n.merge(this.get(),n(a,b))))},addBack:function(a){return this.add(null==a?this.prevObject:this.prevObject.filter(a))}});function F(a,b){while((a=a[b])&&1!==a.nodeType);return a}n.each({parent:function(a){var b=a.parentNode;return b&&11!==b.nodeType?b:null},parents:function(a){return u(a,"parentNode")},parentsUntil:function(a,b,c){return u(a,"parentNode",c)},next:function(a){return F(a,"nextSibling")},prev:function(a){return F(a,"previousSibling")},nextAll:function(a){return u(a,"nextSibling")},prevAll:function(a){return u(a,"previousSibling")},nextUntil:function(a,b,c){return u(a,"nextSibling",c)},prevUntil:function(a,b,c){return u(a,"previousSibling",c)},siblings:function(a){return v((a.parentNode||{}).firstChild,a)},children:function(a){return v(a.firstChild)},contents:function(a){return a.contentDocument||n.merge([],a.childNodes)}},function(a,b){n.fn[a]=function(c,d){var e=n.map(this,b,c);return"Until"!==a.slice(-5)&&(d=c),d&&"string"==typeof d&&(e=n.filter(d,e)),this.length>1&&(E[a]||n.uniqueSort(e),D.test(a)&&e.reverse()),this.pushStack(e)}});var G=/\S+/g;function H(a){var b={};return n.each(a.match(G)||[],function(a,c){b[c]=!0}),b}n.Callbacks=function(a){a="string"==typeof a?H(a):n.extend({},a);var b,c,d,e,f=[],g=[],h=-1,i=function(){for(e=a.once,d=b=!0;g.length;h=-1){c=g.shift();while(++h-1)f.splice(c,1),h>=c&&h--}),this},has:function(a){return a?n.inArray(a,f)>-1:f.length>0},empty:function(){return f&&(f=[]),this},disable:function(){return e=g=[],f=c="",this},disabled:function(){return!f},lock:function(){return e=g=[],c||(f=c=""),this},locked:function(){return!!e},fireWith:function(a,c){return e||(c=c||[],c=[a,c.slice?c.slice():c],g.push(c),b||i()),this},fire:function(){return j.fireWith(this,arguments),this},fired:function(){return!!d}};return j},n.extend({Deferred:function(a){var b=[["resolve","done",n.Callbacks("once memory"),"resolved"],["reject","fail",n.Callbacks("once memory"),"rejected"],["notify","progress",n.Callbacks("memory")]],c="pending",d={state:function(){return c},always:function(){return e.done(arguments).fail(arguments),this},then:function(){var a=arguments;return n.Deferred(function(c){n.each(b,function(b,f){var g=n.isFunction(a[b])&&a[b];e[f[1]](function(){var a=g&&g.apply(this,arguments);a&&n.isFunction(a.promise)?a.promise().progress(c.notify).done(c.resolve).fail(c.reject):c[f[0]+"With"](this===d?c.promise():this,g?[a]:arguments)})}),a=null}).promise()},promise:function(a){return null!=a?n.extend(a,d):d}},e={};return d.pipe=d.then,n.each(b,function(a,f){var g=f[2],h=f[3];d[f[1]]=g.add,h&&g.add(function(){c=h},b[1^a][2].disable,b[2][2].lock),e[f[0]]=function(){return e[f[0]+"With"](this===e?d:this,arguments),this},e[f[0]+"With"]=g.fireWith}),d.promise(e),a&&a.call(e,e),e},when:function(a){var b=0,c=e.call(arguments),d=c.length,f=1!==d||a&&n.isFunction(a.promise)?d:0,g=1===f?a:n.Deferred(),h=function(a,b,c){return function(d){b[a]=this,c[a]=arguments.length>1?e.call(arguments):d,c===i?g.notifyWith(b,c):--f||g.resolveWith(b,c)}},i,j,k;if(d>1)for(i=new Array(d),j=new Array(d),k=new Array(d);d>b;b++)c[b]&&n.isFunction(c[b].promise)?c[b].promise().progress(h(b,j,i)).done(h(b,k,c)).fail(g.reject):--f;return f||g.resolveWith(k,c),g.promise()}});var I;n.fn.ready=function(a){return n.ready.promise().done(a),this},n.extend({isReady:!1,readyWait:1,holdReady:function(a){a?n.readyWait++:n.ready(!0)},ready:function(a){(a===!0?--n.readyWait:n.isReady)||(n.isReady=!0,a!==!0&&--n.readyWait>0||(I.resolveWith(d,[n]),n.fn.triggerHandler&&(n(d).triggerHandler("ready"),n(d).off("ready"))))}});function J(){d.removeEventListener("DOMContentLoaded",J),a.removeEventListener("load",J),n.ready()}n.ready.promise=function(b){return I||(I=n.Deferred(),"complete"===d.readyState||"loading"!==d.readyState&&!d.documentElement.doScroll?a.setTimeout(n.ready):(d.addEventListener("DOMContentLoaded",J),a.addEventListener("load",J))),I.promise(b)},n.ready.promise();var K=function(a,b,c,d,e,f,g){var h=0,i=a.length,j=null==c;if("object"===n.type(c)){e=!0;for(h in c)K(a,b,h,c[h],!0,f,g)}else if(void 0!==d&&(e=!0,n.isFunction(d)||(g=!0),j&&(g?(b.call(a,d),b=null):(j=b,b=function(a,b,c){return j.call(n(a),c)})),b))for(;i>h;h++)b(a[h],c,g?d:d.call(a[h],h,b(a[h],c)));return e?a:j?b.call(a):i?b(a[0],c):f},L=function(a){return 1===a.nodeType||9===a.nodeType||!+a.nodeType};function M(){this.expando=n.expando+M.uid++}M.uid=1,M.prototype={register:function(a,b){var c=b||{};return a.nodeType?a[this.expando]=c:Object.defineProperty(a,this.expando,{value:c,writable:!0,configurable:!0}),a[this.expando]},cache:function(a){if(!L(a))return{};var b=a[this.expando];return b||(b={},L(a)&&(a.nodeType?a[this.expando]=b:Object.defineProperty(a,this.expando,{value:b,configurable:!0}))),b},set:function(a,b,c){var d,e=this.cache(a);if("string"==typeof b)e[b]=c;else for(d in b)e[d]=b[d];return e},get:function(a,b){return void 0===b?this.cache(a):a[this.expando]&&a[this.expando][b]},access:function(a,b,c){var d;return void 0===b||b&&"string"==typeof b&&void 0===c?(d=this.get(a,b),void 0!==d?d:this.get(a,n.camelCase(b))):(this.set(a,b,c),void 0!==c?c:b)},remove:function(a,b){var c,d,e,f=a[this.expando];if(void 0!==f){if(void 0===b)this.register(a);else{n.isArray(b)?d=b.concat(b.map(n.camelCase)):(e=n.camelCase(b),b in f?d=[b,e]:(d=e,d=d in f?[d]:d.match(G)||[])),c=d.length;while(c--)delete f[d[c]]}(void 0===b||n.isEmptyObject(f))&&(a.nodeType?a[this.expando]=void 0:delete a[this.expando])}},hasData:function(a){var b=a[this.expando];return void 0!==b&&!n.isEmptyObject(b)}};var N=new M,O=new M,P=/^(?:\{[\w\W]*\}|\[[\w\W]*\])$/,Q=/[A-Z]/g;function R(a,b,c){var d;if(void 0===c&&1===a.nodeType)if(d="data-"+b.replace(Q,"-$&").toLowerCase(),c=a.getAttribute(d),"string"==typeof c){try{c="true"===c?!0:"false"===c?!1:"null"===c?null:+c+""===c?+c:P.test(c)?n.parseJSON(c):c; +}catch(e){}O.set(a,b,c)}else c=void 0;return c}n.extend({hasData:function(a){return O.hasData(a)||N.hasData(a)},data:function(a,b,c){return O.access(a,b,c)},removeData:function(a,b){O.remove(a,b)},_data:function(a,b,c){return N.access(a,b,c)},_removeData:function(a,b){N.remove(a,b)}}),n.fn.extend({data:function(a,b){var c,d,e,f=this[0],g=f&&f.attributes;if(void 0===a){if(this.length&&(e=O.get(f),1===f.nodeType&&!N.get(f,"hasDataAttrs"))){c=g.length;while(c--)g[c]&&(d=g[c].name,0===d.indexOf("data-")&&(d=n.camelCase(d.slice(5)),R(f,d,e[d])));N.set(f,"hasDataAttrs",!0)}return e}return"object"==typeof a?this.each(function(){O.set(this,a)}):K(this,function(b){var c,d;if(f&&void 0===b){if(c=O.get(f,a)||O.get(f,a.replace(Q,"-$&").toLowerCase()),void 0!==c)return c;if(d=n.camelCase(a),c=O.get(f,d),void 0!==c)return c;if(c=R(f,d,void 0),void 0!==c)return c}else d=n.camelCase(a),this.each(function(){var c=O.get(this,d);O.set(this,d,b),a.indexOf("-")>-1&&void 0!==c&&O.set(this,a,b)})},null,b,arguments.length>1,null,!0)},removeData:function(a){return this.each(function(){O.remove(this,a)})}}),n.extend({queue:function(a,b,c){var d;return a?(b=(b||"fx")+"queue",d=N.get(a,b),c&&(!d||n.isArray(c)?d=N.access(a,b,n.makeArray(c)):d.push(c)),d||[]):void 0},dequeue:function(a,b){b=b||"fx";var c=n.queue(a,b),d=c.length,e=c.shift(),f=n._queueHooks(a,b),g=function(){n.dequeue(a,b)};"inprogress"===e&&(e=c.shift(),d--),e&&("fx"===b&&c.unshift("inprogress"),delete f.stop,e.call(a,g,f)),!d&&f&&f.empty.fire()},_queueHooks:function(a,b){var c=b+"queueHooks";return N.get(a,c)||N.access(a,c,{empty:n.Callbacks("once memory").add(function(){N.remove(a,[b+"queue",c])})})}}),n.fn.extend({queue:function(a,b){var c=2;return"string"!=typeof a&&(b=a,a="fx",c--),arguments.length",""],thead:[1,"","
"],col:[2,"","
"],tr:[2,"","
"],td:[3,"","
"],_default:[0,"",""]};$.optgroup=$.option,$.tbody=$.tfoot=$.colgroup=$.caption=$.thead,$.th=$.td;function _(a,b){var c="undefined"!=typeof a.getElementsByTagName?a.getElementsByTagName(b||"*"):"undefined"!=typeof a.querySelectorAll?a.querySelectorAll(b||"*"):[];return void 0===b||b&&n.nodeName(a,b)?n.merge([a],c):c}function aa(a,b){for(var c=0,d=a.length;d>c;c++)N.set(a[c],"globalEval",!b||N.get(b[c],"globalEval"))}var ba=/<|&#?\w+;/;function ca(a,b,c,d,e){for(var f,g,h,i,j,k,l=b.createDocumentFragment(),m=[],o=0,p=a.length;p>o;o++)if(f=a[o],f||0===f)if("object"===n.type(f))n.merge(m,f.nodeType?[f]:f);else if(ba.test(f)){g=g||l.appendChild(b.createElement("div")),h=(Y.exec(f)||["",""])[1].toLowerCase(),i=$[h]||$._default,g.innerHTML=i[1]+n.htmlPrefilter(f)+i[2],k=i[0];while(k--)g=g.lastChild;n.merge(m,g.childNodes),g=l.firstChild,g.textContent=""}else m.push(b.createTextNode(f));l.textContent="",o=0;while(f=m[o++])if(d&&n.inArray(f,d)>-1)e&&e.push(f);else if(j=n.contains(f.ownerDocument,f),g=_(l.appendChild(f),"script"),j&&aa(g),c){k=0;while(f=g[k++])Z.test(f.type||"")&&c.push(f)}return l}!function(){var a=d.createDocumentFragment(),b=a.appendChild(d.createElement("div")),c=d.createElement("input");c.setAttribute("type","radio"),c.setAttribute("checked","checked"),c.setAttribute("name","t"),b.appendChild(c),l.checkClone=b.cloneNode(!0).cloneNode(!0).lastChild.checked,b.innerHTML="",l.noCloneChecked=!!b.cloneNode(!0).lastChild.defaultValue}();var da=/^key/,ea=/^(?:mouse|pointer|contextmenu|drag|drop)|click/,fa=/^([^.]*)(?:\.(.+)|)/;function ga(){return!0}function ha(){return!1}function ia(){try{return d.activeElement}catch(a){}}function ja(a,b,c,d,e,f){var g,h;if("object"==typeof b){"string"!=typeof c&&(d=d||c,c=void 0);for(h in b)ja(a,h,c,d,b[h],f);return a}if(null==d&&null==e?(e=c,d=c=void 0):null==e&&("string"==typeof c?(e=d,d=void 0):(e=d,d=c,c=void 0)),e===!1)e=ha;else if(!e)return a;return 1===f&&(g=e,e=function(a){return n().off(a),g.apply(this,arguments)},e.guid=g.guid||(g.guid=n.guid++)),a.each(function(){n.event.add(this,b,e,d,c)})}n.event={global:{},add:function(a,b,c,d,e){var f,g,h,i,j,k,l,m,o,p,q,r=N.get(a);if(r){c.handler&&(f=c,c=f.handler,e=f.selector),c.guid||(c.guid=n.guid++),(i=r.events)||(i=r.events={}),(g=r.handle)||(g=r.handle=function(b){return"undefined"!=typeof n&&n.event.triggered!==b.type?n.event.dispatch.apply(a,arguments):void 0}),b=(b||"").match(G)||[""],j=b.length;while(j--)h=fa.exec(b[j])||[],o=q=h[1],p=(h[2]||"").split(".").sort(),o&&(l=n.event.special[o]||{},o=(e?l.delegateType:l.bindType)||o,l=n.event.special[o]||{},k=n.extend({type:o,origType:q,data:d,handler:c,guid:c.guid,selector:e,needsContext:e&&n.expr.match.needsContext.test(e),namespace:p.join(".")},f),(m=i[o])||(m=i[o]=[],m.delegateCount=0,l.setup&&l.setup.call(a,d,p,g)!==!1||a.addEventListener&&a.addEventListener(o,g)),l.add&&(l.add.call(a,k),k.handler.guid||(k.handler.guid=c.guid)),e?m.splice(m.delegateCount++,0,k):m.push(k),n.event.global[o]=!0)}},remove:function(a,b,c,d,e){var f,g,h,i,j,k,l,m,o,p,q,r=N.hasData(a)&&N.get(a);if(r&&(i=r.events)){b=(b||"").match(G)||[""],j=b.length;while(j--)if(h=fa.exec(b[j])||[],o=q=h[1],p=(h[2]||"").split(".").sort(),o){l=n.event.special[o]||{},o=(d?l.delegateType:l.bindType)||o,m=i[o]||[],h=h[2]&&new RegExp("(^|\\.)"+p.join("\\.(?:.*\\.|)")+"(\\.|$)"),g=f=m.length;while(f--)k=m[f],!e&&q!==k.origType||c&&c.guid!==k.guid||h&&!h.test(k.namespace)||d&&d!==k.selector&&("**"!==d||!k.selector)||(m.splice(f,1),k.selector&&m.delegateCount--,l.remove&&l.remove.call(a,k));g&&!m.length&&(l.teardown&&l.teardown.call(a,p,r.handle)!==!1||n.removeEvent(a,o,r.handle),delete i[o])}else for(o in i)n.event.remove(a,o+b[j],c,d,!0);n.isEmptyObject(i)&&N.remove(a,"handle events")}},dispatch:function(a){a=n.event.fix(a);var b,c,d,f,g,h=[],i=e.call(arguments),j=(N.get(this,"events")||{})[a.type]||[],k=n.event.special[a.type]||{};if(i[0]=a,a.delegateTarget=this,!k.preDispatch||k.preDispatch.call(this,a)!==!1){h=n.event.handlers.call(this,a,j),b=0;while((f=h[b++])&&!a.isPropagationStopped()){a.currentTarget=f.elem,c=0;while((g=f.handlers[c++])&&!a.isImmediatePropagationStopped())a.rnamespace&&!a.rnamespace.test(g.namespace)||(a.handleObj=g,a.data=g.data,d=((n.event.special[g.origType]||{}).handle||g.handler).apply(f.elem,i),void 0!==d&&(a.result=d)===!1&&(a.preventDefault(),a.stopPropagation()))}return k.postDispatch&&k.postDispatch.call(this,a),a.result}},handlers:function(a,b){var c,d,e,f,g=[],h=b.delegateCount,i=a.target;if(h&&i.nodeType&&("click"!==a.type||isNaN(a.button)||a.button<1))for(;i!==this;i=i.parentNode||this)if(1===i.nodeType&&(i.disabled!==!0||"click"!==a.type)){for(d=[],c=0;h>c;c++)f=b[c],e=f.selector+" ",void 0===d[e]&&(d[e]=f.needsContext?n(e,this).index(i)>-1:n.find(e,this,null,[i]).length),d[e]&&d.push(f);d.length&&g.push({elem:i,handlers:d})}return h]*)\/>/gi,la=/\s*$/g;function pa(a,b){return n.nodeName(a,"table")&&n.nodeName(11!==b.nodeType?b:b.firstChild,"tr")?a.getElementsByTagName("tbody")[0]||a.appendChild(a.ownerDocument.createElement("tbody")):a}function qa(a){return a.type=(null!==a.getAttribute("type"))+"/"+a.type,a}function ra(a){var b=na.exec(a.type);return b?a.type=b[1]:a.removeAttribute("type"),a}function sa(a,b){var c,d,e,f,g,h,i,j;if(1===b.nodeType){if(N.hasData(a)&&(f=N.access(a),g=N.set(b,f),j=f.events)){delete g.handle,g.events={};for(e in j)for(c=0,d=j[e].length;d>c;c++)n.event.add(b,e,j[e][c])}O.hasData(a)&&(h=O.access(a),i=n.extend({},h),O.set(b,i))}}function ta(a,b){var c=b.nodeName.toLowerCase();"input"===c&&X.test(a.type)?b.checked=a.checked:"input"!==c&&"textarea"!==c||(b.defaultValue=a.defaultValue)}function ua(a,b,c,d){b=f.apply([],b);var e,g,h,i,j,k,m=0,o=a.length,p=o-1,q=b[0],r=n.isFunction(q);if(r||o>1&&"string"==typeof q&&!l.checkClone&&ma.test(q))return a.each(function(e){var f=a.eq(e);r&&(b[0]=q.call(this,e,f.html())),ua(f,b,c,d)});if(o&&(e=ca(b,a[0].ownerDocument,!1,a,d),g=e.firstChild,1===e.childNodes.length&&(e=g),g||d)){for(h=n.map(_(e,"script"),qa),i=h.length;o>m;m++)j=e,m!==p&&(j=n.clone(j,!0,!0),i&&n.merge(h,_(j,"script"))),c.call(a[m],j,m);if(i)for(k=h[h.length-1].ownerDocument,n.map(h,ra),m=0;i>m;m++)j=h[m],Z.test(j.type||"")&&!N.access(j,"globalEval")&&n.contains(k,j)&&(j.src?n._evalUrl&&n._evalUrl(j.src):n.globalEval(j.textContent.replace(oa,"")))}return a}function va(a,b,c){for(var d,e=b?n.filter(b,a):a,f=0;null!=(d=e[f]);f++)c||1!==d.nodeType||n.cleanData(_(d)),d.parentNode&&(c&&n.contains(d.ownerDocument,d)&&aa(_(d,"script")),d.parentNode.removeChild(d));return a}n.extend({htmlPrefilter:function(a){return a.replace(ka,"<$1>")},clone:function(a,b,c){var d,e,f,g,h=a.cloneNode(!0),i=n.contains(a.ownerDocument,a);if(!(l.noCloneChecked||1!==a.nodeType&&11!==a.nodeType||n.isXMLDoc(a)))for(g=_(h),f=_(a),d=0,e=f.length;e>d;d++)ta(f[d],g[d]);if(b)if(c)for(f=f||_(a),g=g||_(h),d=0,e=f.length;e>d;d++)sa(f[d],g[d]);else sa(a,h);return g=_(h,"script"),g.length>0&&aa(g,!i&&_(a,"script")),h},cleanData:function(a){for(var b,c,d,e=n.event.special,f=0;void 0!==(c=a[f]);f++)if(L(c)){if(b=c[N.expando]){if(b.events)for(d in b.events)e[d]?n.event.remove(c,d):n.removeEvent(c,d,b.handle);c[N.expando]=void 0}c[O.expando]&&(c[O.expando]=void 0)}}}),n.fn.extend({domManip:ua,detach:function(a){return va(this,a,!0)},remove:function(a){return va(this,a)},text:function(a){return K(this,function(a){return void 0===a?n.text(this):this.empty().each(function(){1!==this.nodeType&&11!==this.nodeType&&9!==this.nodeType||(this.textContent=a)})},null,a,arguments.length)},append:function(){return ua(this,arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=pa(this,a);b.appendChild(a)}})},prepend:function(){return ua(this,arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=pa(this,a);b.insertBefore(a,b.firstChild)}})},before:function(){return ua(this,arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this)})},after:function(){return ua(this,arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this.nextSibling)})},empty:function(){for(var a,b=0;null!=(a=this[b]);b++)1===a.nodeType&&(n.cleanData(_(a,!1)),a.textContent="");return this},clone:function(a,b){return a=null==a?!1:a,b=null==b?a:b,this.map(function(){return n.clone(this,a,b)})},html:function(a){return K(this,function(a){var b=this[0]||{},c=0,d=this.length;if(void 0===a&&1===b.nodeType)return b.innerHTML;if("string"==typeof a&&!la.test(a)&&!$[(Y.exec(a)||["",""])[1].toLowerCase()]){a=n.htmlPrefilter(a);try{for(;d>c;c++)b=this[c]||{},1===b.nodeType&&(n.cleanData(_(b,!1)),b.innerHTML=a);b=0}catch(e){}}b&&this.empty().append(a)},null,a,arguments.length)},replaceWith:function(){var a=[];return ua(this,arguments,function(b){var c=this.parentNode;n.inArray(this,a)<0&&(n.cleanData(_(this)),c&&c.replaceChild(b,this))},a)}}),n.each({appendTo:"append",prependTo:"prepend",insertBefore:"before",insertAfter:"after",replaceAll:"replaceWith"},function(a,b){n.fn[a]=function(a){for(var c,d=[],e=n(a),f=e.length-1,h=0;f>=h;h++)c=h===f?this:this.clone(!0),n(e[h])[b](c),g.apply(d,c.get());return this.pushStack(d)}});var wa,xa={HTML:"block",BODY:"block"};function ya(a,b){var c=n(b.createElement(a)).appendTo(b.body),d=n.css(c[0],"display");return c.detach(),d}function za(a){var b=d,c=xa[a];return c||(c=ya(a,b),"none"!==c&&c||(wa=(wa||n("' + link = 'New Window'%url + return IPython.display.HTML(iframe+link) + +def nb_play_mocap(mocap, mf, meta=None, frame_time=1/30, scale=1, camera_z=500, base_url=None): + data_template = 'var dataBuffer = `$$DATA$$`;' + data_template += 'var metadata = $$META$$;' + data_template += 'start(dataBuffer, metadata, $$CZ$$, $$SCALE$$, $$FRAMETIME$$);' + dir_path = os.path.dirname(os.path.realpath(__file__)) + + + if base_url is None: + base_url = os.path.join(dir_path, 'mocapplayer/playBuffer.html') + + # print(dir_path) + + if mf == 'bvh': + pass + elif mf == 'pos': + cols = list(mocap.values.columns) + for c in cols: + if 'rotation' in c: + cols.remove(c) + + data_csv = mocap.values.to_csv(index=False, columns=cols) + + if meta is not None: + lines = [','.join(item) for item in meta.astype('str')] + meta_csv = '[' + ','.join('[%s]'%l for l in lines) +']' + else: + meta_csv = '[]' + + data_assigned = data_template.replace('$$DATA$$', data_csv) + data_assigned = data_assigned.replace('$$META$$', meta_csv) + data_assigned = data_assigned.replace('$$CZ$$', str(camera_z)) + data_assigned = data_assigned.replace('$$SCALE$$', str(scale)) + data_assigned = data_assigned.replace('$$FRAMETIME$$', str(frame_time)) + + else: + return + + + + with open(os.path.join(dir_path, 'mocapplayer/data.js'), 'w') as oFile: + oFile.write(data_assigned) + + url = '%s?&cz=200&order=xzyi&frame_time=%f&scale=%f'%(base_url, frame_time, scale) + iframe = '' + link = 'New Window'%url + return IPython.display.HTML(iframe+link) diff --git a/dataloaders/pymo/writers.py b/dataloaders/pymo/writers.py new file mode 100644 index 0000000000000000000000000000000000000000..834ef639bb3c86e7ca94a0c6de2fa868a48c3ff9 --- /dev/null +++ b/dataloaders/pymo/writers.py @@ -0,0 +1,55 @@ +import numpy as np +import pandas as pd + +class BVHWriter(): + def __init__(self): + pass + + def write(self, X, ofile): + + # Writing the skeleton info + ofile.write('HIERARCHY\n') + + self.motions_ = [] + self._printJoint(X, X.root_name, 0, ofile) + + # Writing the motion header + ofile.write('MOTION\n') + ofile.write('Frames: %d\n'%X.values.shape[0]) + ofile.write('Frame Time: %f\n'%X.framerate) + + # Writing the data + self.motions_ = np.asarray(self.motions_).T + lines = [" ".join(item) for item in self.motions_.astype(str)] + ofile.write("".join("%s\n"%l for l in lines)) + + def _printJoint(self, X, joint, tab, ofile): + + if X.skeleton[joint]['parent'] == None: + ofile.write('ROOT %s\n'%joint) + elif len(X.skeleton[joint]['children']) > 0: + ofile.write('%sJOINT %s\n'%('\t'*(tab), joint)) + else: + ofile.write('%sEnd site\n'%('\t'*(tab))) + + ofile.write('%s{\n'%('\t'*(tab))) + + ofile.write('%sOFFSET %3.5f %3.5f %3.5f\n'%('\t'*(tab+1), + X.skeleton[joint]['offsets'][0], + X.skeleton[joint]['offsets'][1], + X.skeleton[joint]['offsets'][2])) + channels = X.skeleton[joint]['channels'] + n_channels = len(channels) + + if n_channels > 0: + for ch in channels: + self.motions_.append(np.asarray(X.values['%s_%s'%(joint, ch)].values)) + + if len(X.skeleton[joint]['children']) > 0: + ch_str = ''.join(' %s'*n_channels%tuple(channels)) + ofile.write('%sCHANNELS %d%s\n' %('\t'*(tab+1), n_channels, ch_str)) + + for c in X.skeleton[joint]['children']: + self._printJoint(X, c, tab+1, ofile) + + ofile.write('%s}\n'%('\t'*(tab))) diff --git a/dataloaders/utils/__pycache__/audio_features.cpython-310.pyc b/dataloaders/utils/__pycache__/audio_features.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d221146c7895dbdba76987d632b873d3152584c4 Binary files /dev/null and b/dataloaders/utils/__pycache__/audio_features.cpython-310.pyc differ diff --git a/dataloaders/utils/__pycache__/audio_features.cpython-38.pyc b/dataloaders/utils/__pycache__/audio_features.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f48142fa047687a562f49e7868276b4ba004dbf Binary files /dev/null and b/dataloaders/utils/__pycache__/audio_features.cpython-38.pyc differ diff --git a/dataloaders/utils/__pycache__/other_tools.cpython-38.pyc b/dataloaders/utils/__pycache__/other_tools.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ffdd44ee8648bd8f004c680ae32cbf2c2a18361 Binary files /dev/null and b/dataloaders/utils/__pycache__/other_tools.cpython-38.pyc differ diff --git a/dataloaders/utils/__pycache__/other_tools_hf.cpython-310.pyc b/dataloaders/utils/__pycache__/other_tools_hf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..009c7a3ee56779cf310ac47a2392ea307efc9298 Binary files /dev/null and b/dataloaders/utils/__pycache__/other_tools_hf.cpython-310.pyc differ diff --git a/dataloaders/utils/__pycache__/other_tools_hf.cpython-38.pyc b/dataloaders/utils/__pycache__/other_tools_hf.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3dc360c1e968ba44468f6a34774eb40cede5db9 Binary files /dev/null and b/dataloaders/utils/__pycache__/other_tools_hf.cpython-38.pyc differ diff --git a/dataloaders/utils/__pycache__/rotation_conversions.cpython-310.pyc b/dataloaders/utils/__pycache__/rotation_conversions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c99a0c2e48d9745bf6ad6b306052d4d5ea5f0ab8 Binary files /dev/null and b/dataloaders/utils/__pycache__/rotation_conversions.cpython-310.pyc differ diff --git a/dataloaders/utils/__pycache__/rotation_conversions.cpython-38.pyc b/dataloaders/utils/__pycache__/rotation_conversions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3618b9ecddd3397fa4ba72158c929778ac8c02fb Binary files /dev/null and b/dataloaders/utils/__pycache__/rotation_conversions.cpython-38.pyc differ diff --git a/dataloaders/utils/audio_features.py b/dataloaders/utils/audio_features.py new file mode 100644 index 0000000000000000000000000000000000000000..596c40ee03cf21c6ec159e3c3542b3086ef38e73 --- /dev/null +++ b/dataloaders/utils/audio_features.py @@ -0,0 +1,209 @@ +"""modified from https://github.com/yesheng-THU/GFGE/blob/main/data_processing/audio_features.py""" +import numpy as np +import librosa +import math +import os +import scipy.io.wavfile as wav +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from tqdm import tqdm +from transformers import Wav2Vec2Model, Wav2Vec2Config +from transformers.modeling_outputs import BaseModelOutput +from typing import Optional, Tuple +_CONFIG_FOR_DOC = "Wav2Vec2Config" + +# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model +# initialize our encoder with the pre-trained wav2vec 2.0 weights. +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.Tensor] = None, + min_masks: int = 0, +) -> np.ndarray: + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + mask_idcs = [] + padding_mask = attention_mask.ne(1) if attention_mask is not None else None + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + lengths = np.full(num_mask, mask_length) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + return mask + +# linear interpolation layer +def linear_interpolation(features, input_fps, output_fps, output_len=None): + features = features.transpose(1, 2) + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features,size=output_len,align_corners=True,mode='linear') + return output_features.transpose(1, 2) + +class Wav2Vec2Model(Wav2Vec2Model): + def __init__(self, config): + super().__init__(config) + self.audio_fps = 15 #args.audio_fps + #input_values 16K hz, 49fps, 20ms overlap, 25ms recepion field + def forward( + self, + input_values, + dataset="beat", + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + frame_num=None + ): + #print(input_values.shape) + self.config.output_attentions = True + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.feature_extractor(input_values) + hidden_states = hidden_states.transpose(1, 2) + #print(hidden_states.shape) + if dataset == "beat": + hidden_states = linear_interpolation(hidden_states, 49, self.audio_fps, output_len=frame_num) + #print(hidden_states.shape) + if attention_mask is not None: + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + attention_mask = torch.zeros( + hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device + ) + attention_mask[ + (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) + ] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + + hidden_states = self.feature_projection(hidden_states)[0] + #print(hidden_states.shape) + if self.config.apply_spec_augment and self.training: + batch_size, sequence_length, hidden_size = hidden_states.size() + if self.config.mask_time_prob > 0: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + self.config.mask_time_prob, + self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=2, + ) + hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) + if self.config.mask_feature_prob > 0: + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + self.config.mask_feature_prob, + self.config.mask_feature_length, + ) + mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) + hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = encoder_outputs[0] + #print(encoder_outputs.shape) + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +def extract_wav2vec2(file_folder, destpath, fps, inference_length=16000*20): + wav2vec_model = Wav2Vec2Model.from_pretrained("/home/ma-user/work/datasets/hub/transformer/wav2vec2-base-960h") + wav2vec_model.feature_extractor._freeze_parameters() + wav2vec_model = wav2vec_model.cuda() + wav2vec_model.eval() + audio_mean = np.load("/home/ma-user/work/datasets/beat_cache/beat_english_15_141/train/wave16k/npy_mean.npy") + audio_std = np.load("/home/ma-user/work/datasets/beat_cache/beat_english_15_141/train/wave16k/npy_std.npy") + if not os.path.exists(destpath): os.mkdir(destpath) + with torch.no_grad(): + for file_name in tqdm(os.listdir(file_folder)): + if "mean" in file_name or "std" in file_name or "pynb" in file_name: continue + audio_np = np.load(file_folder+file_name) + audio_np = (audio_np-audio_mean)/audio_std + audio_torch = torch.from_numpy(audio_np).cuda() + audio_torch = audio_torch.reshape(1, -1) + #print(audio_torch.shape, audio_np.shape) + + if audio_torch.shape[1] > inference_length: + num_div = audio_torch.shape[1] // inference_length + remain = audio_torch.shape[1] % inference_length + for i in range(num_div): + audio_feat = wav2vec_model(audio_torch[:, i*inference_length:(i+1)*inference_length]).last_hidden_state.cpu().numpy().reshape(-1, 768) + if i == 0: + audio_feat_all = audio_feat + else: + audio_feat_all = np.concatenate((audio_feat_all, audio_feat), 0) + if remain > 1600: #0.25s + audio_feat = wav2vec_model(audio_torch[:, num_div*inference_length:num_div*inference_length+remain]).last_hidden_state.cpu().numpy().reshape(-1, 768) + audio_feat_all = np.concatenate((audio_feat_all, audio_feat), 0) + else: + audio_feat_all = wav2vec_model(audio_torch).last_hidden_state.cpu().numpy().reshape(-1, 768) + #print(audio_feat_all.shape, audio_np.shape[0]/16000*15, torch.cuda.memory_cached() / 1E9) + np.save(destpath+file_name, audio_feat_all) + +def extract_melspec(file, destpath, fps, n_mels=128): + fs,X = wav.read(file) + X = X.astype(float)/math.pow(2,15) + target_sr = 48000 + X_48k = librosa.resample(X, orig_sr=fs, target_sr=target_sr, res_type="kaiser_best") + n_fft=int(target_sr*0.13) + hop_len=int(target_sr/fps) + C = librosa.feature.melspectrogram(y=X_48k, sr=target_sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels, fmin=0.0, fmax=8000) + #C2 = librosa.feature.melspectrogram(y=X, sr=fs, n_fft=1024, hop_length=512) + #print(C.shape, C2.shape) + C = np.log(C) + np.save(destpath,np.transpose(C)) + + +if __name__ == "__main__": + #calculate mean and build cache for data. + target_fps = 15 + ori_data_path = f"/home/ma-user/work/datasets/beat_cache/beat_english_{target_fps}_141/" + for data_type in ["train", "val", "test"]: + extract_wav2vec2(ori_data_path+data_type+"/wave16k/", ori_data_path+data_type+f"/wav2vec2_{target_fps}/", target_fps) \ No newline at end of file diff --git a/dataloaders/utils/other_tools.py b/dataloaders/utils/other_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..08ce836b36ef1148f7d0eabaed131ac10f287020 --- /dev/null +++ b/dataloaders/utils/other_tools.py @@ -0,0 +1,676 @@ +import os +import numpy as np +import random +import torch +import shutil +import csv +import pprint +import pandas as pd +from loguru import logger +from collections import OrderedDict +import matplotlib.pyplot as plt +import pickle +import time + +import numpy as np + +def adjust_array(x, k): + len_x = len(x) + len_k = len(k) + + # If x is shorter than k, pad with zeros + if len_x < len_k: + return np.pad(x, (0, len_k - len_x), 'constant') + + # If x is longer than k, truncate x + elif len_x > len_k: + return x[:len_k] + + # If both are of same length + else: + return x + +def onset_to_frame(onset_times, audio_length, fps): + # Calculate total number of frames for the given audio length + total_frames = int(audio_length * fps) + + # Create an array of zeros of shape (total_frames,) + frame_array = np.zeros(total_frames, dtype=np.int32) + + # For each onset time, calculate the frame number and set it to 1 + for onset in onset_times: + frame_num = int(onset * fps) + # Check if the frame number is within the array bounds + if 0 <= frame_num < total_frames: + frame_array[frame_num] = 1 + + return frame_array + +def smooth_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using linear interpolation. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + blend_frames = min(blend_frames, len(animation1), len(animation2)) + + # Extract overlapping sections + overlap_a1 = animation1[-blend_frames:-blend_frames+1, :] + overlap_a2 = animation2[blend_frames-1:blend_frames, :] + + # Create blend weights for linear interpolation + alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1) + + # Linearly interpolate between overlapping sections + blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha + + # Extend the animations to form the result with 2n frames + if blend_frames == len(animation1) and blend_frames == len(animation2): + result = blended_overlap + else: + before_blend = animation1[:-blend_frames] + after_blend = animation2[blend_frames:] + result = np.vstack((before_blend, blended_overlap, after_blend)) + return result + + +def interpolate_sequence(quaternions): + bs, n, j, _ = quaternions.shape + new_n = 2 * n + new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype) + + for i in range(n): + q1 = quaternions[:, i, :, :] + new_quaternions[:, 2*i, :, :] = q1 + + if i < n - 1: + q2 = quaternions[:, i + 1, :, :] + new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5) + else: + # For the last point, duplicate the value + new_quaternions[:, 2*i + 1, :, :] = q1 + + return new_quaternions + +def quaternion_multiply(q1, q2): + w1, x1, y1, z1 = q1 + w2, x2, y2, z2 = q2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return w, x, y, z + +def quaternion_conjugate(q): + w, x, y, z = q + return (w, -x, -y, -z) + +def slerp(q1, q2, t): + dot = torch.sum(q1 * q2, dim=-1, keepdim=True) + + flip = (dot < 0).float() + q2 = (1 - flip * 2) * q2 + dot = dot * (1 - flip * 2) + + DOT_THRESHOLD = 0.9995 + mask = (dot > DOT_THRESHOLD).float() + + theta_0 = torch.acos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / torch.norm(q3, dim=-1, keepdim=True) + + interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3) + + return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated + +def estimate_linear_velocity(data_seq, dt): + ''' + Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates + the velocity for the middle T-2 steps using a second order central difference scheme. + The first and last frames are with forward and backward first-order + differences, respectively + - h : step size + ''' + # first steps is forward diff (t+1 - t) / dt + init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt + + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) + return vel_seq + + +def estimate_angular_velocity(rot_seq, dt): + ''' + Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. + Input sequence should be of shape (B, T, ..., 3, 3) + ''' + # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix + dRdt = estimate_linear_velocity(rot_seq, dt) + R = rot_seq + RT = R.transpose(-1, -2) + # compute skew-symmetric angular velocity tensor + w_mat = torch.matmul(dRdt, RT) + # pull out angular velocity vector by averaging symmetric entries + w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 + w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 + w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 + w = torch.stack([w_x, w_y, w_z], axis=-1) + return w + +import matplotlib.image as mpimg +from io import BytesIO + +def image_from_bytes(image_bytes): + return mpimg.imread(BytesIO(image_bytes), format='PNG') + + + +def process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1): + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import trimesh + import pyvirtualdisplay as Display + + vertices = vertices_all[i] + vertices1 = vertices1_all[i] + filename = f"{output_dir}frame_{i}.png" + filenames.append(filename) + if i%100 == 0: + print('processed', i, 'frames') + #time_s = time.time() + #print(vertices.shape) + if use_matplotlib: + fig = plt.figure(figsize=(20, 10)) + ax = fig.add_subplot(121, projection="3d") + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + #ax.view_init(elev=0, azim=90) + x = vertices[:, 0] + y = vertices[:, 1] + z = vertices[:, 2] + ax.scatter(x, y, z, s=0.5) + ax.set_xlim([-1.0, 1.0]) + ax.set_ylim([-0.5, 1.5])#heigth + ax.set_zlim([-0, 2])#depth + ax.set_box_aspect((1,1,1)) + else: + mesh = trimesh.Trimesh(vertices, faces) + scene = mesh.scene() + scene.camera.fov = camera_params['fov'] + scene.camera.resolution = camera_params['resolution'] + scene.camera.z_near = camera_params['z_near'] + scene.camera.z_far = camera_params['z_far'] + scene.graph[scene.camera.name] = camera_params['transform'] + fig, ax =plt.subplots(1,2, figsize=(16, 6)) + image = scene.save_image(resolution=[640, 480], visible=False) + im0 = ax[0].imshow(image_from_bytes(image)) + ax[0].axis('off') + + if use_matplotlib: + ax2 = fig.add_subplot(122, projection="3d") + ax2.set_box_aspect((1,1,1)) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + x1 = vertices1[:, 0] + y1 = vertices1[:, 1] + z1 = vertices1[:, 2] + ax2.scatter(x1, y1, z1, s=0.5) + ax2.set_xlim([-1.0, 1.0]) + ax2.set_ylim([-0.5, 1.5])#heigth + ax2.set_zlim([-0, 2]) + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + else: + mesh1 = trimesh.Trimesh(vertices1, faces) + scene1 = mesh1.scene() + scene1.camera.fov = camera_params1['fov'] + scene1.camera.resolution = camera_params1['resolution'] + scene1.camera.z_near = camera_params1['z_near'] + scene1.camera.z_far = camera_params1['z_far'] + scene1.graph[scene1.camera.name] = camera_params1['transform'] + image1 = scene1.save_image(resolution=[640, 480], visible=False) + im1 = ax[1].imshow(image_from_bytes(image1)) + ax[1].axis('off') + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + +def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames): + import multiprocessing + import trimesh + num_cores = multiprocessing.cpu_count() # This will get the number of cores on your machine. + mesh = trimesh.Trimesh(vertices_all[0], faces) + scene = mesh.scene() + camera_params = { + 'fov': scene.camera.fov, + 'resolution': scene.camera.resolution, + 'focal': scene.camera.focal, + 'z_near': scene.camera.z_near, + "z_far": scene.camera.z_far, + 'transform': scene.graph[scene.camera.name][0] + } + mesh1 = trimesh.Trimesh(vertices1_all[0], faces) + scene1 = mesh1.scene() + camera_params1 = { + 'fov': scene1.camera.fov, + 'resolution': scene1.camera.resolution, + 'focal': scene1.camera.focal, + 'z_near': scene1.camera.z_near, + "z_far": scene1.camera.z_far, + 'transform': scene1.graph[scene1.camera.name][0] + } + # Use a Pool to manage the processes + # print(num_cores) + progress = multiprocessing.Value('i', 0) + lock = multiprocessing.Lock() + with multiprocessing.Pool(num_cores) as pool: + pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + +def render_one_sequence( + res_npz_path, + gt_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import moviepy.editor as mp + import librosa + + model = smplx.create(model_folder, model_type=model_type, + gender=gender, use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, use_pca=False).cuda() + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + filenames = [] + if not use_matplotlib: + import trimesh + #import pyrender + from pyvirtualdisplay import Display + display = Display(visible=0, size=(640, 480)) + display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).cuda() + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).cuda() + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).cuda() + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).cuda() + pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).cuda() + transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).cuda() + output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + leye_pose=pose1[:, 69:72], + reye_pose=pose1[:, 72:75],return_verts=True) + vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + # camera_settings = None + time_s = time.time() + generate_images(int(seconds*30), vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames) + filenames = [f"{output_dir}frame_{i}.png" for i in range(int(seconds*30))] + # print(time.time()-time_s) + # for i in tqdm(range(seconds*30)): + # vertices = vertices_all[i] + # vertices1 = vertices1_all[i] + # filename = f"{output_dir}frame_{i}.png" + # filenames.append(filename) + # #time_s = time.time() + # #print(vertices.shape) + # if use_matplotlib: + # fig = plt.figure(figsize=(20, 10)) + # ax = fig.add_subplot(121, projection="3d") + # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + # #ax.view_init(elev=0, azim=90) + # x = vertices[:, 0] + # y = vertices[:, 1] + # z = vertices[:, 2] + # ax.scatter(x, y, z, s=0.5) + # ax.set_xlim([-1.0, 1.0]) + # ax.set_ylim([-0.5, 1.5])#heigth + # ax.set_zlim([-0, 2])#depth + # ax.set_box_aspect((1,1,1)) + # else: + # mesh = trimesh.Trimesh(vertices, faces) + # if i == 0: + # scene = mesh.scene() + # camera_params = { + # 'fov': scene.camera.fov, + # 'resolution': scene.camera.resolution, + # 'focal': scene.camera.focal, + # 'z_near': scene.camera.z_near, + # "z_far": scene.camera.z_far, + # 'transform': scene.graph[scene.camera.name][0] + # } + # else: + # scene = mesh.scene() + # scene.camera.fov = camera_params['fov'] + # scene.camera.resolution = camera_params['resolution'] + # scene.camera.z_near = camera_params['z_near'] + # scene.camera.z_far = camera_params['z_far'] + # scene.graph[scene.camera.name] = camera_params['transform'] + # fig, ax =plt.subplots(1,2, figsize=(16, 6)) + # image = scene.save_image(resolution=[640, 480], visible=False) + # #print((time.time()-time_s)) + # im0 = ax[0].imshow(image_from_bytes(image)) + # ax[0].axis('off') + + # # beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0) + # # expression1 = torch.from_numpy(gt_np_body["expressions"][i]).to(torch.float32).unsqueeze(0) + # # jaw_pose1 = torch.from_numpy(gt_np_body["poses"][i][66:69]).to(torch.float32).unsqueeze(0) + # # pose1 = torch.from_numpy(gt_np_body["poses"][i]).to(torch.float32).unsqueeze(0) + # # transl1 = torch.from_numpy(gt_np_body["trans"][i]).to(torch.float32).unsqueeze(0) + # # #print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape)global_orient=pose[0:1,:3], + # # output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[0:1,:3], body_pose=pose1[0:1,3:21*3+3], left_hand_pose=pose1[0:1,25*3:40*3], right_hand_pose=pose1[0:1,40*3:55*3], return_verts=True) + # # vertices1 = output1["vertices"].cpu().detach().numpy()[0] + + # if use_matplotlib: + # ax2 = fig.add_subplot(122, projection="3d") + # ax2.set_box_aspect((1,1,1)) + # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + # #ax2.view_init(elev=0, azim=90) + # x1 = vertices1[:, 0] + # y1 = vertices1[:, 1] + # z1 = vertices1[:, 2] + # ax2.scatter(x1, y1, z1, s=0.5) + # ax2.set_xlim([-1.0, 1.0]) + # ax2.set_ylim([-0.5, 1.5])#heigth + # ax2.set_zlim([-0, 2]) + # plt.savefig(filename, bbox_inches='tight') + # plt.close(fig) + # else: + # mesh1 = trimesh.Trimesh(vertices1, faces) + # if i == 0: + # scene1 = mesh1.scene() + # camera_params1 = { + # 'fov': scene1.camera.fov, + # 'resolution': scene1.camera.resolution, + # 'focal': scene1.camera.focal, + # 'z_near': scene1.camera.z_near, + # "z_far": scene1.camera.z_far, + # 'transform': scene1.graph[scene1.camera.name][0] + # } + # else: + # scene1 = mesh1.scene() + # scene1.camera.fov = camera_params1['fov'] + # scene1.camera.resolution = camera_params1['resolution'] + # scene1.camera.z_near = camera_params1['z_near'] + # scene1.camera.z_far = camera_params1['z_far'] + # scene1.graph[scene1.camera.name] = camera_params1['transform'] + # image1 = scene1.save_image(resolution=[640, 480], visible=False) + # im1 = ax[1].imshow(image_from_bytes(image1)) + # ax[1].axis('off') + # plt.savefig(filename, bbox_inches='tight') + # plt.close(fig) + + # display.stop() + # print(filenames) + images = [imageio.imread(filename) for filename in filenames] + imageio.mimsave(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4", images, fps=30) + for filename in filenames: + os.remove(filename) + + video = mp.VideoFileClip(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4") + # audio, sr = librosa.load(audio_path) + # audio = audio[:seconds*sr] + # print(audio.shape, seconds, sr) + # import soundfile as sf + # sf.write(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, 16000, 'PCM_24') + # audio_tmp = librosa.output.write_wav(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, sr=16000) + audio = mp.AudioFileClip(audio_path) + if audio.duration > video.duration: + audio = audio.subclip(0, video.duration) + final_clip = video.set_audio(audio) + final_clip.write_videofile(f"{output_dir}{res_npz_path.split('/')[-1][4:-4]}.mp4") + os.remove(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4") + +def print_exp_info(args): + logger.info(pprint.pformat(vars(args))) + logger.info(f"# ------------ {args.name} ----------- #") + logger.info("PyTorch version: {}".format(torch.__version__)) + logger.info("CUDA version: {}".format(torch.version.cuda)) + logger.info("{} GPUs".format(torch.cuda.device_count())) + logger.info(f"Random Seed: {args.random_seed}") + +def args2csv(args, get_head=False, list4print=[]): + for k, v in args.items(): + if isinstance(args[k], dict): + args2csv(args[k], get_head, list4print) + else: list4print.append(k) if get_head else list4print.append(v) + return list4print + +class EpochTracker: + def __init__(self, metric_names, metric_directions): + assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length" + + + self.metric_names = metric_names + self.states = ['train', 'val', 'test'] + self.types = ['last', 'best'] + + + self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0} + for type_ in self.types} + for state in self.states} + for name, is_higher_better in zip(metric_names, metric_directions)} + + self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}") + for state in self.states} + for name in metric_names} + + + self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)} + self.train_history = {name: [] for name in metric_names} + self.val_history = {name: [] for name in metric_names} + + + def update_meter(self, name, state, value): + self.loss_meters[name][state].update(value) + + + def update_values(self, name, state, epoch): + value_avg = self.loss_meters[name][state].avg + new_best = False + + + if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or + (value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])): + self.values[name][state]['best']['value'] = value_avg + self.values[name][state]['best']['epoch'] = epoch + new_best = True + self.values[name][state]['last']['value'] = value_avg + self.values[name][state]['last']['epoch'] = epoch + return new_best + + + def get(self, name, state, type_): + return self.values[name][state][type_] + + + def reset(self): + for name in self.metric_names: + for state in self.states: + self.loss_meters[name][state].reset() + + + def flatten_values(self): + flat_dict = {} + for name in self.metric_names: + for state in self.states: + for type_ in self.types: + value_key = f"{name}_{state}_{type_}" + epoch_key = f"{name}_{state}_{type_}_epoch" + flat_dict[value_key] = self.values[name][state][type_]['value'] + flat_dict[epoch_key] = self.values[name][state][type_]['epoch'] + return flat_dict + + def update_and_plot(self, name, epoch, save_path): + new_best_train = self.update_values(name, 'train', epoch) + new_best_val = self.update_values(name, 'val', epoch) + + + self.train_history[name].append(self.loss_meters[name]['train'].avg) + self.val_history[name].append(self.loss_meters[name]['val'].avg) + + + train_values = self.train_history[name] + val_values = self.val_history[name] + epochs = list(range(1, len(train_values) + 1)) + + + plt.figure(figsize=(10, 6)) + plt.plot(epochs, train_values, label='Train') + plt.plot(epochs, val_values, label='Val') + plt.title(f'Train vs Val {name} over epochs') + plt.xlabel('Epochs') + plt.ylabel(name) + plt.legend() + plt.savefig(save_path) + plt.close() + + + return new_best_train, new_best_val + + + + +def record_trial(args, tracker): + """ + 1. record notes, score, env_name, experments_path, + """ + csv_path = args.out_path + "custom/" +args.csv_name+".csv" + all_print_dict = vars(args) + all_print_dict.update(tracker.flatten_values()) + if not os.path.exists(csv_path): + pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False) + else: + df_existing = pd.read_csv(csv_path) + df_new = pd.DataFrame([all_print_dict]) + df_aligned = df_existing.append(df_new).fillna("") + df_aligned.to_csv(csv_path, index=False) + + +def set_random_seed(args): + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + np.random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + torch.cuda.manual_seed_all(args.random_seed) + torch.cuda.manual_seed(args.random_seed) + torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC + torch.backends.cudnn.benchmark = args.benchmark + torch.backends.cudnn.enabled = args.cudnn_enabled + + +def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None): + if lrs is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(), + 'lrs':lrs.state_dict(),} + elif opt is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(),} + else: + states = { 'model_state': model.state_dict(),} + torch.save(states, save_path) + + +def load_checkpoints(model, save_path, load_name='model'): + states = torch.load(save_path) + new_weights = OrderedDict() + flag=False + for k, v in states['model_state'].items(): + #print(k) + if "module" not in k: + break + else: + new_weights[k[7:]]=v + flag=True + if flag: + try: + model.load_state_dict(new_weights) + except: + #print(states['model_state']) + model.load_state_dict(states['model_state']) + else: + model.load_state_dict(states['model_state']) + logger.info(f"load self-pretrained checkpoints for {load_name}") + + +def model_complexity(model, args): + from ptflops import get_model_complexity_info + flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN), + as_strings=False, print_per_layer_stat=False) + logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9)) + logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6)) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) \ No newline at end of file diff --git a/dataloaders/utils/other_tools_hf.py b/dataloaders/utils/other_tools_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..90ee9fcc909fbd0f3da6129bc90732c0d168d5aa --- /dev/null +++ b/dataloaders/utils/other_tools_hf.py @@ -0,0 +1,991 @@ +import os +import numpy as np +import random +import torch +import shutil +import csv +import pprint +import pandas as pd +from loguru import logger +from collections import OrderedDict +import matplotlib.pyplot as plt +import pickle +import time +import hashlib +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp +import cv2 + +def write_wav_names_to_csv(folder_path, csv_path): + """ + Traverse a folder and write the base names of all .wav files to a CSV file. + + :param folder_path: Path to the folder to traverse. + :param csv_path: Path to the CSV file to write. + """ + # Open the CSV file for writing + with open(csv_path, mode='w', newline='') as file: + writer = csv.writer(file) + # Write the header + writer.writerow(['id', 'type']) + + # Walk through the folder + for root, dirs, files in os.walk(folder_path): + for file in files: + # Check if the file ends with .wav + if file.endswith('.wav'): + # Extract the base name without the extension + base_name = os.path.splitext(file)[0] + # Write the base name and type to the CSV + writer.writerow([base_name, 'test']) + +def resize_motion_sequence_tensor(sequence, target_frames): + """ + Resize a batch of 8-frame motion sequences to a specified number of frames using interpolation. + + :param sequence: A (bs, 8, 165) tensor representing a batch of 8-frame motion sequences + :param target_frames: An integer representing the desired number of frames in the output sequences + :return: A (bs, target_frames, 165) tensor representing the resized motion sequences + """ + bs, _, _ = sequence.shape + + # Create a time vector for the original and target sequences + original_time = torch.linspace(0, 1, 8, device=sequence.device).view(1, -1, 1) + target_time = torch.linspace(0, 1, target_frames, device=sequence.device).view(1, -1, 1) + + # Permute the dimensions to (bs, 165, 8) for interpolation + sequence = sequence.permute(0, 2, 1) + + # Interpolate each joint's motion to the target number of frames + resized_sequence = torch.nn.functional.interpolate(sequence, size=target_frames, mode='linear', align_corners=True) + + # Permute the dimensions back to (bs, target_frames, 165) + resized_sequence = resized_sequence.permute(0, 2, 1) + + return resized_sequence + +def adjust_speed_according_to_ratio_tensor(chunks): + """ + Adjust the playback speed within a batch of 32-frame chunks according to random intervals. + + :param chunks: A (bs, 32, 165) tensor representing a batch of motion chunks + :return: A (bs, 32, 165) tensor representing the motion chunks after speed adjustment + """ + bs, _, _ = chunks.shape + + # Step 1: Divide the chunk into 4 equal intervals of 8 frames + equal_intervals = torch.chunk(chunks, 4, dim=1) + + # Step 2: Randomly sample 3 points within the chunk to determine new intervals + success = 0 + all_success = [] + #sample_points = torch.sort(torch.randint(1, 32, (bs, 3), device=chunks.device), dim=1).values + # new_intervals_boundaries = torch.cat([torch.zeros((bs, 1), device=chunks.device, dtype=torch.long), sample_points, 32*torch.ones((bs, 1), device=chunks.device, dtype=torch.long)], dim=1) + while success != 1: + sample_points = sorted(random.sample(range(1, 32), 3)) + new_intervals_boundaries = [0] + sample_points + [32] + new_intervals = [chunks[0][new_intervals_boundaries[i]:new_intervals_boundaries[i+1]] for i in range(4)] + speed_ratios = [8 / len(new_interval) for new_interval in new_intervals] + # if any of the speed ratios is greater than 3 or less than 0.33, resample + if all([0.33 <= speed_ratio <= 3 for speed_ratio in speed_ratios]): + success += 1 + all_success.append(new_intervals_boundaries) + new_intervals_boundaries = torch.from_numpy(np.array(all_success)) + # print(new_intervals_boundaries) + all_shapes = new_intervals_boundaries[:, 1:] - new_intervals_boundaries[:, :-1] + # Step 4: Adjust the speed of each new interval + adjusted_intervals = [] + # print(equal_intervals[0].shape) + for i in range(4): + adjusted_interval = resize_motion_sequence_tensor(equal_intervals[i], all_shapes[0, i]) + adjusted_intervals.append(adjusted_interval) + + # Step 5: Concatenate the adjusted intervals + adjusted_chunk = torch.cat(adjusted_intervals, dim=1) + + return adjusted_chunk + +def compute_exact_iou(bbox1, bbox2): + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) + y2 = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) + + intersection_area = max(0, x2 - x1) * max(0, y2 - y1) + bbox1_area = bbox1[2] * bbox1[3] + bbox2_area = bbox2[2] * bbox2[3] + union_area = bbox1_area + bbox2_area - intersection_area + + if union_area == 0: + return 0 + + return intersection_area / union_area + +def compute_iou(mask1, mask2): + # Compute the intersection + intersection = np.logical_and(mask1, mask2).sum() + + # Compute the union + union = np.logical_or(mask1, mask2).sum() + + # Compute the IoU + iou = intersection / union + + return iou + +def blankblending(all_frames, x, n): + return all_frames[x:x+n+1] + +def synthesize_intermediate_frames_FILM(frame1, frame2, t, name, save_path): + import replicate + from urllib.request import urlretrieve + import os + cv2.imwrite(save_path[:-9]+name+"_frame1.png", frame1) + cv2.imwrite(save_path[:-9]+name+"_frame2.png", frame2) + os.environ["REPLICATE_API_TOKEN"] = "r8_He1rkPk9GAxNQ3LpOohK8sYw1SUfMYV3Fxk9b" + output = replicate.run( + "google-research/frame-interpolation:4f88a16a13673a8b589c18866e540556170a5bcb2ccdc12de556e800e9456d3d", + input={ + "frame1": open(save_path[:-9]+name+"_frame1.png", "rb"), + "frame2": open(save_path[:-9]+name+"_frame2.png", "rb"), + "times_to_interpolate": t, + } + ) + print(output) + urlretrieve(output, save_path[:-9]+name+"_inter.mp4") + return load_video_as_numpy_array(save_path[:-9]+name+"_inter.mp4") + +def load_video_as_numpy_array(video_path): + cap = cv2.VideoCapture(video_path) + + # Using list comprehension to read frames and store in a list + frames = [frame for ret, frame in iter(lambda: cap.read(), (False, None)) if ret] + + cap.release() + + return np.array(frames) + +def synthesize_intermediate_frames_bidirectional(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + # Convert the frames to grayscale + gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # Calculate the forward and backward optical flow + forward_flow = cv2.calcOpticalFlowFarneback(gray1, gray2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + backward_flow = cv2.calcOpticalFlowFarneback(gray2, gray1, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / n # Interpolation factor + + # Compute the intermediate forward and backward flow + intermediate_forward_flow = forward_flow * alpha + intermediate_backward_flow = backward_flow * (1 - alpha) + + # Warp the frames based on the intermediate flow + h, w = frame1.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + forward_displacement = flow_map + intermediate_forward_flow.reshape(-1, 2) + backward_displacement = flow_map - intermediate_backward_flow.reshape(-1, 2) + + # Use cv2.remap for efficient warping + remap_x_forward, remap_y_forward = np.clip(forward_displacement[:, 1], 0, w - 1), np.clip(forward_displacement[:, 0], 0, h - 1) + remap_x_backward, remap_y_backward = np.clip(backward_displacement[:, 1], 0, w - 1), np.clip(backward_displacement[:, 0], 0, h - 1) + + warped_forward = cv2.remap(frame1, remap_x_forward.reshape(h, w).astype(np.float32), remap_y_forward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + warped_backward = cv2.remap(frame2, remap_x_backward.reshape(h, w).astype(np.float32), remap_y_backward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + + # Blend the warped frames to generate the intermediate frame + intermediate_frame = cv2.addWeighted(warped_forward, 1 - alpha, warped_backward, alpha, 0) + synthesized_frames.append(intermediate_frame) + + return synthesized_frames # Return n-2 synthesized intermediate frames + + +def linear_interpolate_frames(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / (n) # Correct interpolation factor + inter_frame = cv2.addWeighted(frame1, 1 - alpha, frame2, alpha, 0) + synthesized_frames.append(inter_frame) + return synthesized_frames[:-1] + +def warp_frame(src_frame, flow): + h, w = flow.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + displacement = flow_map + flow.reshape(-1, 2) + + # Extract x and y coordinates of the displacement + x_coords = np.clip(displacement[:, 1], 0, w - 1).reshape(h, w).astype(np.float32) + y_coords = np.clip(displacement[:, 0], 0, h - 1).reshape(h, w).astype(np.float32) + + # Use cv2.remap for efficient warping + warped_frame = cv2.remap(src_frame, x_coords, y_coords, interpolation=cv2.INTER_LINEAR) + + return warped_frame + +def synthesize_intermediate_frames(all_frames, x, n): + # Calculate Optical Flow between the first and last frame + frame1 = cv2.cvtColor(all_frames[x], cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(all_frames[x + n], cv2.COLOR_BGR2GRAY) + flow = cv2.calcOpticalFlowFarneback(frame1, frame2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame + alpha = i / (n) # Interpolation factor + intermediate_flow = flow * alpha # Interpolate the flow + intermediate_frame = warp_frame(all_frames[x], intermediate_flow) # Warp the first frame + synthesized_frames.append(intermediate_frame) + + return synthesized_frames + + +def map2color(s): + m = hashlib.md5() + m.update(s.encode('utf-8')) + color_code = m.hexdigest()[:6] + return '#' + color_code + +def euclidean_distance(a, b): + return np.sqrt(np.sum((a - b)**2)) + +def adjust_array(x, k): + len_x = len(x) + len_k = len(k) + + # If x is shorter than k, pad with zeros + if len_x < len_k: + return np.pad(x, (0, len_k - len_x), 'constant') + + # If x is longer than k, truncate x + elif len_x > len_k: + return x[:len_k] + + # If both are of same length + else: + return x + +def onset_to_frame(onset_times, audio_length, fps): + # Calculate total number of frames for the given audio length + total_frames = int(audio_length * fps) + + # Create an array of zeros of shape (total_frames,) + frame_array = np.zeros(total_frames, dtype=np.int32) + + # For each onset time, calculate the frame number and set it to 1 + for onset in onset_times: + frame_num = int(onset * fps) + # Check if the frame number is within the array bounds + if 0 <= frame_num < total_frames: + frame_array[frame_num] = 1 + + return frame_array + +# def np_slerp(q1, q2, t): +# dot_product = np.sum(q1 * q2, axis=-1) +# q2_flip = np.where(dot_product[:, None] < 0, -q2, q2) # Flip quaternions where dot_product is negative +# dot_product = np.abs(dot_product) + +# angle = np.arccos(np.clip(dot_product, -1, 1)) +# sin_angle = np.sin(angle) + +# t1 = np.sin((1.0 - t) * angle) / sin_angle +# t2 = np.sin(t * angle) / sin_angle + +# return t1 * q1 + t2 * q2_flip + + +def smooth_rotvec_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using SLERP. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + n1, k1 = animation1.shape + n2, k2 = animation2.shape + animation1 = animation1.reshape(n1, k1//3, 3) + animation2 = animation2.reshape(n2, k2//3, 3) + blend_frames = min(blend_frames, len(animation1), len(animation2)) + all_int = [] + for i in range(k1//3): + # Convert rotation vectors to quaternion for the overlapping part + q = R.from_rotvec(np.concatenate([animation1[0:1, i], animation2[-2:-1, i]], axis=0))#.as_quat() + # q2 = R.from_rotvec()#.as_quat() + times = [0, blend_frames * 2 - 1] + slerp = Slerp(times, q) + interpolated = slerp(np.arange(blend_frames * 2)) + interpolated_rotvecs = interpolated.as_rotvec() + all_int.append(interpolated_rotvecs) + interpolated_rotvecs = np.concatenate(all_int, axis=1) + # result = np.vstack((animation1[:-blend_frames], interpolated_rotvecs, animation2[blend_frames:])) + result = interpolated_rotvecs.reshape(2*n1, k1) + return result + +def smooth_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using linear interpolation. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + blend_frames = min(blend_frames, len(animation1), len(animation2)) + + # Extract overlapping sections + overlap_a1 = animation1[-blend_frames:-blend_frames+1, :] + overlap_a2 = animation2[blend_frames-1:blend_frames, :] + + # Create blend weights for linear interpolation + alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1) + + # Linearly interpolate between overlapping sections + blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha + + # Extend the animations to form the result with 2n frames + if blend_frames == len(animation1) and blend_frames == len(animation2): + result = blended_overlap + else: + before_blend = animation1[:-blend_frames] + after_blend = animation2[blend_frames:] + result = np.vstack((before_blend, blended_overlap, after_blend)) + return result + +def interpolate_sequence(quaternions): + bs, n, j, _ = quaternions.shape + new_n = 2 * n + new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype) + + for i in range(n): + q1 = quaternions[:, i, :, :] + new_quaternions[:, 2*i, :, :] = q1 + + if i < n - 1: + q2 = quaternions[:, i + 1, :, :] + new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5) + else: + # For the last point, duplicate the value + new_quaternions[:, 2*i + 1, :, :] = q1 + + return new_quaternions + +def quaternion_multiply(q1, q2): + w1, x1, y1, z1 = q1 + w2, x2, y2, z2 = q2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return w, x, y, z + +def quaternion_conjugate(q): + w, x, y, z = q + return (w, -x, -y, -z) + +def slerp(q1, q2, t): + dot = torch.sum(q1 * q2, dim=-1, keepdim=True) + + flip = (dot < 0).float() + q2 = (1 - flip * 2) * q2 + dot = dot * (1 - flip * 2) + + DOT_THRESHOLD = 0.9995 + mask = (dot > DOT_THRESHOLD).float() + + theta_0 = torch.acos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / torch.norm(q3, dim=-1, keepdim=True) + + interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3) + + return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated + +def estimate_linear_velocity(data_seq, dt): + ''' + Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates + the velocity for the middle T-2 steps using a second order central difference scheme. + The first and last frames are with forward and backward first-order + differences, respectively + - h : step size + ''' + # first steps is forward diff (t+1 - t) / dt + init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt + + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) + return vel_seq + +def velocity2position(data_seq, dt, init_pos): + res_trans = [] + for i in range(data_seq.shape[1]): + if i == 0: + res_trans.append(init_pos.unsqueeze(1)) + else: + res = data_seq[:, i-1:i] * dt + res_trans[-1] + res_trans.append(res) + return torch.cat(res_trans, dim=1) + +def estimate_angular_velocity(rot_seq, dt): + ''' + Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. + Input sequence should be of shape (B, T, ..., 3, 3) + ''' + # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix + dRdt = estimate_linear_velocity(rot_seq, dt) + R = rot_seq + RT = R.transpose(-1, -2) + # compute skew-symmetric angular velocity tensor + w_mat = torch.matmul(dRdt, RT) + # pull out angular velocity vector by averaging symmetric entries + w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 + w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 + w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 + w = torch.stack([w_x, w_y, w_z], axis=-1) + return w + +def image_from_bytes(image_bytes): + import matplotlib.image as mpimg + from io import BytesIO + return mpimg.imread(BytesIO(image_bytes), format='PNG') + +def process_frame(i, vertices_all, vertices1_all, faces, output_dir, filenames): + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import trimesh + import pyrender + + def deg_to_rad(degrees): + return degrees * np.pi / 180 + + uniform_color = [220, 220, 220, 255] + resolution = (1000, 1000) + figsize = (10, 10) + + fig, axs = plt.subplots( + nrows=1, + ncols=2, + figsize=(figsize[0] * 2, figsize[1] * 1) + ) + axs = axs.flatten() + + vertices = vertices_all[i] + vertices1 = vertices1_all[i] + filename = f"{output_dir}frame_{i}.png" + filenames.append(filename) + if i%100 == 0: + print('processed', i, 'frames') + #time_s = time.time() + #print(vertices.shape) + angle_rad = deg_to_rad(-2) + pose_camera = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 1.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 5.0], + [0.0, 0.0, 0.0, 1.0] + ]) + angle_rad = deg_to_rad(-30) + pose_light = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 0.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 3.0], + [0.0, 0.0, 0.0, 1.0] + ]) + + for vtx_idx, vtx in enumerate([vertices, vertices1]): + trimesh_mesh = trimesh.Trimesh( + vertices=vtx, + faces=faces, + vertex_colors=uniform_color + ) + mesh = pyrender.Mesh.from_trimesh( + trimesh_mesh, smooth=True + ) + scene = pyrender.Scene() + scene.add(mesh) + camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + scene.add(camera, pose=pose_camera) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=4.0) + scene.add(light, pose=pose_light) + renderer = pyrender.OffscreenRenderer(*resolution) + color, _ = renderer.render(scene) + axs[vtx_idx].imshow(color) + axs[vtx_idx].axis('off') + renderer.delete() + + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + +def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, filenames): + import multiprocessing + # import trimesh + num_cores = multiprocessing.cpu_count() - 1 # This will get the number of cores on your machine. + # mesh = trimesh.Trimesh(vertices_all[0], faces) + # scene = mesh.scene() + # fov = scene.camera.fov.copy() + # fov[0] = 80.0 + # fov[1] = 60.0 + # camera_params = { + # 'fov': fov, + # 'resolution': scene.camera.resolution, + # 'focal': scene.camera.focal, + # 'z_near': scene.camera.z_near, + # "z_far": scene.camera.z_far, + # 'transform': scene.graph[scene.camera.name][0] + # } + # mesh1 = trimesh.Trimesh(vertices1_all[0], faces) + # scene1 = mesh1.scene() + # camera_params1 = { + # 'fov': fov, + # 'resolution': scene1.camera.resolution, + # 'focal': scene1.camera.focal, + # 'z_near': scene1.camera.z_near, + # "z_far": scene1.camera.z_far, + # 'transform': scene1.graph[scene1.camera.name][0] + # } + # Use a Pool to manage the processes + # print(num_cores) + # for i in range(frames): + # process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) + for i in range(frames): + process_frame(i*3, vertices_all, vertices1_all, faces, output_dir, filenames) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + +def render_one_sequence( + res_npz_path, + gt_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import moviepy.editor as mp + import librosa + + model = smplx.create(model_folder, model_type=model_type, + gender=gender, use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, use_pca=False).cuda() + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + filenames = [] + # if not use_matplotlib: + # import trimesh + #import pyrender + from pyvirtualdisplay import Display + #''' + #display = Display(visible=0, size=(1000, 1000)) + #display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).cuda() + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).cuda() + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).cuda() + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).cuda() + pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).cuda() + transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).cuda() + output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + leye_pose=pose1[:, 69:72], + reye_pose=pose1[:, 72:75],return_verts=True) + vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + # camera_settings = None + time_s = time.time() + generate_images(int(seconds*10), vertices_all, vertices1_all, faces, output_dir, filenames) + filenames = ["{}frame_{}.png".format(output_dir, i*3) for i in range(int(seconds*10))] + # print(time.time()-time_s) + # for i in tqdm(range(seconds*10)): + # vertices = vertices_all[i] + # vertices1 = vertices1_all[i] + # filename = f"{output_dir}frame_{i}.png" + # filenames.append(filename) + # #time_s = time.time() + # #print(vertices.shape) + # if use_matplotlib: + # fig = plt.figure(figsize=(20, 10)) + # ax = fig.add_subplot(121, projection="3d") + # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + # #ax.view_init(elev=0, azim=90) + # x = vertices[:, 0] + # y = vertices[:, 1] + # z = vertices[:, 2] + # ax.scatter(x, y, z, s=0.5) + # ax.set_xlim([-1.0, 1.0]) + # ax.set_ylim([-0.5, 1.5])#heigth + # ax.set_zlim([-0, 2])#depth + # ax.set_box_aspect((1,1,1)) + # else: + # mesh = trimesh.Trimesh(vertices, faces) + # if i == 0: + # scene = mesh.scene() + # camera_params = { + # 'fov': scene.camera.fov, + # 'resolution': scene.camera.resolution, + # 'focal': scene.camera.focal, + # 'z_near': scene.camera.z_near, + # "z_far": scene.camera.z_far, + # 'transform': scene.graph[scene.camera.name][0] + # } + # else: + # scene = mesh.scene() + # scene.camera.fov = camera_params['fov'] + # scene.camera.resolution = camera_params['resolution'] + # scene.camera.z_near = camera_params['z_near'] + # scene.camera.z_far = camera_params['z_far'] + # scene.graph[scene.camera.name] = camera_params['transform'] + # fig, ax =plt.subplots(1,2, figsize=(16, 6)) + # image = scene.save_image(resolution=[640, 480], visible=False) + # #print((time.time()-time_s)) + # im0 = ax[0].imshow(image_from_bytes(image)) + # ax[0].axis('off') + + # # beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0) + # # expression1 = torch.from_numpy(gt_np_body["expressions"][i]).to(torch.float32).unsqueeze(0) + # # jaw_pose1 = torch.from_numpy(gt_np_body["poses"][i][66:69]).to(torch.float32).unsqueeze(0) + # # pose1 = torch.from_numpy(gt_np_body["poses"][i]).to(torch.float32).unsqueeze(0) + # # transl1 = torch.from_numpy(gt_np_body["trans"][i]).to(torch.float32).unsqueeze(0) + # # #print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape)global_orient=pose[0:1,:3], + # # output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[0:1,:3], body_pose=pose1[0:1,3:21*3+3], left_hand_pose=pose1[0:1,25*3:40*3], right_hand_pose=pose1[0:1,40*3:55*3], return_verts=True) + # # vertices1 = output1["vertices"].cpu().detach().numpy()[0] + + # if use_matplotlib: + # ax2 = fig.add_subplot(122, projection="3d") + # ax2.set_box_aspect((1,1,1)) + # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + # #ax2.view_init(elev=0, azim=90) + # x1 = vertices1[:, 0] + # y1 = vertices1[:, 1] + # z1 = vertices1[:, 2] + # ax2.scatter(x1, y1, z1, s=0.5) + # ax2.set_xlim([-1.0, 1.0]) + # ax2.set_ylim([-0.5, 1.5])#heigth + # ax2.set_zlim([-0, 2]) + # plt.savefig(filename, bbox_inches='tight') + # plt.close(fig) + # else: + # mesh1 = trimesh.Trimesh(vertices1, faces) + # if i == 0: + # scene1 = mesh1.scene() + # camera_params1 = { + # 'fov': scene1.camera.fov, + # 'resolution': scene1.camera.resolution, + # 'focal': scene1.camera.focal, + # 'z_near': scene1.camera.z_near, + # "z_far": scene1.camera.z_far, + # 'transform': scene1.graph[scene1.camera.name][0] + # } + # else: + # scene1 = mesh1.scene() + # scene1.camera.fov = camera_params1['fov'] + # scene1.camera.resolution = camera_params1['resolution'] + # scene1.camera.z_near = camera_params1['z_near'] + # scene1.camera.z_far = camera_params1['z_far'] + # scene1.graph[scene1.camera.name] = camera_params1['transform'] + # image1 = scene1.save_image(resolution=[640, 480], visible=False) + # im1 = ax[1].imshow(image_from_bytes(image1)) + # ax[1].axis('off') + # plt.savefig(filename, bbox_inches='tight') + # plt.close(fig) + + #display.stop() + #''' + # print(filenames) + images = [imageio.imread(filename) for filename in filenames] + imageio.mimsave(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4", images, fps=10) + for filename in filenames: + os.remove(filename) + + video = mp.VideoFileClip(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4") + # audio, sr = librosa.load(audio_path) + # audio = audio[:seconds*sr] + # print(audio.shape, seconds, sr) + # import soundfile as sf + # sf.write(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, 16000, 'PCM_24') + # audio_tmp = librosa.output.write_wav(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, sr=16000) + audio = mp.AudioFileClip(audio_path) + if audio.duration > video.duration: + audio = audio.subclip(0, video.duration) + final_clip = video.set_audio(audio) + final_clip.write_videofile(f"{output_dir}{res_npz_path.split('/')[-1][4:-4]}.mp4") + os.remove(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4") + +def print_exp_info(args): + logger.info(pprint.pformat(vars(args))) + logger.info(f"# ------------ {args.name} ----------- #") + logger.info("PyTorch version: {}".format(torch.__version__)) + logger.info("CUDA version: {}".format(torch.version.cuda)) + logger.info("{} GPUs".format(torch.cuda.device_count())) + logger.info(f"Random Seed: {args.random_seed}") + +def args2csv(args, get_head=False, list4print=[]): + for k, v in args.items(): + if isinstance(args[k], dict): + args2csv(args[k], get_head, list4print) + else: list4print.append(k) if get_head else list4print.append(v) + return list4print + +class EpochTracker: + def __init__(self, metric_names, metric_directions): + assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length" + + + self.metric_names = metric_names + self.states = ['train', 'val', 'test'] + self.types = ['last', 'best'] + + + self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0} + for type_ in self.types} + for state in self.states} + for name, is_higher_better in zip(metric_names, metric_directions)} + + self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}") + for state in self.states} + for name in metric_names} + + + self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)} + self.train_history = {name: [] for name in metric_names} + self.val_history = {name: [] for name in metric_names} + + + def update_meter(self, name, state, value): + self.loss_meters[name][state].update(value) + + + def update_values(self, name, state, epoch): + value_avg = self.loss_meters[name][state].avg + new_best = False + + + if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or + (value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])): + self.values[name][state]['best']['value'] = value_avg + self.values[name][state]['best']['epoch'] = epoch + new_best = True + self.values[name][state]['last']['value'] = value_avg + self.values[name][state]['last']['epoch'] = epoch + return new_best + + + def get(self, name, state, type_): + return self.values[name][state][type_] + + + def reset(self): + for name in self.metric_names: + for state in self.states: + self.loss_meters[name][state].reset() + + + def flatten_values(self): + flat_dict = {} + for name in self.metric_names: + for state in self.states: + for type_ in self.types: + value_key = f"{name}_{state}_{type_}" + epoch_key = f"{name}_{state}_{type_}_epoch" + flat_dict[value_key] = self.values[name][state][type_]['value'] + flat_dict[epoch_key] = self.values[name][state][type_]['epoch'] + return flat_dict + + def update_and_plot(self, name, epoch, save_path): + new_best_train = self.update_values(name, 'train', epoch) + new_best_val = self.update_values(name, 'val', epoch) + + + self.train_history[name].append(self.loss_meters[name]['train'].avg) + self.val_history[name].append(self.loss_meters[name]['val'].avg) + + + train_values = self.train_history[name] + val_values = self.val_history[name] + epochs = list(range(1, len(train_values) + 1)) + + + plt.figure(figsize=(10, 6)) + plt.plot(epochs, train_values, label='Train') + plt.plot(epochs, val_values, label='Val') + plt.title(f'Train vs Val {name} over epochs') + plt.xlabel('Epochs') + plt.ylabel(name) + plt.legend() + plt.savefig(save_path) + plt.close() + + + return new_best_train, new_best_val + +def record_trial(args, tracker): + """ + 1. record notes, score, env_name, experments_path, + """ + csv_path = args.out_path + "custom/" +args.csv_name+".csv" + all_print_dict = vars(args) + all_print_dict.update(tracker.flatten_values()) + if not os.path.exists(csv_path): + pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False) + else: + df_existing = pd.read_csv(csv_path) + df_new = pd.DataFrame([all_print_dict]) + df_aligned = df_existing.append(df_new).fillna("") + df_aligned.to_csv(csv_path, index=False) + +def set_random_seed(args): + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + np.random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + torch.cuda.manual_seed_all(args.random_seed) + torch.cuda.manual_seed(args.random_seed) + torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC + torch.backends.cudnn.benchmark = args.benchmark + torch.backends.cudnn.enabled = args.cudnn_enabled + +def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None): + if lrs is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(), + 'lrs':lrs.state_dict(),} + elif opt is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(),} + else: + states = { 'model_state': model.state_dict(),} + torch.save(states, save_path) + +def load_checkpoints(model, save_path, load_name='model'): + states = torch.load(save_path) + new_weights = OrderedDict() + flag=False + for k, v in states['model_state'].items(): + #print(k) + if "module" not in k: + break + else: + new_weights[k[7:]]=v + flag=True + if flag: + try: + model.load_state_dict(new_weights) + except: + #print(states['model_state']) + model.load_state_dict(states['model_state']) + else: + model.load_state_dict(states['model_state']) + logger.info(f"load self-pretrained checkpoints for {load_name}") + +def model_complexity(model, args): + from ptflops import get_model_complexity_info + flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN), + as_strings=False, print_per_layer_stat=False) + logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9)) + logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6)) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) diff --git a/dataloaders/utils/rotation_conversions.py b/dataloaders/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bfaa1b2247622bff35d3f9b15e8eb84064aa53 --- /dev/null +++ b/dataloaders/utils/rotation_conversions.py @@ -0,0 +1,550 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/datasets/beat_cache/beat_smplx_en_emage_test/test/smplxflame_30_cache/data.mdb b/datasets/beat_cache/beat_smplx_en_emage_test/test/smplxflame_30_cache/data.mdb new file mode 100644 index 0000000000000000000000000000000000000000..0d662d186ed0e4ef934911bfc5b984880db9a384 Binary files /dev/null and b/datasets/beat_cache/beat_smplx_en_emage_test/test/smplxflame_30_cache/data.mdb differ diff --git a/datasets/beat_cache/beat_smplx_en_emage_test/test/smplxflame_30_cache/lock.mdb b/datasets/beat_cache/beat_smplx_en_emage_test/test/smplxflame_30_cache/lock.mdb new file mode 100644 index 0000000000000000000000000000000000000000..f1ae3c350a05efc4c1a6f29555ac22620def584c Binary files /dev/null and b/datasets/beat_cache/beat_smplx_en_emage_test/test/smplxflame_30_cache/lock.mdb differ diff --git a/emage_trainer.py b/emage_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..292bbf6f797ed3a2a8afa41d20cd5619f15bf98e --- /dev/null +++ b/emage_trainer.py @@ -0,0 +1,907 @@ +import train +import os +import time +import csv +import sys +import warnings +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP +import numpy as np +import time +import pprint +from loguru import logger +from utils import rotation_conversions as rc +import smplx +from utils import config, logger_tools, other_tools, metric, data_transfer +from dataloaders import data_tools +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from dataloaders.data_tools import joints_list +import librosa + +class CustomTrainer(train.BaseTrainer): + def __init__(self, args): + super().__init__(args) + self.args = args + self.joints = self.train_data.joints + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list_face = joints_list["beat_smplx_face"] + self.tar_joint_list_upper = joints_list["beat_smplx_upper"] + self.tar_joint_list_hands = joints_list["beat_smplx_hands"] + self.tar_joint_list_lower = joints_list["beat_smplx_lower"] + + self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = 55 + for joint_name in self.tar_joint_list_face: + self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_upper: + self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_hands: + self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_lower: + self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + + self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False,False,False]) + + vq_model_module = __import__(f"models.motion_representation", fromlist=["something"]) + self.args.vae_layer = 2 + self.args.vae_length = 256 + self.args.vae_test_dim = 106 + self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + # print(self.vq_model_face) + other_tools.load_checkpoints(self.vq_model_face, self.args.data_path_1 + "pretrained_vq/last_790_face_v2.bin", args.e_name) + self.args.vae_test_dim = 78 + self.vq_model_upper = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.vq_model_upper, self.args.data_path_1 + "pretrained_vq/upper_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 180 + self.vq_model_hands = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.vq_model_hands, self.args.data_path_1 + "pretrained_vq/hands_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.vq_model_lower, self.args.data_path_1 + "pretrained_vq/lower_foot_600.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.global_motion = getattr(vq_model_module, "VAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.global_motion, self.args.data_path_1 + "pretrained_vq/last_1700_foot.bin", args.e_name) + self.args.vae_test_dim = 330 + self.args.vae_layer = 4 + self.args.vae_length = 240 + + self.vq_model_face.eval() + self.vq_model_upper.eval() + self.vq_model_hands.eval() + self.vq_model_lower.eval() + self.global_motion.eval() + + self.cls_loss = nn.NLLLoss().to(self.rank) + self.reclatent_loss = nn.MSELoss().to(self.rank) + self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank) + self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank) + self.log_softmax = nn.LogSoftmax(dim=2).to(self.rank) + + + def inverse_selection(self, filtered_t, selection_array, n): + original_shape_t = np.zeros((n, selection_array.size)) + selected_indices = np.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + selected_indices = torch.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def _load_data(self, dict_data): + tar_pose_raw = dict_data["pose"] + tar_pose = tar_pose_raw[:, :, :165].to(self.rank) + tar_contact = tar_pose_raw[:, :, 165:169].to(self.rank) + tar_trans = dict_data["trans"].to(self.rank) + tar_exps = dict_data["facial"].to(self.rank) + in_audio = dict_data["audio"].to(self.rank) + in_word = dict_data["word"].to(self.rank) + tar_beta = dict_data["beta"].to(self.rank) + tar_id = dict_data["id"].to(self.rank).long() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + # tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2) + + tar_index_value_face_top = self.vq_model_face.map2index(tar_pose_face) # bs*n/4 + tar_index_value_upper_top = self.vq_model_upper.map2index(tar_pose_upper) # bs*n/4 + tar_index_value_hands_top = self.vq_model_hands.map2index(tar_pose_hands) # bs*n/4 + tar_index_value_lower_top = self.vq_model_lower.map2index(tar_pose_lower) # bs*n/4 + + latent_face_top = self.vq_model_face.map2latent(tar_pose_face) # bs*n/4 + latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper) # bs*n/4 + latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands) # bs*n/4 + latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower) # bs*n/4 + + latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) + + index_in = torch.stack([tar_index_value_upper_top, tar_index_value_hands_top, tar_index_value_lower_top], dim=-1).long() + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + # print(tar_index_value_upper_top.shape, index_in.shape) + return { + "tar_pose_jaw": tar_pose_jaw, + "tar_pose_face": tar_pose_face, + "tar_pose_upper": tar_pose_upper, + "tar_pose_lower": tar_pose_lower, + "tar_pose_hands": tar_pose_hands, + 'tar_pose_leg': tar_pose_leg, + "in_audio": in_audio, + "in_word": in_word, + "tar_trans": tar_trans, + "tar_exps": tar_exps, + "tar_beta": tar_beta, + "tar_pose": tar_pose, + "tar4dis": tar4dis, + "tar_index_value_face_top": tar_index_value_face_top, + "tar_index_value_upper_top": tar_index_value_upper_top, + "tar_index_value_hands_top": tar_index_value_hands_top, + "tar_index_value_lower_top": tar_index_value_lower_top, + "latent_face_top": latent_face_top, + "latent_upper_top": latent_upper_top, + "latent_hands_top": latent_hands_top, + "latent_lower_top": latent_lower_top, + "latent_in": latent_in, + "index_in": index_in, + "tar_id": tar_id, + "latent_all": latent_all, + "tar_pose_6d": tar_pose_6d, + "tar_contact": tar_contact, + } + + def _g_training(self, loaded_data, use_adv, mode="train", epoch=0): + bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints + # ------ full generatation task ------ # + mask_val = torch.ones(bs, n, self.args.pose_dims+3+4).float().cuda() + mask_val[:, :self.args.pre_frames, :] = 0.0 + + net_out_val = self.model( + loaded_data['in_audio'], loaded_data['in_word'], mask=mask_val, + in_id = loaded_data['tar_id'], in_motion = loaded_data['latent_all'], + use_attentions = True) + g_loss_final = 0 + loss_latent_face = self.reclatent_loss(net_out_val["rec_face"], loaded_data["latent_face_top"]) + loss_latent_lower = self.reclatent_loss(net_out_val["rec_lower"], loaded_data["latent_lower_top"]) + loss_latent_hands = self.reclatent_loss(net_out_val["rec_hands"], loaded_data["latent_hands_top"]) + loss_latent_upper = self.reclatent_loss(net_out_val["rec_upper"], loaded_data["latent_upper_top"]) + loss_latent = self.args.lf*loss_latent_face + self.args.ll*loss_latent_lower + self.args.lh*loss_latent_hands + self.args.lu*loss_latent_upper + self.tracker.update_meter("latent", "train", loss_latent.item()) + g_loss_final += loss_latent + + rec_index_face_val = self.log_softmax(net_out_val["cls_face"]).reshape(-1, self.args.vae_codebook_size) + rec_index_upper_val = self.log_softmax(net_out_val["cls_upper"]).reshape(-1, self.args.vae_codebook_size) + rec_index_lower_val = self.log_softmax(net_out_val["cls_lower"]).reshape(-1, self.args.vae_codebook_size) + rec_index_hands_val = self.log_softmax(net_out_val["cls_hands"]).reshape(-1, self.args.vae_codebook_size) + tar_index_value_face_top = loaded_data["tar_index_value_face_top"].reshape(-1) + tar_index_value_upper_top = loaded_data["tar_index_value_upper_top"].reshape(-1) + tar_index_value_lower_top = loaded_data["tar_index_value_lower_top"].reshape(-1) + tar_index_value_hands_top = loaded_data["tar_index_value_hands_top"].reshape(-1) + loss_cls = self.args.cf*self.cls_loss(rec_index_face_val, tar_index_value_face_top)\ + + self.args.cu*self.cls_loss(rec_index_upper_val, tar_index_value_upper_top)\ + + self.args.cl*self.cls_loss(rec_index_lower_val, tar_index_value_lower_top)\ + + self.args.ch*self.cls_loss(rec_index_hands_val, tar_index_value_hands_top) + self.tracker.update_meter("cls_full", "train", loss_cls.item()) + g_loss_final += loss_cls + + if mode == 'train': + # # ------ masked gesture moderling------ # + mask_ratio = (epoch / self.args.epochs) * 0.95 + 0.05 + mask = torch.rand(bs, n, self.args.pose_dims+3+4) < mask_ratio + mask = mask.float().cuda() + net_out_self = self.model( + loaded_data['in_audio'], loaded_data['in_word'], mask=mask, + in_id = loaded_data['tar_id'], in_motion = loaded_data['latent_all'], + use_attentions = True, use_word=False) + + loss_latent_face_self = self.reclatent_loss(net_out_self["rec_face"], loaded_data["latent_face_top"]) + loss_latent_lower_self = self.reclatent_loss(net_out_self["rec_lower"], loaded_data["latent_lower_top"]) + loss_latent_hands_self = self.reclatent_loss(net_out_self["rec_hands"], loaded_data["latent_hands_top"]) + loss_latent_upper_self = self.reclatent_loss(net_out_self["rec_upper"], loaded_data["latent_upper_top"]) + loss_latent_self = self.args.lf*loss_latent_face_self + self.args.ll*loss_latent_lower_self + self.args.lh*loss_latent_hands_self + self.args.lu*loss_latent_upper_self + self.tracker.update_meter("latent_self", "train", loss_latent_self.item()) + g_loss_final += loss_latent_self + rec_index_face_self = self.log_softmax(net_out_self["cls_face"]).reshape(-1, self.args.vae_codebook_size) + rec_index_upper_self = self.log_softmax(net_out_self["cls_upper"]).reshape(-1, self.args.vae_codebook_size) + rec_index_lower_self = self.log_softmax(net_out_self["cls_lower"]).reshape(-1, self.args.vae_codebook_size) + rec_index_hands_self = self.log_softmax(net_out_self["cls_hands"]).reshape(-1, self.args.vae_codebook_size) + index_loss_top_self = self.cls_loss(rec_index_face_self, tar_index_value_face_top) + self.cls_loss(rec_index_upper_self, tar_index_value_upper_top) + self.cls_loss(rec_index_lower_self, tar_index_value_lower_top) + self.cls_loss(rec_index_hands_self, tar_index_value_hands_top) + self.tracker.update_meter("cls_self", "train", index_loss_top_self.item()) + g_loss_final += index_loss_top_self + + # ------ masked audio gesture moderling ------ # + net_out_word = self.model( + loaded_data['in_audio'], loaded_data['in_word'], mask=mask, + in_id = loaded_data['tar_id'], in_motion = loaded_data['latent_all'], + use_attentions = True, use_word=True) + + loss_latent_face_word = self.reclatent_loss(net_out_word["rec_face"], loaded_data["latent_face_top"]) + loss_latent_lower_word = self.reclatent_loss(net_out_word["rec_lower"], loaded_data["latent_lower_top"]) + loss_latent_hands_word = self.reclatent_loss(net_out_word["rec_hands"], loaded_data["latent_hands_top"]) + loss_latent_upper_word = self.reclatent_loss(net_out_word["rec_upper"], loaded_data["latent_upper_top"]) + loss_latent_word = self.args.lf*loss_latent_face_word + self.args.ll*loss_latent_lower_word + self.args.lh*loss_latent_hands_word + self.args.lu*loss_latent_upper_word + self.tracker.update_meter("latent_word", "train", loss_latent_word.item()) + g_loss_final += loss_latent_word + + rec_index_face_word = self.log_softmax(net_out_word["cls_face"]).reshape(-1, self.args.vae_codebook_size) + rec_index_upper_word = self.log_softmax(net_out_word["cls_upper"]).reshape(-1, self.args.vae_codebook_size) + rec_index_lower_word = self.log_softmax(net_out_word["cls_lower"]).reshape(-1, self.args.vae_codebook_size) + rec_index_hands_word = self.log_softmax(net_out_word["cls_hands"]).reshape(-1, self.args.vae_codebook_size) + index_loss_top_word = self.cls_loss(rec_index_face_word, tar_index_value_face_top) + self.cls_loss(rec_index_upper_word, tar_index_value_upper_top) + self.cls_loss(rec_index_lower_word, tar_index_value_lower_top) + self.cls_loss(rec_index_hands_word, tar_index_value_hands_top) + self.tracker.update_meter("cls_word", "train", index_loss_top_word.item()) + g_loss_final += index_loss_top_word + + if mode != 'train': + if self.args.cu != 0: + _, rec_index_upper = torch.max(rec_index_upper_val.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + rec_upper = self.vq_model_upper.decode(rec_index_upper) + else: + _, rec_index_upper, _, _ = self.vq_model_upper.quantizer(net_out_val["rec_upper"]) + rec_upper = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + _, rec_index_lower = torch.max(rec_index_lower_val.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + rec_lower = self.vq_model_lower.decode(rec_index_lower) + else: + _, rec_index_lower, _, _ = self.vq_model_lower.quantizer(net_out_val["rec_lower"]) + rec_lower = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + _, rec_index_hands = torch.max(rec_index_hands_val.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + rec_hands = self.vq_model_hands.decode(rec_index_hands) + else: + _, rec_index_hands, _, _ = self.vq_model_hands.quantizer(net_out_val["rec_hands"]) + rec_hands = self.vq_model_hands.decoder(rec_index_hands) + if self.args.cf != 0: + _, rec_index_face = torch.max(rec_index_face_val.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + rec_face = self.vq_model_face.decode(rec_index_face) + else: + _, rec_index_face, _, _ = self.vq_model_face.quantizer(net_out_val["rec_face"]) + rec_face = self.vq_model_face.decoder(rec_index_face) + + rec_pose_jaw = rec_face[:, :, :6] + rec_pose_legs = rec_lower[:, :, :54] + + rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n) + rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n) + rec_pose_jaw = rec_pose_jaw.reshape(bs*n, 6) + rec_pose_jaw = rc.rotation_6d_to_matrix(rec_pose_jaw) + rec_pose_jaw = rc.matrix_to_axis_angle(rec_pose_jaw).reshape(bs*n, 1*3) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose[:, 66:69] = rec_pose_jaw + # print(rec_pose.shape, tar_pose.shape) + + # tar_trans = loaded_data["tar_trans"] + # rec_trans_v_s = rec_lower[:, :, 54:57] + # rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + # rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + # rec_y_trans = rec_trans_v_s[:,:,1:2] + # rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + # tar_pose = loaded_data["tar_pose"] + # tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + + + if mode == 'train': + return g_loss_final + elif mode == 'val': + return { + 'rec_pose': rec_pose, + # rec_trans': rec_pose_trans, + 'tar_pose': loaded_data["tar_pose_6d"], + } + else: + return { + 'rec_pose': rec_pose, + # 'rec_trans': rec_trans, + 'tar_pose': loaded_data["tar_pose"], + 'tar_exps': loaded_data["tar_exps"], + 'tar_beta': loaded_data["tar_beta"], + 'tar_trans': loaded_data["tar_trans"], + # 'rec_exps': rec_exps, + } + + + def _g_test(self, loaded_data): + mode = 'test' + bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints + tar_pose = loaded_data["tar_pose"] + tar_beta = loaded_data["tar_beta"] + in_word = loaded_data["in_word"] + tar_exps = loaded_data["tar_exps"] + tar_contact = loaded_data["tar_contact"] + in_audio = loaded_data["in_audio"] + tar_trans = loaded_data["tar_trans"] + + remain = n%8 + if remain != 0: + tar_pose = tar_pose[:, :-remain, :] + tar_beta = tar_beta[:, :-remain, :] + tar_trans = tar_trans[:, :-remain, :] + in_word = in_word[:, :-remain] + tar_exps = tar_exps[:, :-remain, :] + tar_contact = tar_contact[:, :-remain, :] + n = n - remain + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + + rec_index_all_face = [] + rec_index_all_upper = [] + rec_index_all_lower = [] + rec_index_all_hands = [] + # rec_index_all_face_bot = [] + # rec_index_all_upper_bot = [] + # rec_index_all_lower_bot = [] + # rec_index_all_hands_bot = [] + + roundt = (n - self.args.pre_frames) // (self.args.pose_length - self.args.pre_frames) + remain = (n - self.args.pre_frames) % (self.args.pose_length - self.args.pre_frames) + round_l = self.args.pose_length - self.args.pre_frames + + # pad latent_all_9 to the same length with latent_all + # if n - latent_all_9.shape[1] >= 0: + # latent_all = torch.cat([latent_all_9, torch.zeros(bs, n - latent_all_9.shape[1], latent_all_9.shape[2]).cuda()], dim=1) + # else: + # latent_all = latent_all_9[:, :n, :] + + for i in range(0, roundt): + in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames] + # audio fps is 16000 and pose fps is 30 + in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames] + in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames] + mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims+3+4).float().cuda() + mask_val[:, :self.args.pre_frames, :] = 0.0 + if i == 0: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames, :] + else: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames, :] + # print(latent_all_tmp.shape, latent_last.shape) + latent_all_tmp[:, :self.args.pre_frames, :] = latent_last[:, -self.args.pre_frames:, :] + + net_out_val = self.model( + in_audio = in_audio_tmp, + in_word=in_word_tmp, + mask=mask_val, + in_motion = latent_all_tmp, + in_id = in_id_tmp, + use_attentions=True,) + + if self.args.cu != 0: + rec_index_upper = self.log_softmax(net_out_val["cls_upper"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_upper = torch.max(rec_index_upper.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_upper = self.vq_model_upper.decode(rec_index_upper) + else: + _, rec_index_upper, _, _ = self.vq_model_upper.quantizer(net_out_val["rec_upper"]) + #rec_upper = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + rec_index_lower = self.log_softmax(net_out_val["cls_lower"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_lower = torch.max(rec_index_lower.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_lower = self.vq_model_lower.decode(rec_index_lower) + else: + _, rec_index_lower, _, _ = self.vq_model_lower.quantizer(net_out_val["rec_lower"]) + #rec_lower = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + rec_index_hands = self.log_softmax(net_out_val["cls_hands"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_hands = torch.max(rec_index_hands.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_hands = self.vq_model_hands.decode(rec_index_hands) + else: + _, rec_index_hands, _, _ = self.vq_model_hands.quantizer(net_out_val["rec_hands"]) + #rec_hands = self.vq_model_hands.decoder(rec_index_hands) + if self.args.cf != 0: + rec_index_face = self.log_softmax(net_out_val["cls_face"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_face = torch.max(rec_index_face.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_face = self.vq_model_face.decoder(rec_index_face) + else: + _, rec_index_face, _, _ = self.vq_model_face.quantizer(net_out_val["rec_face"]) + #rec_face = self.vq_model_face.decoder(rec_index_face) + + if i == 0: + rec_index_all_face.append(rec_index_face) + rec_index_all_upper.append(rec_index_upper) + rec_index_all_lower.append(rec_index_lower) + rec_index_all_hands.append(rec_index_hands) + else: + rec_index_all_face.append(rec_index_face[:, self.args.pre_frames:]) + rec_index_all_upper.append(rec_index_upper[:, self.args.pre_frames:]) + rec_index_all_lower.append(rec_index_lower[:, self.args.pre_frames:]) + rec_index_all_hands.append(rec_index_hands[:, self.args.pre_frames:]) + + if self.args.cu != 0: + rec_upper_last = self.vq_model_upper.decode(rec_index_upper) + else: + rec_upper_last = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + rec_lower_last = self.vq_model_lower.decode(rec_index_lower) + else: + rec_lower_last = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + rec_hands_last = self.vq_model_hands.decode(rec_index_hands) + else: + rec_hands_last = self.vq_model_hands.decoder(rec_index_hands) + # if self.args.cf != 0: + # rec_face_last = self.vq_model_face.decode(rec_index_face) + # else: + # rec_face_last = self.vq_model_face.decoder(rec_index_face) + + rec_pose_legs = rec_lower_last[:, :, :54] + bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] + rec_pose_upper = rec_upper_last.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n) + rec_pose_hands = rec_hands_last.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + rec_trans_v_s = rec_lower_last[:, :, 54:57] + rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + latent_last = torch.cat([rec_pose, rec_trans, rec_lower_last[:, :, 57:61]], dim=-1) + + rec_index_face = torch.cat(rec_index_all_face, dim=1) + rec_index_upper = torch.cat(rec_index_all_upper, dim=1) + rec_index_lower = torch.cat(rec_index_all_lower, dim=1) + rec_index_hands = torch.cat(rec_index_all_hands, dim=1) + if self.args.cu != 0: + rec_upper = self.vq_model_upper.decode(rec_index_upper) + else: + rec_upper = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + rec_lower = self.vq_model_lower.decode(rec_index_lower) + else: + rec_lower = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + rec_hands = self.vq_model_hands.decode(rec_index_hands) + else: + rec_hands = self.vq_model_hands.decoder(rec_index_hands) + if self.args.cf != 0: + rec_face = self.vq_model_face.decode(rec_index_face) + else: + rec_face = self.vq_model_face.decoder(rec_index_face) + + rec_exps = rec_face[:, :, 6:] + rec_pose_jaw = rec_face[:, :, :6] + rec_pose_legs = rec_lower[:, :, :54] + bs, n = rec_pose_jaw.shape[0], rec_pose_jaw.shape[1] + rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n) + rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n) + rec_pose_jaw = rec_pose_jaw.reshape(bs*n, 6) + rec_pose_jaw = rc.rotation_6d_to_matrix(rec_pose_jaw) + rec_pose_jaw = rc.matrix_to_axis_angle(rec_pose_jaw).reshape(bs*n, 1*3) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose[:, 66:69] = rec_pose_jaw + + to_global = rec_lower + to_global[:, :, 54:57] = 0.0 + to_global[:, :, :54] = rec_lower2global + rec_global = self.global_motion(to_global) + + rec_trans_v_s = rec_global["rec_pose"][:, :, 54:57] + rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + tar_pose = tar_pose[:, :n, :] + tar_exps = tar_exps[:, :n, :] + tar_trans = tar_trans[:, :n, :] + tar_beta = tar_beta[:, :n, :] + + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + + return { + 'rec_pose': rec_pose, + 'rec_trans': rec_trans, + 'tar_pose': tar_pose, + 'tar_exps': tar_exps, + 'tar_beta': tar_beta, + 'tar_trans': tar_trans, + 'rec_exps': rec_exps, + } + + + def train(self, epoch): + #torch.autograd.set_detect_anomaly(True) + use_adv = bool(epoch>=self.args.no_adv_epoch) + self.model.train() + # self.d_model.train() + t_start = time.time() + self.tracker.reset() + for its, batch_data in enumerate(self.train_loader): + loaded_data = self._load_data(batch_data) + t_data = time.time() - t_start + + self.opt.zero_grad() + g_loss_final = 0 + g_loss_final += self._g_training(loaded_data, use_adv, 'train', epoch) + #with torch.autograd.detect_anomaly(): + g_loss_final.backward() + if self.args.grad_norm != 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm) + self.opt.step() + + mem_cost = torch.cuda.memory_cached() / 1E9 + lr_g = self.opt.param_groups[0]['lr'] + # lr_d = self.opt_d.param_groups[0]['lr'] + t_train = time.time() - t_start - t_data + t_start = time.time() + if its % self.args.log_period == 0: + self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g) + if self.args.debug: + if its == 1: break + self.opt_s.step(epoch) + # self.opt_d_s.step(epoch) + + def val(self, epoch): + self.model.eval() + # self.d_model.eval() + with torch.no_grad(): + for its, batch_data in enumerate(self.train_loader): + loaded_data = self._load_data(batch_data) + net_out = self._g_training(loaded_data, False, 'val', epoch) + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + if (30/self.args.pose_fps) != 1: + assert 30%self.args.pose_fps == 0 + n *= int(30/self.args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + n = tar_pose.shape[1] + remain = n%self.args.vae_test_len + tar_pose = tar_pose[:, :n-remain, :] + rec_pose = rec_pose[:, :n-remain, :] + latent_out = self.eval_copy.map2latent(rec_pose).reshape(-1, self.args.vae_length).cpu().numpy() + latent_ori = self.eval_copy.map2latent(tar_pose).reshape(-1, self.args.vae_length).cpu().numpy() + if its == 0: + latent_out_motion_all = latent_out + latent_ori_all = latent_ori + else: + latent_out_motion_all = np.concatenate([latent_out_motion_all, latent_out], axis=0) + latent_ori_all = np.concatenate([latent_ori_all, latent_ori], axis=0) + if self.args.debug: + if its == 1: break + fid_motion = data_tools.FIDCalculator.frechet_distance(latent_out_motion_all, latent_ori_all) + self.tracker.update_meter("fid", "val", fid_motion) + self.val_recording(epoch) + + + def test(self, epoch): + + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + return 0 + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + align = 0 + latent_out = [] + latent_ori = [] + l2_all = 0 + lvel = 0 + self.model.eval() + self.smplx.eval() + self.eval_copy.eval() + with torch.no_grad(): + for its, batch_data in enumerate(self.test_loader): + loaded_data = self._load_data(batch_data) + net_out = self._g_test(loaded_data) + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + tar_exps = net_out['tar_exps'] + tar_beta = net_out['tar_beta'] + rec_trans = net_out['rec_trans'] + tar_trans = net_out['tar_trans'] + rec_exps = net_out['rec_exps'] + # print(rec_pose.shape, tar_pose.shape) + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + if (30/self.args.pose_fps) != 1: + assert 30%self.args.pose_fps == 0 + n *= int(30/self.args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + + # print(rec_pose.shape, tar_pose.shape) + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + remain = n%self.args.vae_test_len + latent_out.append(self.eval_copy.map2latent(rec_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy()) # bs * n/8 * 240 + latent_ori.append(self.eval_copy.map2latent(tar_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy()) + + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:,:3], + body_pose=rec_pose[:,3:21*3+3], + left_hand_pose=rec_pose[:,25*3:40*3], + right_hand_pose=rec_pose[:,40*3:55*3], + return_joints=True, + leye_pose=rec_pose[:, 69:72], + reye_pose=rec_pose[:, 72:75], + ) + # vertices_tar = self.smplx( + # betas=tar_beta.reshape(bs*n, 300), + # transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3), + # expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100), + # jaw_pose=tar_pose[:, 66:69], + # global_orient=tar_pose[:,:3], + # body_pose=tar_pose[:,3:21*3+3], + # left_hand_pose=tar_pose[:,25*3:40*3], + # right_hand_pose=tar_pose[:,40*3:55*3], + # return_joints=True, + # leye_pose=tar_pose[:, 69:72], + # reye_pose=tar_pose[:, 72:75], + # ) + vertices_rec_face = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3), + expression=rec_exps.reshape(bs*n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:,:3]-rec_pose[:,:3], + body_pose=rec_pose[:,3:21*3+3]-rec_pose[:,3:21*3+3], + left_hand_pose=rec_pose[:,25*3:40*3]-rec_pose[:,25*3:40*3], + right_hand_pose=rec_pose[:,40*3:55*3]-rec_pose[:,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=rec_pose[:, 69:72]-rec_pose[:, 69:72], + reye_pose=rec_pose[:, 72:75]-rec_pose[:, 72:75], + ) + vertices_tar_face = self.smplx( + betas=tar_beta.reshape(bs*n, 300), + transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), + expression=tar_exps.reshape(bs*n, 100), + jaw_pose=tar_pose[:, 66:69], + global_orient=tar_pose[:,:3]-tar_pose[:,:3], + body_pose=tar_pose[:,3:21*3+3]-tar_pose[:,3:21*3+3], + left_hand_pose=tar_pose[:,25*3:40*3]-tar_pose[:,25*3:40*3], + right_hand_pose=tar_pose[:,40*3:55*3]-tar_pose[:,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=tar_pose[:, 69:72]-tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75]-tar_pose[:, 72:75], + ) + joints_rec = vertices_rec["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3] + # joints_tar = vertices_tar["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3] + facial_rec = vertices_rec_face['vertices'].reshape(1, n, -1)[0, :n] + facial_tar = vertices_tar_face['vertices'].reshape(1, n, -1)[0, :n] + face_vel_loss = self.vel_loss(facial_rec[1:, :] - facial_tar[:-1, :], facial_tar[1:, :] - facial_tar[:-1, :]) + l2 = self.reclatent_loss(facial_rec, facial_tar) + l2_all += l2.item() * n + lvel += face_vel_loss.item() * n + + _ = self.l1_calculator.run(joints_rec) + if self.alignmenter is not None: + in_audio_eval, sr = librosa.load(self.args.data_path+"wave16k/"+test_seq_list.iloc[its]['id']+".wav") + in_audio_eval = librosa.resample(in_audio_eval, orig_sr=sr, target_sr=self.args.audio_sr) + a_offset = int(self.align_mask * (self.args.audio_sr / self.args.pose_fps)) + onset_bt = self.alignmenter.load_audio(in_audio_eval[:int(self.args.audio_sr / self.args.pose_fps*n)], a_offset, len(in_audio_eval)-a_offset, True) + beat_vel = self.alignmenter.load_pose(joints_rec, self.align_mask, n-self.align_mask, 30, True) + # print(beat_vel) + align += (self.alignmenter.calculate_align(onset_bt, beat_vel, 30) * (n-2*self.align_mask)) + + tar_pose_np = tar_pose.detach().cpu().numpy() + rec_pose_np = rec_pose.detach().cpu().numpy() + rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) + rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) + gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose_np, + expressions=tar_exp_np, + trans=tar_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + total_length += n + + logger.info(f"l2 loss: {l2_all/total_length}") + logger.info(f"lvel loss: {lvel/total_length}") + + latent_out_all = np.concatenate(latent_out, axis=0) + latent_ori_all = np.concatenate(latent_ori, axis=0) + fid = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) + logger.info(f"fid score: {fid}") + self.test_recording("fid", fid, epoch) + + align_avg = align/(total_length-2*len(self.test_loader)*self.align_mask) + logger.info(f"align score: {align_avg}") + self.test_recording("bc", align_avg, epoch) + + l1div = self.l1_calculator.avg() + logger.info(f"l1div score: {l1div}") + self.test_recording("l1div", l1div, epoch) + + # data_tools.result2target_vis(self.args.pose_version, results_save_path, results_save_path, self.test_demo, False) + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") + + + def test_demo(self, epoch): + ''' + input audio and text, output motion + do not calculate loss and metric + save video + ''' + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + return 0 + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + align = 0 + latent_out = [] + latent_ori = [] + l2_all = 0 + lvel = 0 + self.model.eval() + self.smplx.eval() + # self.eval_copy.eval() + with torch.no_grad(): + for its, batch_data in enumerate(self.test_loader): + loaded_data = self._load_data(batch_data) + net_out = self._g_test(loaded_data) + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + tar_exps = net_out['tar_exps'] + tar_beta = net_out['tar_beta'] + rec_trans = net_out['rec_trans'] + tar_trans = net_out['tar_trans'] + rec_exps = net_out['rec_exps'] + # print(rec_pose.shape, tar_pose.shape) + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + + # interpolate to 30fps + if (30/self.args.pose_fps) != 1: + assert 30%self.args.pose_fps == 0 + n *= int(30/self.args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + + # print(rec_pose.shape, tar_pose.shape) + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + + tar_pose_np = tar_pose.detach().cpu().numpy() + rec_pose_np = rec_pose.detach().cpu().numpy() + rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) + rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) + + gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose_np, + expressions=tar_exp_np, + trans=tar_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + total_length += n + + data_tools.result2target_vis(self.args.pose_version, results_save_path, results_save_path, self.test_demo, False) + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") diff --git a/models/.ipynb_checkpoints/emage_audio-checkpoint.py b/models/.ipynb_checkpoints/emage_audio-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..06d43a28284c4b4d339c38efeb44c4a131e3e62d --- /dev/null +++ b/models/.ipynb_checkpoints/emage_audio-checkpoint.py @@ -0,0 +1,243 @@ +import copy +import math +import pickle +import numpy as np +import torch +import torch.nn as nn +from .utils.layer import BasicBlock +from .motion_encoder import * + + +class WavEncoder(nn.Module): + def __init__(self, out_dim, audio_in=1): + super().__init__() + self.out_dim = out_dim + self.feat_extractor = nn.Sequential( + BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1600, downsample=True), + BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ), + BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7), + BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True), + ) + def forward(self, wav_data): + # print(wav_data.shape) + if wav_data.dim() == 2: + wav_data = wav_data.unsqueeze(1) + else: + wav_data = wav_data.transpose(1, 2) + out = self.feat_extractor(wav_data) + return out.transpose(1, 2) + + +class MLP(nn.Module): + def __init__(self, in_dim, hidden_size, out_dim): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(in_dim, hidden_size), + nn.LeakyReLU(0.2, True), + nn.Linear(hidden_size, out_dim) + ) + def forward(self, inputs): + out = self.mlp(inputs) + return out + + +class PeriodicPositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=60): + super(PeriodicPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(period, d_model) + position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # (1, period, d_model) + repeat_num = (max_seq_len//period) + 1 + pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model) + self.register_buffer('pe', pe) + def forward(self, x): + # print(self.pe.shape, x.shape) + x = x + self.pe[:, :x.size(1), :] + return self.dropout(x) + + +class MAGE_Transformer(nn.Module): + def __init__(self, args): + super(MAGE_Transformer, self).__init__() + self.args = args + # with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + # self.lang_model = pickle.load(f) + # pre_trained_embedding = self.lang_model.word_embedding_weights + # self.text_pre_encoder_face = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre) + # self.text_encoder_face = nn.Linear(300, args.audio_f) + # self.text_encoder_face = nn.Linear(300, args.audio_f) + # self.text_pre_encoder_body = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre) + # self.text_encoder_body = nn.Linear(300, args.audio_f) + # self.text_encoder_body = nn.Linear(300, args.audio_f) + + self.audio_pre_encoder_face = WavEncoder(args.audio_f, audio_in=1) + self.audio_pre_encoder_body = WavEncoder(args.audio_f, audio_in=1) + + # self.at_attn_face = nn.Linear(args.audio_f*2, args.audio_f*2) + # self.at_attn_body = nn.Linear(args.audio_f*2, args.audio_f*2) + + args_top = copy.deepcopy(self.args) + args_top.vae_layer = 3 + args_top.vae_length = args.motion_f + args_top.vae_test_dim = args.pose_dims+3+4 + self.motion_encoder = VQEncoderV6(args_top) # masked motion to latent bs t 333 to bs t 256 + + # face decoder + self.feature2face = nn.Linear(args.audio_f*2, args.hidden_size) + self.face2latent = nn.Linear(args.hidden_size, args.vae_codebook_size) + self.transformer_de_layer = nn.TransformerDecoderLayer( + d_model=self.args.hidden_size, + nhead=4, + dim_feedforward=self.args.hidden_size*2, + batch_first=True + ) + self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4) + self.position_embeddings = PeriodicPositionalEncoding(self.args.hidden_size, period=self.args.pose_length, max_seq_len=self.args.pose_length) + + # motion decoder + self.transformer_en_layer = nn.TransformerEncoderLayer( + d_model=self.args.hidden_size, + nhead=4, + dim_feedforward=self.args.hidden_size*2, + batch_first=True + ) + self.motion_self_encoder = nn.TransformerEncoder(self.transformer_en_layer, num_layers=1) + self.audio_feature2motion = nn.Linear(args.audio_f, args.hidden_size) + self.feature2motion = nn.Linear(args.motion_f, args.hidden_size) + + self.bodyhints_face = MLP(args.motion_f, args.hidden_size, args.motion_f) + self.bodyhints_body = MLP(args.motion_f, args.hidden_size, args.motion_f) + self.motion2latent_upper = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size) + self.motion2latent_hands = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size) + self.motion2latent_lower = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size) + self.wordhints_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=8) + + self.upper_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1) + self.hands_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1) + self.lower_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1) + + self.face_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + self.upper_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + self.hands_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + self.lower_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + + self.mask_embeddings = nn.Parameter(torch.zeros(1, 1, self.args.pose_dims+3+4)) + self.motion_down_upper = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_hands = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_lower = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_upper = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_hands = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_lower = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self._reset_parameters() + + self.spearker_encoder_body = nn.Embedding(25, args.hidden_size) + self.spearker_encoder_face = nn.Embedding(25, args.hidden_size) + + def _reset_parameters(self): + nn.init.normal_(self.mask_embeddings, 0, self.args.hidden_size ** -0.5) + + def forward(self, in_audio=None, in_word=None, mask=None, is_test=None, in_motion=None, use_attentions=True, use_word=True, in_id = None): + # in_word_face = self.text_pre_encoder_face(in_word) + # in_word_face = self.text_encoder_face(in_word_face) + # in_word_body = self.text_pre_encoder_body(in_word) + # in_word_body = self.text_encoder_body(in_word_body) + # bs, t, c = in_word_face.shape + in_audio_face = self.audio_pre_encoder_face(in_audio) + in_audio_body = self.audio_pre_encoder_body(in_audio) + bs, t, c = in_audio_body.shape + # if in_audio_face.shape[1] != in_motion.shape[1]: + # diff_length = in_motion.shape[1]- in_audio_face.shape[1] + # if diff_length < 0: + # in_audio_face = in_audio_face[:, :diff_length, :] + # in_audio_body = in_audio_body[:, :diff_length, :] + # else: + # in_audio_face = torch.cat((in_audio_face, in_audio_face[:,-diff_length:]),1) + # in_audio_body = torch.cat((in_audio_body, in_audio_body[:,-diff_length:]),1) + + # if use_attentions: + # alpha_at_face = torch.cat([in_word_face, in_audio_face], dim=-1).reshape(bs, t, c*2) + # alpha_at_face = self.at_attn_face(alpha_at_face).reshape(bs, t, c, 2) + # alpha_at_face = alpha_at_face.softmax(dim=-1) + # fusion_face = in_word_face * alpha_at_face[:,:,:,1] + in_audio_face * alpha_at_face[:,:,:,0] + # alpha_at_body = torch.cat([in_word_body, in_audio_body], dim=-1).reshape(bs, t, c*2) + # alpha_at_body = self.at_attn_body(alpha_at_body).reshape(bs, t, c, 2) + # alpha_at_body = alpha_at_body.softmax(dim=-1) + # fusion_body = in_word_body * alpha_at_body[:,:,:,1] + in_audio_body * alpha_at_body[:,:,:,0] + # else: + fusion_face = in_audio_face + fusion_body = in_audio_body + + masked_embeddings = self.mask_embeddings.expand_as(in_motion) + masked_motion = torch.where(mask == 1, masked_embeddings, in_motion) # bs, t, 256 + body_hint = self.motion_encoder(masked_motion) # bs t 256 + speaker_embedding_face = self.spearker_encoder_face(in_id).squeeze(2) + speaker_embedding_body = self.spearker_encoder_body(in_id).squeeze(2) + + # decode face + use_body_hints = True + if use_body_hints: + body_hint_face = self.bodyhints_face(body_hint) + fusion_face = torch.cat([fusion_face, body_hint_face], dim=2) + a2g_face = self.feature2face(fusion_face) + face_embeddings = speaker_embedding_face + face_embeddings = self.position_embeddings(face_embeddings) + decoded_face = self.face_decoder(tgt=face_embeddings, memory=a2g_face) + face_latent = self.face2latent(decoded_face) + cls_face = self.face_classifier(face_latent) + + # motion spatial encoder + body_hint_body = self.bodyhints_body(body_hint) + motion_embeddings = self.feature2motion(body_hint_body) + motion_embeddings = speaker_embedding_body + motion_embeddings + motion_embeddings = self.position_embeddings(motion_embeddings) + + # bi-directional self-attention + motion_refined_embeddings = self.motion_self_encoder(motion_embeddings) + + # audio to gesture cross-modal attention + if use_word: + a2g_motion = self.audio_feature2motion(fusion_body) + motion_refined_embeddings_in = motion_refined_embeddings + speaker_embedding_body + motion_refined_embeddings_in = self.position_embeddings(motion_refined_embeddings) + word_hints = self.wordhints_decoder(tgt=motion_refined_embeddings_in, memory=a2g_motion) + motion_refined_embeddings = motion_refined_embeddings + word_hints + + # feedforward + upper_latent = self.motion2latent_upper(motion_refined_embeddings) + hands_latent = self.motion2latent_hands(motion_refined_embeddings) + lower_latent = self.motion2latent_lower(motion_refined_embeddings) + + upper_latent_in = upper_latent + speaker_embedding_body + upper_latent_in = self.position_embeddings(upper_latent_in) + hands_latent_in = hands_latent + speaker_embedding_body + hands_latent_in = self.position_embeddings(hands_latent_in) + lower_latent_in = lower_latent + speaker_embedding_body + lower_latent_in = self.position_embeddings(lower_latent_in) + + # transformer decoder + motion_upper = self.upper_decoder(tgt=upper_latent_in, memory=hands_latent+lower_latent) + motion_hands = self.hands_decoder(tgt=hands_latent_in, memory=upper_latent+lower_latent) + motion_lower = self.lower_decoder(tgt=lower_latent_in, memory=upper_latent+hands_latent) + upper_latent = self.motion_down_upper(motion_upper+upper_latent) + hands_latent = self.motion_down_hands(motion_hands+hands_latent) + lower_latent = self.motion_down_lower(motion_lower+lower_latent) + cls_lower = self.lower_classifier(lower_latent) + cls_upper = self.upper_classifier(upper_latent) + cls_hands = self.hands_classifier(hands_latent) + + return { + "rec_face":face_latent, + "rec_upper":upper_latent, + "rec_lower":lower_latent, + "rec_hands":hands_latent, + "cls_face":cls_face, + "cls_upper":cls_upper, + "cls_lower":cls_lower, + "cls_hands":cls_hands, + } \ No newline at end of file diff --git a/models/__pycache__/emage_audio.cpython-310.pyc b/models/__pycache__/emage_audio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a8b3e3900da28a70bbc073f0bb594b4a51bd8ea Binary files /dev/null and b/models/__pycache__/emage_audio.cpython-310.pyc differ diff --git a/models/__pycache__/emage_audio.cpython-38.pyc b/models/__pycache__/emage_audio.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beb71c71308d3293fab5de822488d0379016820d Binary files /dev/null and b/models/__pycache__/emage_audio.cpython-38.pyc differ diff --git a/models/__pycache__/motion_encoder.cpython-310.pyc b/models/__pycache__/motion_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e935774f62f0c45067cc41b10b65750945b89252 Binary files /dev/null and b/models/__pycache__/motion_encoder.cpython-310.pyc differ diff --git a/models/__pycache__/motion_encoder.cpython-38.pyc b/models/__pycache__/motion_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acd5cc4ab7080407e7af5d339e81996408d6b38e Binary files /dev/null and b/models/__pycache__/motion_encoder.cpython-38.pyc differ diff --git a/models/__pycache__/motion_representation.cpython-310.pyc b/models/__pycache__/motion_representation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfa83fe83cff4d94af606a674bb6f03f3ca26c8b Binary files /dev/null and b/models/__pycache__/motion_representation.cpython-310.pyc differ diff --git a/models/__pycache__/motion_representation.cpython-38.pyc b/models/__pycache__/motion_representation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56e457c6db3b08ef818f5877a4bec772590e03cb Binary files /dev/null and b/models/__pycache__/motion_representation.cpython-38.pyc differ diff --git a/models/__pycache__/quantizer.cpython-310.pyc b/models/__pycache__/quantizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6305026311d655daf77a96524ade637a4ef3687 Binary files /dev/null and b/models/__pycache__/quantizer.cpython-310.pyc differ diff --git a/models/__pycache__/quantizer.cpython-38.pyc b/models/__pycache__/quantizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68949657987ca62fef62d1e04074040839eabc73 Binary files /dev/null and b/models/__pycache__/quantizer.cpython-38.pyc differ diff --git a/models/camn.py b/models/camn.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9cd7f9f70447fd64a8376203a2a22925991c16 --- /dev/null +++ b/models/camn.py @@ -0,0 +1,449 @@ +import torch +import torch.nn as nn +import os +import pickle +import numpy as np +from torch.nn.utils import weight_norm +from .utils.build_vocab import Vocab + +class Chomp1d(nn.Module): + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + return x[:, :, :-self.chomp_size].contiguous() + + +class TemporalBlock(nn.Module): + def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): + super(TemporalBlock, self).__init__() + self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp1 = Chomp1d(padding) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp2 = Chomp1d(padding) + self.relu2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, + self.conv2, self.chomp2, self.relu2, self.dropout2) + self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + self.relu = nn.ReLU() + self.init_weights() + + def init_weights(self): + self.conv1.weight.data.normal_(0, 0.01) + self.conv2.weight.data.normal_(0, 0.01) + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.01) + + def forward(self, x): + out = self.net(x) + res = x if self.downsample is None else self.downsample(x) + return self.relu(out + res) + + +class TemporalConvNet(nn.Module): + def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): + super(TemporalConvNet, self).__init__() + layers = [] + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2 ** i + in_channels = num_inputs if i == 0 else num_channels[i-1] + out_channels = num_channels[i] + layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, + padding=(kernel_size-1) * dilation_size, dropout=dropout)] + + self.network = nn.Sequential(*layers) + + def forward(self, x): + return self.network(x) + + +class TextEncoderTCN(nn.Module): + """ based on https://github.com/locuslab/TCN/blob/master/TCN/word_cnn/model.py """ + def __init__(self, args, n_words, embed_size=300, pre_trained_embedding=None, + kernel_size=2, dropout=0.3, emb_dropout=0.1, word_cache=False): + super(TextEncoderTCN, self).__init__() + if word_cache: + self.embedding = None + else: + if pre_trained_embedding is not None: # use pre-trained embedding (fasttext) + #print(pre_trained_embedding.shape) + assert pre_trained_embedding.shape[0] == n_words + assert pre_trained_embedding.shape[1] == embed_size + self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding), + freeze=args.freeze_wordembed) + else: + self.embedding = nn.Embedding(n_words, embed_size) + + num_channels = [args.hidden_size] * args.n_layer + self.tcn = TemporalConvNet(embed_size, num_channels, kernel_size, dropout=dropout) + + self.decoder = nn.Linear(num_channels[-1], args.word_f) + self.drop = nn.Dropout(emb_dropout) + self.emb_dropout = emb_dropout + self.init_weights() + + def init_weights(self): + self.decoder.bias.data.fill_(0) + self.decoder.weight.data.normal_(0, 0.01) + + def forward(self, input): + #print(input.shape) + if self.embedding is None: + emb = self.drop(input) + else: + emb = self.drop(self.embedding(input)) + y = self.tcn(emb.transpose(1, 2)).transpose(1, 2) + y = self.decoder(y) + return y.contiguous(), 0 + + +class BasicBlock(nn.Module): + """ based on timm: https://github.com/rwightman/pytorch-image-models """ + def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm1d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(BasicBlock, self).__init__() + + self.conv1 = nn.Conv1d( + inplanes, planes, kernel_size=ker_size, stride=stride, padding=first_dilation, + dilation=dilation, bias=True) + self.bn1 = norm_layer(planes) + self.act1 = act_layer(inplace=True) + self.conv2 = nn.Conv1d( + planes, planes, kernel_size=ker_size, padding=ker_size//2, dilation=dilation, bias=True) + self.bn2 = norm_layer(planes) + self.act2 = act_layer(inplace=True) + if downsample is not None: + self.downsample = nn.Sequential( + nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, padding=first_dilation, dilation=dilation, bias=True), + norm_layer(planes), + ) + else: self.downsample=None + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act2(x) + return x + + +class WavEncoder(nn.Module): + def __init__(self, out_dim): + super().__init__() + self.out_dim = out_dim + self.feat_extractor = nn.Sequential( + BasicBlock(1, 32, 15, 5, first_dilation=1600, downsample=True), + BasicBlock(32, 32, 15, 6, first_dilation=0, downsample=True), + BasicBlock(32, 32, 15, 1, first_dilation=7, ), + BasicBlock(32, 64, 15, 6, first_dilation=0, downsample=True), + BasicBlock(64, 64, 15, 1, first_dilation=7), + BasicBlock(64, 128, 15, 6, first_dilation=0,downsample=True), + ) + + def forward(self, wav_data): + wav_data = wav_data.unsqueeze(1) + out = self.feat_extractor(wav_data) + return out.transpose(1, 2) + + +class PoseGenerator(nn.Module): + """ + End2End model + audio, text and speaker ID encoder are customized based on Yoon et al. SIGGRAPH ASIA 2020 + """ + def __init__(self, args): + super().__init__() + self.args = args + self.pre_length = args.pre_frames + self.gen_length = args.pose_length - args.pre_frames + self.pose_dims = args.pose_dims + self.facial_f = args.facial_f + self.speaker_f = args.speaker_f + self.audio_f = args.audio_f + self.word_f = args.word_f + self.emotion_f = args.emotion_f + self.facial_dims = args.facial_dims + self.args.speaker_dims = args.speaker_dims + self.emotion_dims = args.emotion_dims + + self.in_size = self.audio_f + self.pose_dims + self.facial_f + self.word_f + 1 + self.audio_encoder = WavEncoder(self.audio_f) + self.hidden_size = args.hidden_size + self.n_layer = args.n_layer + + if self.facial_f is not 0: + self.facial_encoder = nn.Sequential( + BasicBlock(self.facial_dims, self.facial_f//2, 7, 1, first_dilation=3, downsample=True), + BasicBlock(self.facial_f//2, self.facial_f//2, 3, 1, first_dilation=1, downsample=True), + BasicBlock(self.facial_f//2, self.facial_f//2, 3, 1, first_dilation=1, ), + BasicBlock(self.facial_f//2, self.facial_f, 3, 1, first_dilation=1, downsample=True), + ) + else: + self.facial_encoder = None + + self.text_encoder = None + if self.word_f is not 0: + if args.word_cache: + self.text_encoder = TextEncoderTCN(args, args.word_index_num, args.word_dims, pre_trained_embedding=None, + dropout=args.dropout_prob, word_cache=True) + else: + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + pre_trained_embedding = self.lang_model.word_embedding_weights + self.text_encoder = TextEncoderTCN(args, args.word_index_num, args.word_dims, pre_trained_embedding=pre_trained_embedding, + dropout=args.dropout_prob) + + self.speaker_embedding = None + if self.speaker_f is not 0: + self.in_size += self.speaker_f + self.speaker_embedding = nn.Sequential( + nn.Embedding(self.args.speaker_dims, self.speaker_f), + nn.Linear(self.speaker_f, self.speaker_f), + nn.LeakyReLU(True) + ) + + + self.emotion_embedding = None + if self.emotion_f is not 0: + self.in_size += self.emotion_f + + self.emotion_embedding = nn.Sequential( + nn.Embedding(self.emotion_dims, self.emotion_f), + nn.Linear(self.emotion_f, self.emotion_f) + ) + + # self.emotion_embedding_tail = nn.Sequential( + # nn.Conv1d(self.emotion_f, 8, 9, 1, 4), + # nn.BatchNorm1d(8), + # nn.LeakyReLU(0.3, inplace=True), + # nn.Conv1d(8, 16, 9, 1, 4), + # nn.BatchNorm1d(16), + # nn.LeakyReLU(0.3, inplace=True), + # nn.Conv1d(16, 16, 9, 1, 4), + # nn.BatchNorm1d(16), + # nn.LeakyReLU(0.3, inplace=True), + # nn.Conv1d(16, self.emotion_f, 9, 1, 4), + # nn.BatchNorm1d(self.emotion_f), + # nn.LeakyReLU(0.3, inplace=True), + # ) + + self.LSTM = nn.LSTM(self.in_size+3, hidden_size=self.hidden_size, num_layers=args.n_layer, batch_first=True, + bidirectional=True, dropout=args.dropout_prob) + self.out = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size//2), + nn.LeakyReLU(True), + nn.Linear(self.hidden_size//2, 330-180) + ) + + self.LSTM_hands = nn.LSTM(self.in_size+150+3, hidden_size=self.hidden_size, num_layers=args.n_layer, batch_first=True, + bidirectional=True, dropout=args.dropout_prob) + self.out_hands = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size//2), + nn.LeakyReLU(True), + nn.Linear(self.hidden_size//2, 180+3) + ) + + self.do_flatten_parameters = False + if torch.cuda.device_count() > 1: + self.do_flatten_parameters = True + + + def forward(self, pre_seq, in_audio=None, in_facial=None, in_text=None, in_id=None, in_emo=None, is_test=False): + if self.do_flatten_parameters: + self.LSTM.flatten_parameters() + + text_feat_seq = audio_feat_seq = None + if in_audio is not None: + audio_feat_seq = self.audio_encoder(in_audio) + if in_text is not None: + text_feat_seq, _ = self.text_encoder(in_text) + assert(audio_feat_seq.shape[1] == text_feat_seq.shape[1]) + + if self.facial_f is not 0: + face_feat_seq = self.facial_encoder(in_facial.permute([0, 2, 1])) + face_feat_seq = face_feat_seq.permute([0, 2, 1]) + speaker_feat_seq = None + if self.speaker_embedding: + speaker_feat_seq = self.speaker_embedding(in_id) + emo_feat_seq = None + if self.emotion_embedding: + emo_feat_seq = self.emotion_embedding(in_emo) + emo_feat_seq = emo_feat_seq.permute([0,2,1]) + emo_feat_seq = self.emotion_embedding_tail(emo_feat_seq) + emo_feat_seq = emo_feat_seq.permute([0,2,1]) + + if audio_feat_seq.shape[1] != pre_seq.shape[1]: + diff_length = pre_seq.shape[1] - audio_feat_seq.shape[1] + audio_feat_seq = torch.cat((audio_feat_seq, audio_feat_seq[:,-diff_length:, :].reshape(1,diff_length,-1)),1) + + if self.audio_f is not 0 and self.facial_f is 0: + in_data = torch.cat((pre_seq, audio_feat_seq), dim=2) + elif self.audio_f is not 0 and self.facial_f is not 0: + in_data = torch.cat((pre_seq, audio_feat_seq, face_feat_seq), dim=2) + else: pass + + if text_feat_seq is not None: + in_data = torch.cat((in_data, text_feat_seq), dim=2) + if emo_feat_seq is not None: + in_data = torch.cat((in_data, emo_feat_seq), dim=2) + + if speaker_feat_seq is not None: + repeated_s = speaker_feat_seq + if len(repeated_s.shape) == 2: + repeated_s = repeated_s.reshape(1, repeated_s.shape[1], repeated_s.shape[0]) + repeated_s = repeated_s.repeat(1, in_data.shape[1], 1) + in_data = torch.cat((in_data, repeated_s), dim=2) + + output, _ = self.LSTM(in_data) + output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] + output = self.out(output.reshape(-1, output.shape[2])) + decoder_outputs = output.reshape(in_data.shape[0], in_data.shape[1], -1) + return decoder_outputs + + +class CaMN(PoseGenerator): + def __init__(self, args): + super().__init__(args) + self.audio_fusion_dim = self.audio_f+self.speaker_f+self.emotion_f+self.word_f + self.facial_fusion_dim = self.audio_fusion_dim + self.facial_f + self.audio_fusion = nn.Sequential( + nn.Linear(self.audio_fusion_dim, self.hidden_size//2), + nn.LeakyReLU(True), + nn.Linear(self.hidden_size//2, self.audio_f), + nn.LeakyReLU(True), + ) + + self.facial_fusion = nn.Sequential( + nn.Linear(self.facial_fusion_dim, self.hidden_size//2), + nn.LeakyReLU(True), + nn.Linear(self.hidden_size//2, self.facial_f), + nn.LeakyReLU(True), + ) + + def forward(self, pre_seq, in_audio=None, in_facial=None, in_text=None, in_id=None, in_emo=None): + if self.do_flatten_parameters: + self.LSTM.flatten_parameters() + + decoder_hidden = decoder_hidden_hands = None + text_feat_seq = audio_feat_seq = speaker_feat_seq = emo_feat_seq = face_feat_seq = None + in_data = None + + if self.speaker_embedding: + speaker_feat_seq = self.speaker_embedding(in_id).squeeze(2) + in_data = torch.cat((in_data, speaker_feat_seq), 2) if in_data is not None else speaker_feat_seq + + if self.emotion_embedding: + emo_feat_seq = self.emotion_embedding(in_emo).squeeze(2) + in_data = torch.cat((in_data, emo_feat_seq), 2) + + if in_text is not None: + text_feat_seq, _ = self.text_encoder(in_text) + in_data = torch.cat((in_data, text_feat_seq), 2) if in_data is not None else text_feat_seq + + if in_audio is not None: + audio_feat_seq = self.audio_encoder(in_audio) + if in_text is not None: + if (audio_feat_seq.shape[1] != text_feat_seq.shape[1]): + min_gap = text_feat_seq.shape[1] - audio_feat_seq.shape[1] + audio_feat_seq = torch.cat((audio_feat_seq, audio_feat_seq[:,-min_gap:, :]),1) + audio_fusion_seq = self.audio_fusion(torch.cat((audio_feat_seq, emo_feat_seq, speaker_feat_seq, text_feat_seq), dim=2).reshape(-1, self.audio_fusion_dim)) + audio_feat_seq = audio_fusion_seq.reshape(*audio_feat_seq.shape) + in_data = torch.cat((in_data, audio_feat_seq), 2) if in_data is not None else audio_feat_seq + + if self.facial_f is not 0: + face_feat_seq = self.facial_encoder(in_facial.permute([0, 2, 1])) + face_feat_seq = face_feat_seq.permute([0, 2, 1]) + if (audio_feat_seq.shape[1] != face_feat_seq.shape[1]): + min_gap_2 = face_feat_seq.shape[1] - audio_feat_seq.shape[1] + if min_gap_2 > 0: + face_feat_seq = face_feat_seq[:,:audio_feat_seq.shape[1], :] + else: + face_feat_seq = torch.cat((face_feat_seq, face_feat_seq[:,-min_gap_2:, :]),1) + + face_fusion_seq = self.facial_fusion(torch.cat((face_feat_seq, audio_feat_seq, emo_feat_seq, speaker_feat_seq, text_feat_seq), dim=2).reshape(-1, self.facial_fusion_dim)) + face_feat_seq = face_fusion_seq.reshape(*face_feat_seq.shape) + in_data = torch.cat((in_data, face_feat_seq), 2) if in_data is not None else face_feat_seq + + + in_data = torch.cat((pre_seq, in_data), dim=2) + output, _ = self.LSTM(in_data) + output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] + output = self.out(output.reshape(-1, output.shape[2])) + decoder_outputs = output.reshape(in_data.shape[0], in_data.shape[1], -1) + + in_data = torch.cat((in_data, decoder_outputs), dim=2) + output_hands, _ = self.LSTM_hands(in_data) + output_hands = output_hands[:, :, :self.hidden_size] + output_hands[:, :, self.hidden_size:] + output_hands = self.out_hands(output_hands.reshape(-1, output_hands.shape[2])) + decoder_outputs_hands = output_hands.reshape(in_data.shape[0], in_data.shape[1], -1) + + decoder_outputs_final = torch.zeros((in_data.shape[0], in_data.shape[1], 333)).to(in_data.device) + decoder_outputs_final[:, :, 0:150] = decoder_outputs[:, :, 0:150] + decoder_outputs_final[:, :, 150:333] = decoder_outputs_hands[:, :, 0:183] + return { + "rec_pose": decoder_outputs_final, + } + + +class ConvDiscriminator(nn.Module): + def __init__(self, args): + super().__init__() + self.input_size = args.pose_dims + + self.hidden_size = 64 + self.pre_conv = nn.Sequential( + nn.Conv1d(self.input_size, 16, 3), + nn.BatchNorm1d(16), + nn.LeakyReLU(True), + nn.Conv1d(16, 8, 3), + nn.BatchNorm1d(8), + nn.LeakyReLU(True), + nn.Conv1d(8, 8, 3), + ) + + self.LSTM = nn.LSTM(8, hidden_size=self.hidden_size, num_layers=4, bidirectional=True, + dropout=0.3, batch_first=True) + self.out = nn.Linear(self.hidden_size, 1) + self.out2 = nn.Linear(34-6, 1) + + self.do_flatten_parameters = False + if torch.cuda.device_count() > 1: + self.do_flatten_parameters = True + + def forward(self, poses): + if self.do_flatten_parameters: + self.LSTM.flatten_parameters() + poses = poses.transpose(1, 2) + feat = self.pre_conv(poses) + feat = feat.transpose(1, 2) + output, _ = self.LSTM(feat) + output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] + batch_size = poses.shape[0] + output = output.contiguous().view(-1, output.shape[2]) + output = self.out(output) # apply linear to every output + output = output.view(batch_size, -1) + output = self.out2(output) + output = torch.sigmoid(output) + return output \ No newline at end of file diff --git a/models/emage.py b/models/emage.py new file mode 100644 index 0000000000000000000000000000000000000000..1849cc9af7ba7875b7a686ad091a3c87331739bf --- /dev/null +++ b/models/emage.py @@ -0,0 +1,241 @@ +import copy +import math +import pickle +import numpy as np +import torch +import torch.nn as nn +from .utils.layer import BasicBlock +from .motion_encoder import * + + +class WavEncoder(nn.Module): + def __init__(self, out_dim, audio_in=1): + super().__init__() + self.out_dim = out_dim + self.feat_extractor = nn.Sequential( + BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1600, downsample=True), + BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ), + BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7), + BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True), + ) + def forward(self, wav_data): + if wav_data.dim() == 2: + wav_data = wav_data.unsqueeze(1) + else: + wav_data = wav_data.transpose(1, 2) + out = self.feat_extractor(wav_data) + return out.transpose(1, 2) + + +class MLP(nn.Module): + def __init__(self, in_dim, hidden_size, out_dim): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(in_dim, hidden_size), + nn.LeakyReLU(0.2, True), + nn.Linear(hidden_size, out_dim) + ) + def forward(self, inputs): + out = self.mlp(inputs) + return out + + +class PeriodicPositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=60): + super(PeriodicPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(period, d_model) + position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # (1, period, d_model) + repeat_num = (max_seq_len//period) + 1 + pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model) + self.register_buffer('pe', pe) + def forward(self, x): + # print(self.pe.shape, x.shape) + x = x + self.pe[:, :x.size(1), :] + return self.dropout(x) + + +class MAGE_Transformer(nn.Module): + def __init__(self, args): + super(MAGE_Transformer, self).__init__() + self.args = args + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + pre_trained_embedding = self.lang_model.word_embedding_weights + self.text_pre_encoder_face = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre) + self.text_encoder_face = nn.Linear(300, args.audio_f) + self.text_encoder_face = nn.Linear(300, args.audio_f) + self.text_pre_encoder_body = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre) + self.text_encoder_body = nn.Linear(300, args.audio_f) + self.text_encoder_body = nn.Linear(300, args.audio_f) + + self.audio_pre_encoder_face = WavEncoder(args.audio_f, audio_in=2) + self.audio_pre_encoder_body = WavEncoder(args.audio_f, audio_in=2) + + self.at_attn_face = nn.Linear(args.audio_f*2, args.audio_f*2) + self.at_attn_body = nn.Linear(args.audio_f*2, args.audio_f*2) + + args_top = copy.deepcopy(self.args) + args_top.vae_layer = 3 + args_top.vae_length = args.motion_f + args_top.vae_test_dim = args.pose_dims+3+4 + self.motion_encoder = VQEncoderV6(args_top) # masked motion to latent bs t 333 to bs t 256 + + # face decoder + self.feature2face = nn.Linear(args.audio_f*2, args.hidden_size) + self.face2latent = nn.Linear(args.hidden_size, args.vae_codebook_size) + self.transformer_de_layer = nn.TransformerDecoderLayer( + d_model=self.args.hidden_size, + nhead=4, + dim_feedforward=self.args.hidden_size*2, + batch_first=True + ) + self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4) + self.position_embeddings = PeriodicPositionalEncoding(self.args.hidden_size, period=self.args.pose_length, max_seq_len=self.args.pose_length) + + # motion decoder + self.transformer_en_layer = nn.TransformerEncoderLayer( + d_model=self.args.hidden_size, + nhead=4, + dim_feedforward=self.args.hidden_size*2, + batch_first=True + ) + self.motion_self_encoder = nn.TransformerEncoder(self.transformer_en_layer, num_layers=1) + self.audio_feature2motion = nn.Linear(args.audio_f, args.hidden_size) + self.feature2motion = nn.Linear(args.motion_f, args.hidden_size) + + self.bodyhints_face = MLP(args.motion_f, args.hidden_size, args.motion_f) + self.bodyhints_body = MLP(args.motion_f, args.hidden_size, args.motion_f) + self.motion2latent_upper = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size) + self.motion2latent_hands = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size) + self.motion2latent_lower = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size) + self.wordhints_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=8) + + self.upper_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1) + self.hands_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1) + self.lower_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1) + + self.face_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + self.upper_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + self.hands_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + self.lower_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + + self.mask_embeddings = nn.Parameter(torch.zeros(1, 1, self.args.pose_dims+3+4)) + self.motion_down_upper = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_hands = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_lower = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_upper = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_hands = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_lower = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self._reset_parameters() + + self.spearker_encoder_body = nn.Embedding(25, args.hidden_size) + self.spearker_encoder_face = nn.Embedding(25, args.hidden_size) + + def _reset_parameters(self): + nn.init.normal_(self.mask_embeddings, 0, self.args.hidden_size ** -0.5) + + def forward(self, in_audio=None, in_word=None, mask=None, is_test=None, in_motion=None, use_attentions=True, use_word=True, in_id = None): + in_word_face = self.text_pre_encoder_face(in_word) + in_word_face = self.text_encoder_face(in_word_face) + in_word_body = self.text_pre_encoder_body(in_word) + in_word_body = self.text_encoder_body(in_word_body) + bs, t, c = in_word_face.shape + in_audio_face = self.audio_pre_encoder_face(in_audio) + in_audio_body = self.audio_pre_encoder_body(in_audio) + if in_audio_face.shape[1] != in_motion.shape[1]: + diff_length = in_motion.shape[1]- in_audio_face.shape[1] + if diff_length < 0: + in_audio_face = in_audio_face[:, :diff_length, :] + in_audio_body = in_audio_body[:, :diff_length, :] + else: + in_audio_face = torch.cat((in_audio_face, in_audio_face[:,-diff_length:]),1) + in_audio_body = torch.cat((in_audio_body, in_audio_body[:,-diff_length:]),1) + + if use_attentions: + alpha_at_face = torch.cat([in_word_face, in_audio_face], dim=-1).reshape(bs, t, c*2) + alpha_at_face = self.at_attn_face(alpha_at_face).reshape(bs, t, c, 2) + alpha_at_face = alpha_at_face.softmax(dim=-1) + fusion_face = in_word_face * alpha_at_face[:,:,:,1] + in_audio_face * alpha_at_face[:,:,:,0] + alpha_at_body = torch.cat([in_word_body, in_audio_body], dim=-1).reshape(bs, t, c*2) + alpha_at_body = self.at_attn_body(alpha_at_body).reshape(bs, t, c, 2) + alpha_at_body = alpha_at_body.softmax(dim=-1) + fusion_body = in_word_body * alpha_at_body[:,:,:,1] + in_audio_body * alpha_at_body[:,:,:,0] + else: + fusion_face = in_word_face + in_audio_face + fusion_body = in_word_body + in_audio_body + + masked_embeddings = self.mask_embeddings.expand_as(in_motion) + masked_motion = torch.where(mask == 1, masked_embeddings, in_motion) # bs, t, 256 + body_hint = self.motion_encoder(masked_motion) # bs t 256 + speaker_embedding_face = self.spearker_encoder_face(in_id).squeeze(2) + speaker_embedding_body = self.spearker_encoder_body(in_id).squeeze(2) + + # decode face + use_body_hints = True + if use_body_hints: + body_hint_face = self.bodyhints_face(body_hint) + fusion_face = torch.cat([fusion_face, body_hint_face], dim=2) + a2g_face = self.feature2face(fusion_face) + face_embeddings = speaker_embedding_face + face_embeddings = self.position_embeddings(face_embeddings) + decoded_face = self.face_decoder(tgt=face_embeddings, memory=a2g_face) + face_latent = self.face2latent(decoded_face) + cls_face = self.face_classifier(face_latent) + + # motion spatial encoder + body_hint_body = self.bodyhints_body(body_hint) + motion_embeddings = self.feature2motion(body_hint_body) + motion_embeddings = speaker_embedding_body + motion_embeddings + motion_embeddings = self.position_embeddings(motion_embeddings) + + # bi-directional self-attention + motion_refined_embeddings = self.motion_self_encoder(motion_embeddings) + + # audio to gesture cross-modal attention + if use_word: + a2g_motion = self.audio_feature2motion(fusion_body) + motion_refined_embeddings_in = motion_refined_embeddings + speaker_embedding_body + motion_refined_embeddings_in = self.position_embeddings(motion_refined_embeddings) + word_hints = self.wordhints_decoder(tgt=motion_refined_embeddings_in, memory=a2g_motion) + motion_refined_embeddings = motion_refined_embeddings + word_hints + + # feedforward + upper_latent = self.motion2latent_upper(motion_refined_embeddings) + hands_latent = self.motion2latent_hands(motion_refined_embeddings) + lower_latent = self.motion2latent_lower(motion_refined_embeddings) + + upper_latent_in = upper_latent + speaker_embedding_body + upper_latent_in = self.position_embeddings(upper_latent_in) + hands_latent_in = hands_latent + speaker_embedding_body + hands_latent_in = self.position_embeddings(hands_latent_in) + lower_latent_in = lower_latent + speaker_embedding_body + lower_latent_in = self.position_embeddings(lower_latent_in) + + # transformer decoder + motion_upper = self.upper_decoder(tgt=upper_latent_in, memory=hands_latent+lower_latent) + motion_hands = self.hands_decoder(tgt=hands_latent_in, memory=upper_latent+lower_latent) + motion_lower = self.lower_decoder(tgt=lower_latent_in, memory=upper_latent+hands_latent) + upper_latent = self.motion_down_upper(motion_upper+upper_latent) + hands_latent = self.motion_down_hands(motion_hands+hands_latent) + lower_latent = self.motion_down_lower(motion_lower+lower_latent) + cls_lower = self.lower_classifier(lower_latent) + cls_upper = self.upper_classifier(upper_latent) + cls_hands = self.hands_classifier(hands_latent) + + return { + "rec_face":face_latent, + "rec_upper":upper_latent, + "rec_lower":lower_latent, + "rec_hands":hands_latent, + "cls_face":cls_face, + "cls_upper":cls_upper, + "cls_lower":cls_lower, + "cls_hands":cls_hands, + } \ No newline at end of file diff --git a/models/emage_audio.py b/models/emage_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..06d43a28284c4b4d339c38efeb44c4a131e3e62d --- /dev/null +++ b/models/emage_audio.py @@ -0,0 +1,243 @@ +import copy +import math +import pickle +import numpy as np +import torch +import torch.nn as nn +from .utils.layer import BasicBlock +from .motion_encoder import * + + +class WavEncoder(nn.Module): + def __init__(self, out_dim, audio_in=1): + super().__init__() + self.out_dim = out_dim + self.feat_extractor = nn.Sequential( + BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1600, downsample=True), + BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ), + BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7), + BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True), + ) + def forward(self, wav_data): + # print(wav_data.shape) + if wav_data.dim() == 2: + wav_data = wav_data.unsqueeze(1) + else: + wav_data = wav_data.transpose(1, 2) + out = self.feat_extractor(wav_data) + return out.transpose(1, 2) + + +class MLP(nn.Module): + def __init__(self, in_dim, hidden_size, out_dim): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(in_dim, hidden_size), + nn.LeakyReLU(0.2, True), + nn.Linear(hidden_size, out_dim) + ) + def forward(self, inputs): + out = self.mlp(inputs) + return out + + +class PeriodicPositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=60): + super(PeriodicPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(period, d_model) + position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # (1, period, d_model) + repeat_num = (max_seq_len//period) + 1 + pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model) + self.register_buffer('pe', pe) + def forward(self, x): + # print(self.pe.shape, x.shape) + x = x + self.pe[:, :x.size(1), :] + return self.dropout(x) + + +class MAGE_Transformer(nn.Module): + def __init__(self, args): + super(MAGE_Transformer, self).__init__() + self.args = args + # with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + # self.lang_model = pickle.load(f) + # pre_trained_embedding = self.lang_model.word_embedding_weights + # self.text_pre_encoder_face = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre) + # self.text_encoder_face = nn.Linear(300, args.audio_f) + # self.text_encoder_face = nn.Linear(300, args.audio_f) + # self.text_pre_encoder_body = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre) + # self.text_encoder_body = nn.Linear(300, args.audio_f) + # self.text_encoder_body = nn.Linear(300, args.audio_f) + + self.audio_pre_encoder_face = WavEncoder(args.audio_f, audio_in=1) + self.audio_pre_encoder_body = WavEncoder(args.audio_f, audio_in=1) + + # self.at_attn_face = nn.Linear(args.audio_f*2, args.audio_f*2) + # self.at_attn_body = nn.Linear(args.audio_f*2, args.audio_f*2) + + args_top = copy.deepcopy(self.args) + args_top.vae_layer = 3 + args_top.vae_length = args.motion_f + args_top.vae_test_dim = args.pose_dims+3+4 + self.motion_encoder = VQEncoderV6(args_top) # masked motion to latent bs t 333 to bs t 256 + + # face decoder + self.feature2face = nn.Linear(args.audio_f*2, args.hidden_size) + self.face2latent = nn.Linear(args.hidden_size, args.vae_codebook_size) + self.transformer_de_layer = nn.TransformerDecoderLayer( + d_model=self.args.hidden_size, + nhead=4, + dim_feedforward=self.args.hidden_size*2, + batch_first=True + ) + self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4) + self.position_embeddings = PeriodicPositionalEncoding(self.args.hidden_size, period=self.args.pose_length, max_seq_len=self.args.pose_length) + + # motion decoder + self.transformer_en_layer = nn.TransformerEncoderLayer( + d_model=self.args.hidden_size, + nhead=4, + dim_feedforward=self.args.hidden_size*2, + batch_first=True + ) + self.motion_self_encoder = nn.TransformerEncoder(self.transformer_en_layer, num_layers=1) + self.audio_feature2motion = nn.Linear(args.audio_f, args.hidden_size) + self.feature2motion = nn.Linear(args.motion_f, args.hidden_size) + + self.bodyhints_face = MLP(args.motion_f, args.hidden_size, args.motion_f) + self.bodyhints_body = MLP(args.motion_f, args.hidden_size, args.motion_f) + self.motion2latent_upper = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size) + self.motion2latent_hands = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size) + self.motion2latent_lower = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size) + self.wordhints_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=8) + + self.upper_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1) + self.hands_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1) + self.lower_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1) + + self.face_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + self.upper_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + self.hands_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + self.lower_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size) + + self.mask_embeddings = nn.Parameter(torch.zeros(1, 1, self.args.pose_dims+3+4)) + self.motion_down_upper = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_hands = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_lower = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_upper = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_hands = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self.motion_down_lower = nn.Linear(args.hidden_size, self.args.vae_codebook_size) + self._reset_parameters() + + self.spearker_encoder_body = nn.Embedding(25, args.hidden_size) + self.spearker_encoder_face = nn.Embedding(25, args.hidden_size) + + def _reset_parameters(self): + nn.init.normal_(self.mask_embeddings, 0, self.args.hidden_size ** -0.5) + + def forward(self, in_audio=None, in_word=None, mask=None, is_test=None, in_motion=None, use_attentions=True, use_word=True, in_id = None): + # in_word_face = self.text_pre_encoder_face(in_word) + # in_word_face = self.text_encoder_face(in_word_face) + # in_word_body = self.text_pre_encoder_body(in_word) + # in_word_body = self.text_encoder_body(in_word_body) + # bs, t, c = in_word_face.shape + in_audio_face = self.audio_pre_encoder_face(in_audio) + in_audio_body = self.audio_pre_encoder_body(in_audio) + bs, t, c = in_audio_body.shape + # if in_audio_face.shape[1] != in_motion.shape[1]: + # diff_length = in_motion.shape[1]- in_audio_face.shape[1] + # if diff_length < 0: + # in_audio_face = in_audio_face[:, :diff_length, :] + # in_audio_body = in_audio_body[:, :diff_length, :] + # else: + # in_audio_face = torch.cat((in_audio_face, in_audio_face[:,-diff_length:]),1) + # in_audio_body = torch.cat((in_audio_body, in_audio_body[:,-diff_length:]),1) + + # if use_attentions: + # alpha_at_face = torch.cat([in_word_face, in_audio_face], dim=-1).reshape(bs, t, c*2) + # alpha_at_face = self.at_attn_face(alpha_at_face).reshape(bs, t, c, 2) + # alpha_at_face = alpha_at_face.softmax(dim=-1) + # fusion_face = in_word_face * alpha_at_face[:,:,:,1] + in_audio_face * alpha_at_face[:,:,:,0] + # alpha_at_body = torch.cat([in_word_body, in_audio_body], dim=-1).reshape(bs, t, c*2) + # alpha_at_body = self.at_attn_body(alpha_at_body).reshape(bs, t, c, 2) + # alpha_at_body = alpha_at_body.softmax(dim=-1) + # fusion_body = in_word_body * alpha_at_body[:,:,:,1] + in_audio_body * alpha_at_body[:,:,:,0] + # else: + fusion_face = in_audio_face + fusion_body = in_audio_body + + masked_embeddings = self.mask_embeddings.expand_as(in_motion) + masked_motion = torch.where(mask == 1, masked_embeddings, in_motion) # bs, t, 256 + body_hint = self.motion_encoder(masked_motion) # bs t 256 + speaker_embedding_face = self.spearker_encoder_face(in_id).squeeze(2) + speaker_embedding_body = self.spearker_encoder_body(in_id).squeeze(2) + + # decode face + use_body_hints = True + if use_body_hints: + body_hint_face = self.bodyhints_face(body_hint) + fusion_face = torch.cat([fusion_face, body_hint_face], dim=2) + a2g_face = self.feature2face(fusion_face) + face_embeddings = speaker_embedding_face + face_embeddings = self.position_embeddings(face_embeddings) + decoded_face = self.face_decoder(tgt=face_embeddings, memory=a2g_face) + face_latent = self.face2latent(decoded_face) + cls_face = self.face_classifier(face_latent) + + # motion spatial encoder + body_hint_body = self.bodyhints_body(body_hint) + motion_embeddings = self.feature2motion(body_hint_body) + motion_embeddings = speaker_embedding_body + motion_embeddings + motion_embeddings = self.position_embeddings(motion_embeddings) + + # bi-directional self-attention + motion_refined_embeddings = self.motion_self_encoder(motion_embeddings) + + # audio to gesture cross-modal attention + if use_word: + a2g_motion = self.audio_feature2motion(fusion_body) + motion_refined_embeddings_in = motion_refined_embeddings + speaker_embedding_body + motion_refined_embeddings_in = self.position_embeddings(motion_refined_embeddings) + word_hints = self.wordhints_decoder(tgt=motion_refined_embeddings_in, memory=a2g_motion) + motion_refined_embeddings = motion_refined_embeddings + word_hints + + # feedforward + upper_latent = self.motion2latent_upper(motion_refined_embeddings) + hands_latent = self.motion2latent_hands(motion_refined_embeddings) + lower_latent = self.motion2latent_lower(motion_refined_embeddings) + + upper_latent_in = upper_latent + speaker_embedding_body + upper_latent_in = self.position_embeddings(upper_latent_in) + hands_latent_in = hands_latent + speaker_embedding_body + hands_latent_in = self.position_embeddings(hands_latent_in) + lower_latent_in = lower_latent + speaker_embedding_body + lower_latent_in = self.position_embeddings(lower_latent_in) + + # transformer decoder + motion_upper = self.upper_decoder(tgt=upper_latent_in, memory=hands_latent+lower_latent) + motion_hands = self.hands_decoder(tgt=hands_latent_in, memory=upper_latent+lower_latent) + motion_lower = self.lower_decoder(tgt=lower_latent_in, memory=upper_latent+hands_latent) + upper_latent = self.motion_down_upper(motion_upper+upper_latent) + hands_latent = self.motion_down_hands(motion_hands+hands_latent) + lower_latent = self.motion_down_lower(motion_lower+lower_latent) + cls_lower = self.lower_classifier(lower_latent) + cls_upper = self.upper_classifier(upper_latent) + cls_hands = self.hands_classifier(hands_latent) + + return { + "rec_face":face_latent, + "rec_upper":upper_latent, + "rec_lower":lower_latent, + "rec_hands":hands_latent, + "cls_face":cls_face, + "cls_upper":cls_upper, + "cls_lower":cls_lower, + "cls_hands":cls_hands, + } \ No newline at end of file diff --git a/models/motion_encoder.py b/models/motion_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..afa8513ea66ea0446230d796aff277ed142b8801 --- /dev/null +++ b/models/motion_encoder.py @@ -0,0 +1,789 @@ +import random +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import smplx + +# ----------- 1 full conv-based encoder------------- # +""" +from tm2t +TM2T: Stochastical and Tokenized Modeling for the Reciprocal Generation of 3D Human Motions and Texts +https://github.com/EricGuo5513/TM2T +""" +from .quantizer import * +from .utils.layer import ResBlock, init_weight + +class SCFormer(nn.Module): + def __init__(self, args): + super(VQEncoderV3, self).__init__() + + + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): # bs t n + ''' + face 51 or 106 + hand 30*(15) + upper body + lower body + global 1*3 + max length around 180 --> 450 + ''' + bs, t, n = inputs.shape + inputs = inputs.reshape(bs*t, n) + inputs = self.spatial_transformer_encoder(inputs) # bs*t c + cs = inputs.shape[1] + inputs = inputs.reshape(bs, t, cs).permute(0, 2, 1).reshape(bs*cs, t) + inputs = self.temporal_cnn_encoder(inputs) # bs*c t + ct = inputs.shape[1] + outputs = inputs.reshape(bs, cs, ct).permute(0, 2, 1) # bs ct cs + return outputs + +class VQEncoderV3(nn.Module): + def __init__(self, args): + super(VQEncoderV3, self).__init__() + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQEncoderV6(nn.Module): + def __init__(self, args): + super(VQEncoderV6, self).__init__() + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQEncoderV4(nn.Module): + def __init__(self, args): + super(VQEncoderV4, self).__init__() + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + # print(outputs.shape) + return outputs + +class VQEncoderV5(nn.Module): + def __init__(self, args): + super(VQEncoderV5, self).__init__() + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + # print(outputs.shape) + return outputs + +class VQDecoderV4(nn.Module): + def __init__(self, args): + super(VQDecoderV4, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim) + input_size = args.vae_length + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + up_factor = 2 if i < n_up - 1 else 1 + layers += [ + nn.Upsample(scale_factor=up_factor, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQDecoderV5(nn.Module): + def __init__(self, args): + super(VQDecoderV5, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim) + input_size = args.vae_length + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + up_factor = 2 if i < n_up - 1 else 1 + layers += [ + #nn.Upsample(scale_factor=up_factor, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQDecoderV7(nn.Module): + def __init__(self, args): + super(VQDecoderV7, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim+4) + input_size = args.vae_length + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + up_factor = 2 if i < n_up - 1 else 1 + layers += [ + #nn.Upsample(scale_factor=up_factor, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQDecoderV3(nn.Module): + def __init__(self, args): + super(VQDecoderV3, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim) + input_size = args.vae_length + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + layers += [ + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQDecoderV6(nn.Module): + def __init__(self, args): + super(VQDecoderV6, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim) + input_size = args.vae_length * 2 + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + layers += [ + # nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + + +# -----------2 conv+mlp based fix-length input ae ------------- # +from .utils.layer import reparameterize, ConvNormRelu, BasicBlock +""" +from Trimodal, +encoder: + bs, n, c_in --conv--> bs, n/k, c_out_0 --mlp--> bs, c_out_1, only support fixed length +decoder: + bs, c_out_1 --mlp--> bs, n/k*c_out_0 --> bs, n/k, c_out_0 --deconv--> bs, n, c_in +""" +class PoseEncoderConv(nn.Module): + def __init__(self, length, dim, feature_length=32): + super().__init__() + self.base = feature_length + self.net = nn.Sequential( + ConvNormRelu(dim, self.base, batchnorm=True), #32 + ConvNormRelu(self.base, self.base*2, batchnorm=True), #30 + ConvNormRelu(self.base*2, self.base*2, True, batchnorm=True), #14 + nn.Conv1d(self.base*2, self.base, 3) + ) + self.out_net = nn.Sequential( + nn.Linear(12*self.base, self.base*4), # for 34 frames + nn.BatchNorm1d(self.base*4), + nn.LeakyReLU(True), + nn.Linear(self.base*4, self.base*2), + nn.BatchNorm1d(self.base*2), + nn.LeakyReLU(True), + nn.Linear(self.base*2, self.base), + ) + self.fc_mu = nn.Linear(self.base, self.base) + self.fc_logvar = nn.Linear(self.base, self.base) + + def forward(self, poses, variational_encoding=None): + poses = poses.transpose(1, 2) # to (bs, dim, seq) + out = self.net(poses) + out = out.flatten(1) + out = self.out_net(out) + mu = self.fc_mu(out) + logvar = self.fc_logvar(out) + if variational_encoding: + z = reparameterize(mu, logvar) + else: + z = mu + return z, mu, logvar + + +class PoseDecoderFC(nn.Module): + def __init__(self, gen_length, pose_dim, use_pre_poses=False): + super().__init__() + self.gen_length = gen_length + self.pose_dim = pose_dim + self.use_pre_poses = use_pre_poses + + in_size = 32 + if use_pre_poses: + self.pre_pose_net = nn.Sequential( + nn.Linear(pose_dim * 4, 32), + nn.BatchNorm1d(32), + nn.ReLU(), + nn.Linear(32, 32), + ) + in_size += 32 + + self.net = nn.Sequential( + nn.Linear(in_size, 128), + nn.BatchNorm1d(128), + nn.ReLU(), + nn.Linear(128, 128), + nn.BatchNorm1d(128), + nn.ReLU(), + nn.Linear(128, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, gen_length * pose_dim), + ) + + def forward(self, latent_code, pre_poses=None): + if self.use_pre_poses: + pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1)) + feat = torch.cat((pre_pose_feat, latent_code), dim=1) + else: + feat = latent_code + output = self.net(feat) + output = output.view(-1, self.gen_length, self.pose_dim) + return output + + +class PoseDecoderConv(nn.Module): + def __init__(self, length, dim, use_pre_poses=False, feature_length=32): + super().__init__() + self.use_pre_poses = use_pre_poses + self.feat_size = feature_length + + if use_pre_poses: + self.pre_pose_net = nn.Sequential( + nn.Linear(dim * 4, 32), + nn.BatchNorm1d(32), + nn.ReLU(), + nn.Linear(32, 32), + ) + self.feat_size += 32 + + if length == 64: + self.pre_net = nn.Sequential( + nn.Linear(self.feat_size, self.feat_size), + nn.BatchNorm1d(self.feat_size), + nn.LeakyReLU(True), + nn.Linear(self.feat_size, self.feat_size//8*64), + ) + elif length == 34: + self.pre_net = nn.Sequential( + nn.Linear(self.feat_size, self.feat_size*2), + nn.BatchNorm1d(self.feat_size*2), + nn.LeakyReLU(True), + nn.Linear(self.feat_size*2, self.feat_size//8*34), + ) + elif length == 32: + self.pre_net = nn.Sequential( + nn.Linear(self.feat_size, self.feat_size*2), + nn.BatchNorm1d(self.feat_size*2), + nn.LeakyReLU(True), + nn.Linear(self.feat_size*2, self.feat_size//8*32), + ) + else: + assert False + self.decoder_size = self.feat_size//8 + self.net = nn.Sequential( + nn.ConvTranspose1d(self.decoder_size, self.feat_size, 3), + nn.BatchNorm1d(self.feat_size), + nn.LeakyReLU(0.2, True), + + nn.ConvTranspose1d(self.feat_size, self.feat_size, 3), + nn.BatchNorm1d(self.feat_size), + nn.LeakyReLU(0.2, True), + nn.Conv1d(self.feat_size, self.feat_size*2, 3), + nn.Conv1d(self.feat_size*2, dim, 3), + ) + + def forward(self, feat, pre_poses=None): + if self.use_pre_poses: + pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1)) + feat = torch.cat((pre_pose_feat, feat), dim=1) + #print(feat.shape) + out = self.pre_net(feat) + #print(out.shape) + out = out.view(feat.shape[0], self.decoder_size, -1) + #print(out.shape) + out = self.net(out) + out = out.transpose(1, 2) + return out + +''' +Our CaMN Modification +''' +class PoseEncoderConvResNet(nn.Module): + def __init__(self, length, dim, feature_length=32): + super().__init__() + self.base = feature_length + self.conv1=BasicBlock(dim, self.base, reduce_first = 1, downsample = False, first_dilation=1) #34 + self.conv2=BasicBlock(self.base, self.base*2, downsample = False, first_dilation=1,) #34 + self.conv3=BasicBlock(self.base*2, self.base*2, first_dilation=1, downsample = True, stride=2)#17 + self.conv4=BasicBlock(self.base*2, self.base, first_dilation=1, downsample = False) + + self.out_net = nn.Sequential( + # nn.Linear(864, 256), # for 64 frames + nn.Linear(17*self.base, self.base*4), # for 34 frames + nn.BatchNorm1d(self.base*4), + nn.LeakyReLU(True), + nn.Linear(self.base*4, self.base*2), + nn.BatchNorm1d(self.base*2), + nn.LeakyReLU(True), + nn.Linear(self.base*2, self.base), + ) + + self.fc_mu = nn.Linear(self.base, self.base) + self.fc_logvar = nn.Linear(self.base, self.base) + + def forward(self, poses, variational_encoding=None): + poses = poses.transpose(1, 2) # to (bs, dim, seq) + out1 = self.conv1(poses) + out2 = self.conv2(out1) + out3 = self.conv3(out2) + out = self.conv4(out3) + out = out.flatten(1) + out = self.out_net(out) + mu = self.fc_mu(out) + logvar = self.fc_logvar(out) + if variational_encoding: + z = reparameterize(mu, logvar) + else: + z = mu + return z, mu, logvar + + +# -----------3 lstm ------------- # +''' +bs, n, c_int --> bs, n, c_out or bs, 1 (hidden), c_out +''' +class AELSTM(nn.Module): + def __init__(self, args): + super().__init__() + self.motion_emb = nn.Linear(args.vae_test_dim, args.vae_length) + self.lstm = nn.LSTM(args.vae_length, hidden_size=args.vae_length, num_layers=4, batch_first=True, + bidirectional=True, dropout=0.3) + self.out = nn.Sequential( + nn.Linear(args.vae_length, args.vae_length//2), + nn.LeakyReLU(0.2, True), + nn.Linear(args.vae_length//2, args.vae_test_dim) + ) + self.hidden_size = args.vae_length + + def forward(self, inputs): + poses = self.motion_emb(inputs) + out, _ = self.lstm(poses) + out = out[:, :, :self.hidden_size] + out[:, :, self.hidden_size:] + out_poses = self.out(out) + return { + "poses_feat":out, + "rec_pose": out_poses, + } + +class PoseDecoderLSTM(nn.Module): + """ + input bs*n*64 + """ + def __init__(self,pose_dim, feature_length): + super().__init__() + self.pose_dim = pose_dim + self.base = feature_length + self.hidden_size = 256 + self.lstm_d = nn.LSTM(self.base, hidden_size=self.hidden_size, num_layers=4, batch_first=True, + bidirectional=True, dropout=0.3) + self.out_d = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size // 2), + nn.LeakyReLU(True), + nn.Linear(self.hidden_size // 2, self.pose_dim) + ) + + def forward(self, latent_code): + output, _ = self.lstm_d(latent_code) + output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] # sum bidirectional outputs + #print("outd:", output.shape) + output = self.out_d(output.reshape(-1, output.shape[2])) + output = output.view(latent_code.shape[0], latent_code.shape[1], -1) + #print("resotuput:", output.shape) + return output + +# ---------------4 transformer --------------- # +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0)#.transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + #print(self.pe.shape, x.shape) + x = x + self.pe[:, :x.shape[1]] + return self.dropout(x) + +class Encoder_TRANSFORMER(nn.Module): + def __init__(self, args): + super().__init__() + self.skelEmbedding = nn.Linear(args.vae_test_dim, args.vae_length) + self.sequence_pos_encoder = PositionalEncoding(args.vae_length, 0.3) + seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=args.vae_length, + nhead=4, + dim_feedforward=1025, + dropout=0.3, + activation="gelu", + batch_first=True + ) + self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, + num_layers=4) + def _generate_square_subsequent_mask(self, sz): + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def forward(self, inputs): + x = self.skelEmbedding(inputs) #bs * n * 128 + #print(x.shape) + xseq = self.sequence_pos_encoder(x) + device = xseq.device + #mask = self._generate_square_subsequent_mask(xseq.size(1)).to(device) + final = self.seqTransEncoder(xseq) + #print(final.shape) + mu = final[:, 0:1, :] + logvar = final[:, 1:2, :] + return final, mu, logvar + +class Decoder_TRANSFORMER(nn.Module): + def __init__(self, args): + super().__init__() + self.vae_test_len = args.vae_test_len + self.vae_length = args.vae_length + self.sequence_pos_encoder = PositionalEncoding(args.vae_length, 0.3) + seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=args.vae_length, + nhead=4, + dim_feedforward=1024, + dropout=0.3, + activation="gelu", + batch_first=True) + self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, + num_layers=4) + self.finallayer = nn.Linear(args.vae_length, args.vae_test_dim) + + def forward(self, inputs): + timequeries = torch.zeros(inputs.shape[0], self.vae_test_len, self.vae_length, device=inputs.device) + timequeries = self.sequence_pos_encoder(timequeries) + output = self.seqTransDecoder(tgt=timequeries, memory=inputs) + output = self.finallayer(output) + return output + +# --------- 5 skcnn --------------- # +''' +from NeMF, +NeMF: Neural Motion Fields for Kinematic Animation +''' +from .utils.skeleton import ResidualBlock, SkeletonResidual, residual_ratio, SkeletonConv, SkeletonPool, find_neighbor, build_edge_topology +class LocalEncoder(nn.Module): + def __init__(self, args, topology): + super(LocalEncoder, self).__init__() + args.channel_base = 6 + args.activation = "tanh" + args.use_residual_blocks=True + args.z_dim=1024 + args.temporal_scale=8 + args.kernel_size=4 + args.num_layers=args.vae_layer + args.skeleton_dist=2 + args.extra_conv=0 + # check how to reflect in 1d + args.padding_mode="constant" + args.skeleton_pool="mean" + args.upsampling="linear" + + + self.topologies = [topology] + self.channel_base = [args.channel_base] + + self.channel_list = [] + self.edge_num = [len(topology)] + self.pooling_list = [] + self.layers = nn.ModuleList() + self.args = args + # self.convs = [] + + kernel_size = args.kernel_size + kernel_even = False if kernel_size % 2 else True + padding = (kernel_size - 1) // 2 + bias = True + self.grow = args.vae_grow + for i in range(args.num_layers): + self.channel_base.append(self.channel_base[-1]*self.grow[i]) + + for i in range(args.num_layers): + seq = [] + neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist) + in_channels = self.channel_base[i] * self.edge_num[i] + out_channels = self.channel_base[i + 1] * self.edge_num[i] + if i == 0: + self.channel_list.append(in_channels) + self.channel_list.append(out_channels) + last_pool = True if i == args.num_layers - 1 else False + + # (T, J, D) => (T, J', D) + pool = SkeletonPool(edges=self.topologies[i], pooling_mode=args.skeleton_pool, + channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool) + + if args.use_residual_blocks: + # (T, J, D) => (T/2, J', 2D) + seq.append(SkeletonResidual(self.topologies[i], neighbour_list, joint_num=self.edge_num[i], in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=2, padding=padding, padding_mode=args.padding_mode, bias=bias, + extra_conv=args.extra_conv, pooling_mode=args.skeleton_pool, activation=args.activation, last_pool=last_pool)) + else: + for _ in range(args.extra_conv): + # (T, J, D) => (T, J, D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, + joint_num=self.edge_num[i], kernel_size=kernel_size - 1 if kernel_even else kernel_size, + stride=1, + padding=padding, padding_mode=args.padding_mode, bias=bias)) + seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh()) + # (T, J, D) => (T/2, J, 2D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=self.edge_num[i], kernel_size=kernel_size, stride=2, + padding=padding, padding_mode=args.padding_mode, bias=bias, add_offset=False, + in_offset_channel=3 * self.channel_base[i] // self.channel_base[0])) + # self.convs.append(seq[-1]) + + seq.append(pool) + seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh()) + self.layers.append(nn.Sequential(*seq)) + + self.topologies.append(pool.new_edges) + self.pooling_list.append(pool.pooling_list) + self.edge_num.append(len(self.topologies[-1])) + + # in_features = self.channel_base[-1] * len(self.pooling_list[-1]) + # in_features *= int(args.temporal_scale / 2) + # self.reduce = nn.Linear(in_features, args.z_dim) + # self.mu = nn.Linear(in_features, args.z_dim) + # self.logvar = nn.Linear(in_features, args.z_dim) + + def forward(self, input): + #bs, n, c = input.shape[0], input.shape[1], input.shape[2] + output = input.permute(0, 2, 1)#input.reshape(bs, n, -1, 6) + for layer in self.layers: + output = layer(output) + #output = output.view(output.shape[0], -1) + output = output.permute(0, 2, 1) + return output \ No newline at end of file diff --git a/models/motion_representation.py b/models/motion_representation.py new file mode 100644 index 0000000000000000000000000000000000000000..b5d93b49931d45ae0b8bf5013c76b08445e13eff --- /dev/null +++ b/models/motion_representation.py @@ -0,0 +1,431 @@ +import random +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import smplx +import copy +from .motion_encoder import * + +# ----------- AE, VAE ------------- # +class VAEConvZero(nn.Module): + def __init__(self, args): + super(VAEConvZero, self).__init__() + self.encoder = VQEncoderV5(args) + # self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV5(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + # embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(pre_latent) + return { + # "poses_feat":vq_latent, + # "embedding_loss":embedding_loss, + # "perplexity":perplexity, + "rec_pose": rec_pose + } + +class VAEConv(nn.Module): + def __init__(self, args): + super(VAEConv, self).__init__() + self.encoder = VQEncoderV3(args) + self.decoder = VQDecoderV3(args) + self.fc_mu = nn.Linear(args.vae_length, args.vae_length) + self.fc_logvar = nn.Linear(args.vae_length, args.vae_length) + self.variational = args.variational + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + mu, logvar = None, None + if self.variational: + mu = self.fc_mu(pre_latent) + logvar = self.fc_logvar(pre_latent) + pre_latent = reparameterize(mu, logvar) + rec_pose = self.decoder(pre_latent) + return { + "poses_feat":pre_latent, + "rec_pose": rec_pose, + "pose_mu": mu, + "pose_logvar": logvar, + } + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + if self.variational: + mu = self.fc_mu(pre_latent) + logvar = self.fc_logvar(pre_latent) + pre_latent = reparameterize(mu, logvar) + return pre_latent + + def decode(self, pre_latent): + rec_pose = self.decoder(pre_latent) + return rec_pose + +class VAESKConv(VAEConv): + def __init__(self, args): + super(VAESKConv, self).__init__(args) + smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz' + smpl_data = np.load(smpl_fname, encoding='latin1') + parents = smpl_data['kintree_table'][0].astype(np.int32) + edges = build_edge_topology(parents) + self.encoder = LocalEncoder(args, edges) + self.decoder = VQDecoderV3(args) + +class VAEConvMLP(VAEConv): + def __init__(self, args): + super(VAEConvMLP, self).__init__(args) + self.encoder = PoseEncoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length) + self.decoder = PoseDecoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length) + +class VAELSTM(VAEConv): + def __init__(self, args): + super(VAELSTM, self).__init__(args) + pose_dim = args.vae_test_dim + feature_length = args.vae_length + self.encoder = PoseEncoderLSTM_Resnet(pose_dim, feature_length=feature_length) + self.decoder = PoseDecoderLSTM(pose_dim, feature_length=feature_length) + +class VAETransformer(VAEConv): + def __init__(self, args): + super(VAETransformer, self).__init__(args) + self.encoder = Encoder_TRANSFORMER(args) + self.decoder = Decoder_TRANSFORMER(args) + +# ----------- VQVAE --------------- # +class VQVAEConv(nn.Module): + def __init__(self, args): + super(VQVAEConv, self).__init__() + self.encoder = VQEncoderV3(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV3(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + +class VQVAESKConv(VQVAEConv): + def __init__(self, args): + super(VQVAESKConv, self).__init__(args) + smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz' + smpl_data = np.load(smpl_fname, encoding='latin1') + parents = smpl_data['kintree_table'][0].astype(np.int32) + edges = build_edge_topology(parents) + self.encoder = LocalEncoder(args, edges) + + +class VQVAEConvStride(nn.Module): + def __init__(self, args): + super(VQVAEConvStride, self).__init__() + self.encoder = VQEncoderV4(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV4(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + +class VQVAEConvZero(nn.Module): + def __init__(self, args): + super(VQVAEConvZero, self).__init__() + self.encoder = VQEncoderV5(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV5(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + + +class VAEConvZero(nn.Module): + def __init__(self, args): + super(VAEConvZero, self).__init__() + self.encoder = VQEncoderV5(args) + # self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV5(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + # embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(pre_latent) + return { + # "poses_feat":vq_latent, + # "embedding_loss":embedding_loss, + # "perplexity":perplexity, + "rec_pose": rec_pose + } + + # def map2index(self, inputs): + # pre_latent = self.encoder(inputs) + # index = self.quantizer.map2index(pre_latent) + # return index + + # def map2latent(self, inputs): + # pre_latent = self.encoder(inputs) + # index = self.quantizer.map2index(pre_latent) + # z_q = self.quantizer.get_codebook_entry(index) + # return z_q + + # def decode(self, index): + # z_q = self.quantizer.get_codebook_entry(index) + # rec_pose = self.decoder(z_q) + # return rec_pose + + +class VQVAEConvZero3(nn.Module): + def __init__(self, args): + super(VQVAEConvZero3, self).__init__() + self.encoder = VQEncoderV5(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV5(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + +class VQVAEConvZero2(nn.Module): + def __init__(self, args): + super(VQVAEConvZero2, self).__init__() + self.encoder = VQEncoderV5(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV7(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + +class VQVAE2(nn.Module): + def __init__(self, args): + super(VQVAE2, self).__init__() + # Bottom-level encoder and decoder + args_bottom = copy.deepcopy(args) + args_bottom.vae_layer = 2 + self.bottom_encoder = VQEncoderV6(args_bottom) + self.bottom_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + args_bottom.vae_test_dim = args.vae_test_dim + self.bottom_decoder = VQDecoderV6(args_bottom) + + # Top-level encoder and decoder + args_top = copy.deepcopy(args) + args_top.vae_layer = 3 + args_top.vae_test_dim = args.vae_length + self.top_encoder = VQEncoderV3(args_top) # Adjust according to the top level's design + self.quantize_conv_t = nn.Conv1d(args.vae_length+args.vae_length, args.vae_length, 1) + self.top_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + # self.upsample_t_up = nn.Upsample(scale_factor=2, mode='nearest') + layers = [ + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + self.upsample_t= nn.Sequential(*layers) + self.top_decoder = VQDecoderV3(args_top) # Adjust to handle top level features appropriately + + def forward(self, inputs): + # Bottom-level processing + enc_b = self.bottom_encoder(inputs) + enc_t = self.top_encoder(enc_b) + #print(enc_b.shape, enc_t.shape) + top_embedding_loss, quant_t, _, top_perplexity = self.top_quantizer(enc_t) + #print(quant_t.shape) + dec_t = self.top_decoder(quant_t) + #print(dec_t.shape) + enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1) + #print(enc_b.shape) + quant_b = self.quantize_conv_t(enc_b).permute(0,2,1) + #print("5",quant_b.shape) + bottom_embedding_loss, quant_b, _, bottom_perplexity = self.bottom_quantizer(quant_b) + #print("6",quant_b.shape) + upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1) + #print("7",upsample_t.shape) + quant = torch.cat([upsample_t, quant_b], 2) + rec_pose = self.bottom_decoder(quant) + # print(quant_t.shape, quant_b.shape, rec_pose.shape) + return { + "poses_feat_top": quant_t, + "pose_feat_bottom": quant_b, + "embedding_loss":top_embedding_loss+bottom_embedding_loss, + #"perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + enc_b = self.bottom_encoder(inputs) + enc_t = self.top_encoder(enc_b) + + _, quant_t, _, _ = self.top_quantizer(enc_t) + top_index = self.top_quantizer.map2index(enc_t) + dec_t = self.top_decoder(quant_t) + + enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1) + #print(enc_b.shape) + quant_b = self.quantize_conv_t(enc_b).permute(0,2,1) + # quant_b = self.quantize_conv_t(enc_b) + bottom_index = self.bottom_quantizer.map2index(quant_b) + return top_index, bottom_index + + def get_top_laent(self, top_index): + z_q_top = self.top_quantizer.get_codebook_entry(top_index) + return z_q_top + + def map2latent(self, inputs): + enc_b = self.bottom_encoder(inputs) + enc_t = self.top_encoder(enc_b) + + _, quant_t, _, _ = self.top_quantizer(enc_t) + top_index = self.top_quantizer.map2index(enc_t) + dec_t = self.top_decoder(quant_t) + + enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1) + #print(enc_b.shape) + quant_b = self.quantize_conv_t(enc_b).permute(0,2,1) + # quant_b = self.quantize_conv_t(enc_b) + bottom_index = self.bottom_quantizer.map2index(quant_b) + z_q_top = self.top_quantizer.get_codebook_entry(top_index) + z_q_bottom = self.bottom_quantizer.get_codebook_entry(bottom_index) + return z_q_top, z_q_bottom + + def map2latent_top(self, inputs): + enc_b = self.bottom_encoder(inputs) + enc_t = self.top_encoder(enc_b) + top_index = self.top_quantizer.map2index(enc_t) + z_q_top = self.top_quantizer.get_codebook_entry(top_index) + return z_q_top + + def decode(self, top_index, bottom_index): + quant_t = self.top_quantizer.get_codebook_entry(top_index) + quant_b = self.bottom_quantizer.get_codebook_entry(bottom_index) + upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1) + #print("7",upsample_t.shape) + quant = torch.cat([upsample_t, quant_b], 2) + rec_pose = self.bottom_decoder(quant) + return rec_pose \ No newline at end of file diff --git a/models/quantizer.py b/models/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..896c973ea25513e27feccc564d85d5dd361a4dc5 --- /dev/null +++ b/models/quantizer.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Quantizer(nn.Module): + def __init__(self, n_e, e_dim, beta): + super(Quantizer, self).__init__() + + self.e_dim = e_dim + self.n_e = n_e + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vectort that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + :param z (B, seq_len, channel): + :return z_q: + """ + assert z.shape[-1] == self.e_dim + z_flattened = z.contiguous().view(-1, self.e_dim) + + # B x V + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + + # compute loss for embedding + loss = torch.mean((z_q - z.detach())**2) + self.beta * \ + torch.mean((z_q.detach() - z)**2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10))) + return loss, z_q, min_encoding_indices, perplexity + + def map2index(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vectort that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + :param z (B, seq_len, channel): + :return z_q: + """ + assert z.shape[-1] == self.e_dim + #print(z.shape) + z_flattened = z.contiguous().view(-1, self.e_dim) + #print(z_flattened.shape) + + # B x V + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + return min_encoding_indices.reshape(z.shape[0], -1) + + def get_codebook_entry(self, indices): + """ + + :param indices(B, seq_len): + :return z_q(B, seq_len, e_dim): + """ + index_flattened = indices.view(-1) + z_q = self.embedding(index_flattened) + z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous() + return z_q + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super(EmbeddingEMA, self).__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_emb_avg): + self.embed_avg.data.mul_(self.decay).add(new_emb_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens*self.eps) * n + ) + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5): + super(EMAVectorQuantizer, self).__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + def forward(self, z): + z_flattened = z.view(-1, self.codebook_dim) + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + + min_encodings = F.one_hot(min_encoding_indices, self.num_tokens).type(z.dtype) + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + if self.training and self.embedding.update: + encoding_sum = min_encodings.sum(0) + embed_sum = min_encodings.transpose(0, 1)@z_flattened + + self.embedding.cluster_size_ema_update(encoding_sum) + self.embedding.embed_avg_ema_update(embed_sum) + self.embedding.weight_update(self.num_tokens) + + loss = self.beta * F.mse_loss(z_q.detach(), z) + + z_q = z + (z_q - z).detach() + return loss, z_q, min_encoding_indices, perplexity + + +# class GumbelQuantizer(nn.Module): +# def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, +# kl_weight=5e-4, temp_init=1.0): +# super(GumbelQuantizer, self).__init__() +# +# self.embedding_dim = embedding_dim +# self.n_embed = n_embed +# +# self.straight_through = straight_through +# self.temperature = temp_init +# self.kl_weight = kl_weight +# +# self.proj = nn.Linear(num_hiddens, n_embed) +# self.embed = nn.Embedding(n_embed, embedding_dim) diff --git a/models/utils/__init__.py b/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/utils/__pycache__/__init__.cpython-310.pyc b/models/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1312d64080b901886073ae45fd75f8c6e6a82fb5 Binary files /dev/null and b/models/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/utils/__pycache__/__init__.cpython-38.pyc b/models/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..783577b4c4864a2104d9fefdb56eb6396ae4ae50 Binary files /dev/null and b/models/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/utils/__pycache__/build_vocab.cpython-310.pyc b/models/utils/__pycache__/build_vocab.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d37c35aeb340258fe3f138a5ec939579a71b3457 Binary files /dev/null and b/models/utils/__pycache__/build_vocab.cpython-310.pyc differ diff --git a/models/utils/__pycache__/build_vocab.cpython-38.pyc b/models/utils/__pycache__/build_vocab.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa9e2a89e181579bb03c105c08e559e252ff809c Binary files /dev/null and b/models/utils/__pycache__/build_vocab.cpython-38.pyc differ diff --git a/models/utils/__pycache__/layer.cpython-310.pyc b/models/utils/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c1786b13fd7cd2f2f8b3dc04a99a6e636fefab Binary files /dev/null and b/models/utils/__pycache__/layer.cpython-310.pyc differ diff --git a/models/utils/__pycache__/layer.cpython-38.pyc b/models/utils/__pycache__/layer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..901bd5b54d07baf7781a1cbe4adb91f5d7dc4ccd Binary files /dev/null and b/models/utils/__pycache__/layer.cpython-38.pyc differ diff --git a/models/utils/__pycache__/skeleton.cpython-310.pyc b/models/utils/__pycache__/skeleton.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c21ae1831e11126e9fd3d88ee1fa0d9e79aa7ff Binary files /dev/null and b/models/utils/__pycache__/skeleton.cpython-310.pyc differ diff --git a/models/utils/__pycache__/skeleton.cpython-38.pyc b/models/utils/__pycache__/skeleton.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e237d0c91360db550a401c5e8d9eb450755647cf Binary files /dev/null and b/models/utils/__pycache__/skeleton.cpython-38.pyc differ diff --git a/models/utils/audio_utils.py b/models/utils/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39f428af596b2be187a78cc2abf36c458d35ccba --- /dev/null +++ b/models/utils/audio_utils.py @@ -0,0 +1,148 @@ +import numpy as np +import torch as t +import models.utils.dist_adapter as dist +import soundfile +import librosa +from models.utils.dist_utils import print_once + +class DefaultSTFTValues: + def __init__(self, hps): + self.sr = hps.sr + self.n_fft = 2048 + self.hop_length = 256 + self.window_size = 6 * self.hop_length + +class STFTValues: + def __init__(self, hps, n_fft, hop_length, window_size): + self.sr = hps.sr + self.n_fft = n_fft + self.hop_length = hop_length + self.window_size = window_size + +def calculate_bandwidth(dataset, hps, duration=600): + hps = DefaultSTFTValues(hps) + n_samples = int(dataset.sr * duration) + l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank() + spec_norm_total, spec_nelem = 0.0, 0.0 + while n_seen < n_samples: + x = dataset[idx] + if isinstance(x, (tuple, list)): + x, y = x + samples = x.astype(np.float64) + stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size) + spec = np.absolute(stft) + spec_norm_total += np.linalg.norm(spec) + spec_nelem += 1 + n_seen += int(np.prod(samples.shape)) + l1 += np.sum(np.abs(samples)) + total += np.sum(samples) + total_sq += np.sum(samples ** 2) + idx += max(16, dist.get_world_size()) + + if dist.is_available(): + from jukebox.utils.dist_utils import allreduce + n_seen = allreduce(n_seen) + total = allreduce(total) + total_sq = allreduce(total_sq) + l1 = allreduce(l1) + spec_nelem = allreduce(spec_nelem) + spec_norm_total = allreduce(spec_norm_total) + + mean = total / n_seen + bandwidth = dict(l2 = total_sq / n_seen - mean ** 2, + l1 = l1 / n_seen, + spec = spec_norm_total / spec_nelem) + print_once(bandwidth) + return bandwidth + +def audio_preprocess(x, hps): + # Extra layer in case we want to experiment with different preprocessing + # For two channel, blend randomly into mono (standard is .5 left, .5 right) + + # x: NTC + # x = x.float() + # if x.shape[-1]==2: + # if hps.aug_blend: + # mix=t.rand((x.shape[0],1), device=x.device) #np.random.rand() + # else: + # mix = 0.5 + # x=(mix*x[:,:,0]+(1-mix)*x[:,:,1]) + # elif x.shape[-1]==1: + # x=x[:,:,0] + # else: + # assert False, f'Expected channels {hps.channels}. Got unknown {x.shape[-1]} channels' + + # # x: NT -> NTC + # x = x.unsqueeze(2) + return x + +def audio_postprocess(x, hps): + return x + +def stft(sig, hps): + return t.stft(sig, hps.n_fft, hps.hop_length, win_length=hps.window_size, window=t.hann_window(hps.window_size, device=sig.device)) + +def spec(x, hps): + return t.norm(stft(x, hps), p=2, dim=-1) + +def norm(x): + return (x.view(x.shape[0], -1) ** 2).sum(dim=-1).sqrt() + +def squeeze(x): + if len(x.shape) == 3: + assert x.shape[-1] in [1,2] + x = t.mean(x, -1) + if len(x.shape) != 2: + raise ValueError(f'Unknown input shape {x.shape}') + return x + +def spectral_loss(x_in, x_out, hps): + hps = DefaultSTFTValues(hps) + spec_in = spec(squeeze(x_in.float()), hps) + spec_out = spec(squeeze(x_out.float()), hps) + return norm(spec_in - spec_out) + +def multispectral_loss(x_in, x_out, hps): + losses = [] + assert len(hps.multispec_loss_n_fft) == len(hps.multispec_loss_hop_length) == len(hps.multispec_loss_window_size) + args = [hps.multispec_loss_n_fft, + hps.multispec_loss_hop_length, + hps.multispec_loss_window_size] + for n_fft, hop_length, window_size in zip(*args): + hps = STFTValues(hps, n_fft, hop_length, window_size) + spec_in = spec(squeeze(x_in.float()), hps) + spec_out = spec(squeeze(x_out.float()), hps) + losses.append(norm(spec_in - spec_out)) + return sum(losses) / len(losses) + +def spectral_convergence(x_in, x_out, hps, epsilon=2e-3): + hps = DefaultSTFTValues(hps) + spec_in = spec(squeeze(x_in.float()), hps) + spec_out = spec(squeeze(x_out.float()), hps) + + gt_norm = norm(spec_in) + residual_norm = norm(spec_in - spec_out) + mask = (gt_norm > epsilon).float() + return (residual_norm * mask) / t.clamp(gt_norm, min=epsilon) + +def log_magnitude_loss(x_in, x_out, hps, epsilon=1e-4): + hps = DefaultSTFTValues(hps) + spec_in = t.log(spec(squeeze(x_in.float()), hps) + epsilon) + spec_out = t.log(spec(squeeze(x_out.float()), hps) + epsilon) + return t.mean(t.abs(spec_in - spec_out)) + +def load_audio(file, sr, offset, duration, mono=False): + # Librosa loads more filetypes than soundfile + x, _ = librosa.load(file, sr=sr, mono=mono, offset=offset/sr, duration=duration/sr) + if len(x.shape) == 1: + x = x.reshape((1, -1)) + return x + + +def save_wav(fname, aud, sr): + # clip before saving? + aud = t.clamp(aud, -1, 1).cpu().numpy() + for i in list(range(aud.shape[0])): + soundfile.write(f'{fname}/item_{i}.wav', aud[i], samplerate=sr, format='wav') + + diff --git a/models/utils/build_vocab.py b/models/utils/build_vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..fef722da80a10583927941af1cb811341e721862 --- /dev/null +++ b/models/utils/build_vocab.py @@ -0,0 +1,145 @@ +import numpy as np +import glob +import os +import pickle +import lmdb +import pyarrow +import fasttext +from loguru import logger +from scipy import linalg + + +class Vocab: + PAD_token = 0 + SOS_token = 1 + EOS_token = 2 + UNK_token = 3 + + def __init__(self, name, insert_default_tokens=True): + self.name = name + self.trimmed = False + self.word_embedding_weights = None + self.reset_dictionary(insert_default_tokens) + + def reset_dictionary(self, insert_default_tokens=True): + self.word2index = {} + self.word2count = {} + if insert_default_tokens: + self.index2word = {self.PAD_token: "", self.SOS_token: "", + self.EOS_token: "", self.UNK_token: ""} + else: + self.index2word = {self.UNK_token: ""} + self.n_words = len(self.index2word) # count default tokens + + def index_word(self, word): + if word not in self.word2index: + self.word2index[word] = self.n_words + self.word2count[word] = 1 + self.index2word[self.n_words] = word + self.n_words += 1 + else: + self.word2count[word] += 1 + + def add_vocab(self, other_vocab): + for word, _ in other_vocab.word2count.items(): + self.index_word(word) + + # remove words below a certain count threshold + def trim(self, min_count): + if self.trimmed: + return + self.trimmed = True + + keep_words = [] + + for k, v in self.word2count.items(): + if v >= min_count: + keep_words.append(k) + + print(' word trimming, kept %s / %s = %.4f' % ( + len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) + )) + + # reinitialize dictionary + self.reset_dictionary() + for word in keep_words: + self.index_word(word) + + def get_word_index(self, word): + if word in self.word2index: + return self.word2index[word] + else: + return self.UNK_token + + def load_word_vectors(self, pretrained_path, embedding_dim=300): + print(" loading word vectors from '{}'...".format(pretrained_path)) + + # initialize embeddings to random values for special words + init_sd = 1 / np.sqrt(embedding_dim) + weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) + weights = weights.astype(np.float32) + + # read word vectors + word_model = fasttext.load_model(pretrained_path) + for word, id in self.word2index.items(): + vec = word_model.get_word_vector(word) + weights[id] = vec + self.word_embedding_weights = weights + +def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None): + print(' building a language model...') + lang_model = Vocab(name) + print(' indexing words from {}'.format(data_path)) + index_words_from_textgrid(lang_model, data_path) + + if word_vec_path is not None: + lang_model.load_word_vectors(word_vec_path, feat_dim) + else: + print(' loaded from {}'.format(cache_path)) + with open(cache_path, 'rb') as f: + lang_model = pickle.load(f) + if word_vec_path is None: + lang_model.word_embedding_weights = None + elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: + logging.warning(' failed to load word embedding weights. check this') + assert False + + with open(cache_path, 'wb') as f: + pickle.dump(lang_model, f) + + return lang_model + +def index_words(lang_model, data_path): + #index words form text + with open(data_path, "r") as f: + for line in f.readlines(): + line = line.replace(",", " ") + line = line.replace(".", " ") + line = line.replace("?", " ") + line = line.replace("!", " ") + for word in line.split(): + lang_model.index_word(word) + print(' indexed %d words' % lang_model.n_words) + +def index_words_from_textgrid(lang_model, data_path): + import textgrid as tg + trainvaltest=os.listdir(data_path) + for loadtype in trainvaltest: + if "." in loadtype: continue #ignore .ipynb_checkpoints + texts = os.listdir(data_path+loadtype+"/text/") + for textfile in texts: + tgrid = tg.TextGrid.fromFile(data_path+loadtype+"/text/"+textfile) + for word in tgrid[0]: + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + word_n = word_n.replace(",", " ") + word_n = word_n.replace(".", " ") + word_n = word_n.replace("?", " ") + word_n = word_n.replace("!", " ") + #print(word_n) + lang_model.index_word(word_n) + print(' indexed %d words' % lang_model.n_words) + +if __name__ == "__main__": + #11195 for all, 5793 for 4 speakers + build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300) + \ No newline at end of file diff --git a/models/utils/fk.py b/models/utils/fk.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ae32341c1ccab559772e053fecf6ded608dc1f --- /dev/null +++ b/models/utils/fk.py @@ -0,0 +1,149 @@ +"""Based on Daniel Holden code from: + A Deep Learning Framework for Character Motion Synthesis and Editing + (http://www.ipab.inf.ed.ac.uk/cgvu/motionsynthesis.pdf) +""" + +import os + +import numpy as np +import torch +import torch.nn as nn +from .rotations import euler_angles_to_matrix, quaternion_to_matrix, rotation_6d_to_matrix + + +class ForwardKinematicsLayer(nn.Module): + """ Forward Kinematics Layer Class """ + + def __init__(self, args=None, parents=None, positions=None, device=None): + super().__init__() + self.b_idxs = None + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = device + if parents is None and positions is None: + # Load SMPL skeleton (their joint order is different from the one we use for bvh export) + smpl_fname = os.path.join(args.smpl.smpl_body_model, args.data.gender, 'model.npz') + smpl_data = np.load(smpl_fname, encoding='latin1') + self.parents = torch.from_numpy(smpl_data['kintree_table'][0].astype(np.int32)).to(self.device) + self.parents = self.parents.long() + self.positions = torch.from_numpy(smpl_data['J'].astype(np.float32)).to(self.device) + self.positions[1:] -= self.positions[self.parents[1:]] + else: + self.parents = torch.from_numpy(parents).to(self.device) + self.parents = self.parents.long() + self.positions = torch.from_numpy(positions).to(self.device) + self.positions = self.positions.float() + self.positions[0] = 0 + + def rotate(self, t0s, t1s): + return torch.matmul(t0s, t1s) + + def identity_rotation(self, rotations): + diagonal = torch.diag(torch.tensor([1.0, 1.0, 1.0, 1.0])).to(self.device) + diagonal = torch.reshape( + diagonal, torch.Size([1] * len(rotations.shape[:2]) + [4, 4])) + ts = diagonal.repeat(rotations.shape[:2] + torch.Size([1, 1])) + return ts + + def make_fast_rotation_matrices(self, positions, rotations): + if len(rotations.shape) == 4 and rotations.shape[-2:] == torch.Size([3, 3]): + rot_matrices = rotations + elif rotations.shape[-1] == 3: + rot_matrices = euler_angles_to_matrix(rotations, convention='XYZ') + elif rotations.shape[-1] == 4: + rot_matrices = quaternion_to_matrix(rotations) + elif rotations.shape[-1] == 6: + rot_matrices = rotation_6d_to_matrix(rotations) + else: + raise NotImplementedError(f'Unimplemented rotation representation in FK layer, shape of {rotations.shape}') + + rot_matrices = torch.cat([rot_matrices, positions[..., None]], dim=-1) + zeros = torch.zeros(rot_matrices.shape[:-2] + torch.Size([1, 3])).to(self.device) + ones = torch.ones(rot_matrices.shape[:-2] + torch.Size([1, 1])).to(self.device) + zerosones = torch.cat([zeros, ones], dim=-1) + rot_matrices = torch.cat([rot_matrices, zerosones], dim=-2) + return rot_matrices + + def rotate_global(self, parents, positions, rotations): + locals = self.make_fast_rotation_matrices(positions, rotations) + globals = self.identity_rotation(rotations) + + globals = torch.cat([locals[:, 0:1], globals[:, 1:]], dim=1) + b_size = positions.shape[0] + if self.b_idxs is None: + self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device) + elif self.b_idxs.shape[-1] != b_size: + self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device) + + for i in range(1, positions.shape[1]): + globals[:, i] = self.rotate( + globals[self.b_idxs, parents[i]], locals[:, i]) + + return globals + + def get_tpose_joints(self, offsets, parents): + num_joints = len(parents) + joints = [offsets[:, 0]] + for j in range(1, len(parents)): + joints.append(joints[parents[j]] + offsets[:, j]) + + return torch.stack(joints, dim=1) + + def canonical_to_local(self, canonical_xform, global_orient=None): + """ + Args: + canonical_xform: (B, J, 3, 3) + global_orient: (B, 3, 3) + + Returns: + local_xform: (B, J, 3, 3) + """ + local_xform = torch.zeros_like(canonical_xform) + + if global_orient is None: + global_xform = canonical_xform + else: + global_xform = torch.matmul(global_orient.unsqueeze(1), canonical_xform) + for i in range(global_xform.shape[1]): + if i == 0: + local_xform[:, i] = global_xform[:, i] + else: + local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i]) + + return local_xform + + def global_to_local(self, global_xform): + """ + Args: + global_xform: (B, J, 3, 3) + + Returns: + local_xform: (B, J, 3, 3) + """ + local_xform = torch.zeros_like(global_xform) + + for i in range(global_xform.shape[1]): + if i == 0: + local_xform[:, i] = global_xform[:, i] + else: + local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i]) + + return local_xform + + def forward(self, rotations, positions=None): + """ + Args: + rotations (B, J, D) + + Returns: + The global position of each joint after FK (B, J, 3) + """ + # Get the full transform with rotations for skinning + b_size = rotations.shape[0] + if positions is None: + positions = self.positions.repeat(b_size, 1, 1) + transforms = self.rotate_global(self.parents, positions, rotations) + coordinates = transforms[:, :, :3, 3] / transforms[:, :, 3:, 3] + + return coordinates, transforms diff --git a/models/utils/layer.py b/models/utils/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..86f8013512086280656ee10225952642abe7b11e --- /dev/null +++ b/models/utils/layer.py @@ -0,0 +1,217 @@ +import random +import math +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm +import torch.nn.functional as F + +from .build_vocab import Vocab + +class Chomp1d(nn.Module): + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + return x[:, :, :-self.chomp_size].contiguous() + + +class TemporalBlock(nn.Module): + def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): + super(TemporalBlock, self).__init__() + self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp1 = Chomp1d(padding) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp2 = Chomp1d(padding) + self.relu2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, + self.conv2, self.chomp2, self.relu2, self.dropout2) + self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + self.relu = nn.ReLU() + self.init_weights() + + def init_weights(self): + self.conv1.weight.data.normal_(0, 0.01) + self.conv2.weight.data.normal_(0, 0.01) + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.01) + + def forward(self, x): + out = self.net(x) + res = x if self.downsample is None else self.downsample(x) + return self.relu(out + res) + + +class TemporalConvNet(nn.Module): + def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): + super(TemporalConvNet, self).__init__() + layers = [] + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2 ** i + in_channels = num_inputs if i == 0 else num_channels[i-1] + out_channels = num_channels[i] + layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, + padding=(kernel_size-1) * dilation_size, dropout=dropout)] + + self.network = nn.Sequential(*layers) + + def forward(self, x): + return self.network(x) + + +class TextEncoderTCN(nn.Module): + """ based on https://github.com/locuslab/TCN/blob/master/TCN/word_cnn/model.py """ + def __init__(self, args, n_words=11195, embed_size=300, pre_trained_embedding=None, + kernel_size=2, dropout=0.3, emb_dropout=0.1, word_cache=False): + super(TextEncoderTCN, self).__init__() +# if word_cache: +# self.embedding = None +# else: +# if pre_trained_embedding is not None: # use pre-trained embedding (fasttext) +# #print(pre_trained_embedding.shape) +# assert pre_trained_embedding.shape[0] == n_words +# assert pre_trained_embedding.shape[1] == embed_size +# self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding), +# freeze=args.freeze_wordembed) +# else: +# self.embedding = nn.Embedding(n_words, embed_size) + + num_channels = [args.hidden_size] #* args.n_layer + self.tcn = TemporalConvNet(embed_size, num_channels, kernel_size, dropout=dropout) + self.decoder = nn.Linear(num_channels[-1], args.word_f) + self.drop = nn.Dropout(emb_dropout) + #self.emb_dropout = emb_dropout + self.init_weights() + + def init_weights(self): + self.decoder.bias.data.fill_(0) + self.decoder.weight.data.normal_(0, 0.01) + + def forward(self, input): + #print(input.shape) +# if self.embedding is None: +# emb = self.drop(input) +# else: +# emb = self.drop(self.embedding(input)) + y = self.tcn(input.transpose(1, 2)).transpose(1, 2) + y = self.decoder(y) + return y, torch.max(y, dim=1)[0] + + + + + + + + + +def reparameterize(mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + +def ConvNormRelu(in_channels, out_channels, downsample=False, padding=0, batchnorm=True): + if not downsample: + k = 3 + s = 1 + else: + k = 4 + s = 2 + conv_block = nn.Conv1d(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) + norm_block = nn.BatchNorm1d(out_channels) + if batchnorm: + net = nn.Sequential( + conv_block, + norm_block, + nn.LeakyReLU(0.2, True) + ) + else: + net = nn.Sequential( + conv_block, + nn.LeakyReLU(0.2, True) + ) + return net + +class BasicBlock(nn.Module): + """ based on timm: https://github.com/rwightman/pytorch-image-models """ + def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm1d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(BasicBlock, self).__init__() + + self.conv1 = nn.Conv1d( + inplanes, planes, kernel_size=ker_size, stride=stride, padding=first_dilation, + dilation=dilation, bias=True) + self.bn1 = norm_layer(planes) + self.act1 = act_layer(inplace=True) + self.conv2 = nn.Conv1d( + planes, planes, kernel_size=ker_size, padding=ker_size//2, dilation=dilation, bias=True) + self.bn2 = norm_layer(planes) + self.act2 = act_layer(inplace=True) + if downsample is not None: + self.downsample = nn.Sequential( + nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, padding=first_dilation, dilation=dilation, bias=True), + norm_layer(planes), + ) + else: self.downsample=None + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act2(x) + return x + +def init_weight(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.xavier_normal_(m.weight) + # m.bias.data.fill_(0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +def init_weight_skcnn(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + # m.bias.data.fill_(0.01) + if m.bias is not None: + #nn.init.constant_(m.bias, 0) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(m.bias, -bound, bound) + +class ResBlock(nn.Module): + def __init__(self, channel): + super(ResBlock, self).__init__() + self.model = nn.Sequential( + nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), + ) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + \ No newline at end of file diff --git a/models/utils/rotation_conversions.py b/models/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bfaa1b2247622bff35d3f9b15e8eb84064aa53 --- /dev/null +++ b/models/utils/rotation_conversions.py @@ -0,0 +1,550 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/models/utils/rotations.py b/models/utils/rotations.py new file mode 100644 index 0000000000000000000000000000000000000000..55729b2724c9c34234bddb63a826aa1f9a4321b9 --- /dev/null +++ b/models/utils/rotations.py @@ -0,0 +1,587 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Union + +import torch +import torch.nn.functional as F + +Device = Union[str, torch.device] + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16] + ].reshape(batch_dim + (4,)) + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +) -> torch.Tensor: + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + if isinstance(device, str): + device = torch.device(device) + o = torch.randn((n, 4), dtype=dtype, device=device) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +) -> torch.Tensor: + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions(n, dtype=dtype, device=device) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +) -> torch.Tensor: + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device)[0] + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor: + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device) + return quaternion * scaling + + +def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor: + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, {point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) diff --git a/models/utils/skeleton.py b/models/utils/skeleton.py new file mode 100644 index 0000000000000000000000000000000000000000..123656b7516aec1b424f9f87d384837eb820ccc9 --- /dev/null +++ b/models/utils/skeleton.py @@ -0,0 +1,636 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SkeletonConv(nn.Module): + def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0, + bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0): + self.in_channels_per_joint = in_channels // joint_num + self.out_channels_per_joint = out_channels // joint_num + if in_channels % joint_num != 0 or out_channels % joint_num != 0: + raise Exception('BAD') + super(SkeletonConv, self).__init__() + + if padding_mode == 'zeros': + padding_mode = 'constant' + if padding_mode == 'reflection': + padding_mode = 'reflect' + + self.expanded_neighbour_list = [] + self.expanded_neighbour_list_offset = [] + self.neighbour_list = neighbour_list + self.add_offset = add_offset + self.joint_num = joint_num + + self.stride = stride + self.dilation = 1 + self.groups = 1 + self.padding = padding + self.padding_mode = padding_mode + self._padding_repeated_twice = (padding, padding) + + for neighbour in neighbour_list: + expanded = [] + for k in neighbour: + for i in range(self.in_channels_per_joint): + expanded.append(k * self.in_channels_per_joint + i) + self.expanded_neighbour_list.append(expanded) + + if self.add_offset: + self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels) + + for neighbour in neighbour_list: + expanded = [] + for k in neighbour: + for i in range(add_offset): + expanded.append(k * in_offset_channel + i) + self.expanded_neighbour_list_offset.append(expanded) + + self.weight = torch.zeros(out_channels, in_channels, kernel_size) + if bias: + self.bias = torch.zeros(out_channels) + else: + self.register_parameter('bias', None) + + self.mask = torch.zeros_like(self.weight) + for i, neighbour in enumerate(self.expanded_neighbour_list): + self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1 + self.mask = nn.Parameter(self.mask, requires_grad=False) + + self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \ + 'joint_num={}, stride={}, padding={}, bias={})'.format( + in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias + ) + + self.reset_parameters() + + def reset_parameters(self): + for i, neighbour in enumerate(self.expanded_neighbour_list): + """ Use temporary variable to avoid assign to copy of slice, which might lead to unexpected result """ + tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), + neighbour, ...]) + nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) + self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), + neighbour, ...] = tmp + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...]) + bound = 1 / math.sqrt(fan_in) + tmp = torch.zeros_like( + self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)]) + nn.init.uniform_(tmp, -bound, bound) + self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp + + self.weight = nn.Parameter(self.weight) + if self.bias is not None: + self.bias = nn.Parameter(self.bias) + + def set_offset(self, offset): + if not self.add_offset: + raise Exception('Wrong Combination of Parameters') + self.offset = offset.reshape(offset.shape[0], -1) + + def forward(self, input): + # print('SkeletonConv') + weight_masked = self.weight * self.mask + #print(f'input: {input.size()}') + res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), + weight_masked, self.bias, self.stride, + 0, self.dilation, self.groups) + + if self.add_offset: + offset_res = self.offset_enc(self.offset) + offset_res = offset_res.reshape(offset_res.shape + (1, )) + res += offset_res / 100 + #print(f'res: {res.size()}') + return res + + +class SkeletonLinear(nn.Module): + def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False): + super(SkeletonLinear, self).__init__() + self.neighbour_list = neighbour_list + self.in_channels = in_channels + self.out_channels = out_channels + self.in_channels_per_joint = in_channels // len(neighbour_list) + self.out_channels_per_joint = out_channels // len(neighbour_list) + self.extra_dim1 = extra_dim1 + self.expanded_neighbour_list = [] + + for neighbour in neighbour_list: + expanded = [] + for k in neighbour: + for i in range(self.in_channels_per_joint): + expanded.append(k * self.in_channels_per_joint + i) + self.expanded_neighbour_list.append(expanded) + + self.weight = torch.zeros(out_channels, in_channels) + self.mask = torch.zeros(out_channels, in_channels) + self.bias = nn.Parameter(torch.Tensor(out_channels)) + + self.reset_parameters() + + def reset_parameters(self): + for i, neighbour in enumerate(self.expanded_neighbour_list): + tmp = torch.zeros_like( + self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] + ) + self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1 + nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) + self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp + + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + self.weight = nn.Parameter(self.weight) + self.mask = nn.Parameter(self.mask, requires_grad=False) + + def forward(self, input): + input = input.reshape(input.shape[0], -1) + weight_masked = self.weight * self.mask + res = F.linear(input, weight_masked, self.bias) + if self.extra_dim1: + res = res.reshape(res.shape + (1,)) + return res + + +class SkeletonPool(nn.Module): + def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False): + super(SkeletonPool, self).__init__() + + if pooling_mode != 'mean': + raise Exception('Unimplemented pooling mode in matrix_implementation') + + self.channels_per_edge = channels_per_edge + self.pooling_mode = pooling_mode + self.edge_num = len(edges) + # self.edge_num = len(edges) + 1 + self.seq_list = [] + self.pooling_list = [] + self.new_edges = [] + degree = [0] * 100 # each element represents the degree of the corresponding joint + + for edge in edges: + degree[edge[0]] += 1 + degree[edge[1]] += 1 + + # seq_list contains multiple sub-lists where each sub-list is an edge chain from the joint whose degree > 2 to the end effectors or joints whose degree > 2. + def find_seq(j, seq): + nonlocal self, degree, edges + + if degree[j] > 2 and j != 0: + self.seq_list.append(seq) + seq = [] + + if degree[j] == 1: + self.seq_list.append(seq) + return + + for idx, edge in enumerate(edges): + if edge[0] == j: + find_seq(edge[1], seq + [idx]) + + find_seq(0, []) + # print(f'self.seq_list: {self.seq_list}') + + for seq in self.seq_list: + if last_pool: + self.pooling_list.append(seq) + continue + if len(seq) % 2 == 1: + self.pooling_list.append([seq[0]]) + self.new_edges.append(edges[seq[0]]) + seq = seq[1:] + for i in range(0, len(seq), 2): + self.pooling_list.append([seq[i], seq[i + 1]]) + self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]]) + # print(f'self.pooling_list: {self.pooling_list}') + # print(f'self.new_egdes: {self.new_edges}') + + # add global position + # self.pooling_list.append([self.edge_num - 1]) + + self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format( + len(edges), len(self.pooling_list) + ) + + self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge) + + for i, pair in enumerate(self.pooling_list): + for j in pair: + for c in range(channels_per_edge): + self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair) + + self.weight = nn.Parameter(self.weight, requires_grad=False) + + def forward(self, input: torch.Tensor): + # print('SkeletonPool') + # print(f'input: {input.size()}') + # print(f'self.weight: {self.weight.size()}') + return torch.matmul(self.weight, input) + + +class SkeletonUnpool(nn.Module): + def __init__(self, pooling_list, channels_per_edge): + super(SkeletonUnpool, self).__init__() + self.pooling_list = pooling_list + self.input_edge_num = len(pooling_list) + self.output_edge_num = 0 + self.channels_per_edge = channels_per_edge + for t in self.pooling_list: + self.output_edge_num += len(t) + + self.description = 'SkeletonUnpool(in_edge_num={}, out_edge_num={})'.format( + self.input_edge_num, self.output_edge_num, + ) + + self.weight = torch.zeros(self.output_edge_num * channels_per_edge, self.input_edge_num * channels_per_edge) + + for i, pair in enumerate(self.pooling_list): + for j in pair: + for c in range(channels_per_edge): + self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1 + + self.weight = nn.Parameter(self.weight) + self.weight.requires_grad_(False) + + def forward(self, input: torch.Tensor): + # print('SkeletonUnpool') + # print(f'input: {input.size()}') + # print(f'self.weight: {self.weight.size()}') + return torch.matmul(self.weight, input) + + +""" +Helper functions for skeleton operation +""" + + +def dfs(x, fa, vis, dist): + vis[x] = 1 + for y in range(len(fa)): + if (fa[y] == x or fa[x] == y) and vis[y] == 0: + dist[y] = dist[x] + 1 + dfs(y, fa, vis, dist) + + +""" +def find_neighbor_joint(fa, threshold): + neighbor_list = [[]] + for x in range(1, len(fa)): + vis = [0 for _ in range(len(fa))] + dist = [0 for _ in range(len(fa))] + dist[0] = 10000 + dfs(x, fa, vis, dist) + neighbor = [] + for j in range(1, len(fa)): + if dist[j] <= threshold: + neighbor.append(j) + neighbor_list.append(neighbor) + + neighbor = [0] + for i, x in enumerate(neighbor_list): + if i == 0: continue + if 1 in x: + neighbor.append(i) + neighbor_list[i] = [0] + neighbor_list[i] + neighbor_list[0] = neighbor + return neighbor_list + + +def build_edge_topology(topology, offset): + # get all edges (pa, child, offset) + edges = [] + joint_num = len(topology) + for i in range(1, joint_num): + edges.append((topology[i], i, offset[i])) + return edges +""" + + +def build_edge_topology(topology): + # get all edges (pa, child) + edges = [] + joint_num = len(topology) + edges.append((0, joint_num)) # add an edge between the root joint and a virtual joint + for i in range(1, joint_num): + edges.append((topology[i], i)) + return edges + + +def build_joint_topology(edges, origin_names): + parent = [] + offset = [] + names = [] + edge2joint = [] + joint_from_edge = [] # -1 means virtual joint + joint_cnt = 0 + out_degree = [0] * (len(edges) + 10) + for edge in edges: + out_degree[edge[0]] += 1 + + # add root joint + joint_from_edge.append(-1) + parent.append(0) + offset.append(np.array([0, 0, 0])) + names.append(origin_names[0]) + joint_cnt += 1 + + def make_topology(edge_idx, pa): + nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt + edge = edges[edge_idx] + if out_degree[edge[0]] > 1: + parent.append(pa) + offset.append(np.array([0, 0, 0])) + names.append(origin_names[edge[1]] + '_virtual') + edge2joint.append(-1) + pa = joint_cnt + joint_cnt += 1 + + parent.append(pa) + offset.append(edge[2]) + names.append(origin_names[edge[1]]) + edge2joint.append(edge_idx) + pa = joint_cnt + joint_cnt += 1 + + for idx, e in enumerate(edges): + if e[0] == edge[1]: + make_topology(idx, pa) + + for idx, e in enumerate(edges): + if e[0] == 0: + make_topology(idx, 0) + + return parent, offset, names, edge2joint + + +def calc_edge_mat(edges): + edge_num = len(edges) + # edge_mat[i][j] = distance between edge(i) and edge(j) + edge_mat = [[100000] * edge_num for _ in range(edge_num)] + for i in range(edge_num): + edge_mat[i][i] = 0 + + # initialize edge_mat with direct neighbor + for i, a in enumerate(edges): + for j, b in enumerate(edges): + link = 0 + for x in range(2): + for y in range(2): + if a[x] == b[y]: + link = 1 + if link: + edge_mat[i][j] = 1 + + # calculate all the pairs distance + for k in range(edge_num): + for i in range(edge_num): + for j in range(edge_num): + edge_mat[i][j] = min(edge_mat[i][j], edge_mat[i][k] + edge_mat[k][j]) + return edge_mat + + +def find_neighbor(edges, d): + """ + Args: + edges: The list contains N elements, each element represents (parent, child). + d: Distance between edges (the distance of the same edge is 0 and the distance of adjacent edges is 1). + + Returns: + The list contains N elements, each element is a list of edge indices whose distance <= d. + """ + edge_mat = calc_edge_mat(edges) + neighbor_list = [] + edge_num = len(edge_mat) + for i in range(edge_num): + neighbor = [] + for j in range(edge_num): + if edge_mat[i][j] <= d: + neighbor.append(j) + neighbor_list.append(neighbor) + + # # add neighbor for global part + # global_part_neighbor = neighbor_list[0].copy() + # """ + # Line #373 is buggy. Thanks @crissallan!! + # See issue #30 (https://github.com/DeepMotionEditing/deep-motion-editing/issues/30) + # However, fixing this bug will make it unable to load the pretrained model and + # affect the reproducibility of quantitative error reported in the paper. + # It is not a fatal bug so we didn't touch it and we are looking for possible solutions. + # """ + # for i in global_part_neighbor: + # neighbor_list[i].append(edge_num) + # neighbor_list.append(global_part_neighbor) + + return neighbor_list + + +def calc_node_depth(topology): + def dfs(node, topology): + if topology[node] < 0: + return 0 + return 1 + dfs(topology[node], topology) + depth = [] + for i in range(len(topology)): + depth.append(dfs(i, topology)) + + return depth + + +def residual_ratio(k): + return 1 / (k + 1) + + +class Affine(nn.Module): + def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0): + super(Affine, self).__init__() + if scale: + self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init) + else: + self.register_parameter('scale', None) + + if bias: + self.bias = nn.Parameter(torch.zeros(num_parameters)) + else: + self.register_parameter('bias', None) + + def forward(self, input): + output = input + if self.scale is not None: + scale = self.scale.unsqueeze(0) + while scale.dim() < input.dim(): + scale = scale.unsqueeze(2) + output = output.mul(scale) + + if self.bias is not None: + bias = self.bias.unsqueeze(0) + while bias.dim() < input.dim(): + bias = bias.unsqueeze(2) + output += bias + + return output + + +class BatchStatistics(nn.Module): + def __init__(self, affine=-1): + super(BatchStatistics, self).__init__() + self.affine = nn.Sequential() if affine == -1 else Affine(affine) + self.loss = 0 + + def clear_loss(self): + self.loss = 0 + + def compute_loss(self, input): + input_flat = input.view(input.size(1), input.numel() // input.size(1)) + mu = input_flat.mean(1) + logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log() + + self.loss = mu.pow(2).mean() + logvar.pow(2).mean() + + def forward(self, input): + self.compute_loss(input) + return self.affine(input) + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False): + super(ResidualBlock, self).__init__() + + self.residual_ratio = residual_ratio + self.shortcut_ratio = 1 - residual_ratio + + residual = [] + residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)) + if batch_statistics: + residual.append(BatchStatistics(out_channels)) + if not last_layer: + residual.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) + self.residual = nn.Sequential(*residual) + + self.shortcut = nn.Sequential( + nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(), + nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), + BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential() + ) + + def forward(self, input): + return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) + + +class ResidualBlockTranspose(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation): + super(ResidualBlockTranspose, self).__init__() + + self.residual_ratio = residual_ratio + self.shortcut_ratio = 1 - residual_ratio + + self.residual = nn.Sequential( + nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), + nn.PReLU() if activation == 'relu' else nn.Tanh() + ) + + self.shortcut = nn.Sequential( + nn.Upsample(scale_factor=2, mode='linear', align_corners=False) if stride == 2 else nn.Sequential(), + nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, input): + return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) + + +class SkeletonResidual(nn.Module): + def __init__(self, topology, neighbour_list, joint_num, in_channels, out_channels, kernel_size, stride, padding, padding_mode, bias, extra_conv, pooling_mode, activation, last_pool): + super(SkeletonResidual, self).__init__() + + kernel_even = False if kernel_size % 2 else True + + seq = [] + for _ in range(extra_conv): + # (T, J, D) => (T, J, D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, + joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, + stride=1, + padding=padding, padding_mode=padding_mode, bias=bias)) + seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) + # (T, J, D) => (T/2, J, 2D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=joint_num, kernel_size=kernel_size, stride=stride, + padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False)) + seq.append(nn.GroupNorm(10, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!! + self.residual = nn.Sequential(*seq) + + # (T, J, D) => (T/2, J, 2D) + self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=joint_num, kernel_size=1, stride=stride, padding=0, + bias=True, add_offset=False) + + seq = [] + # (T/2, J, 2D) => (T/2, J', 2D) + pool = SkeletonPool(edges=topology, pooling_mode=pooling_mode, + channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool) + if len(pool.pooling_list) != pool.edge_num: + seq.append(pool) + seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) + self.common = nn.Sequential(*seq) + + def forward(self, input): + output = self.residual(input) + self.shortcut(input) + + return self.common(output) + + +class SkeletonResidualTranspose(nn.Module): + def __init__(self, neighbour_list, joint_num, in_channels, out_channels, kernel_size, padding, padding_mode, bias, extra_conv, pooling_list, upsampling, activation, last_layer): + super(SkeletonResidualTranspose, self).__init__() + + kernel_even = False if kernel_size % 2 else True + + seq = [] + # (T, J, D) => (2T, J, D) + if upsampling is not None: + seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False)) + # (2T, J, D) => (2T, J', D) + unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list)) + if unpool.input_edge_num != unpool.output_edge_num: + seq.append(unpool) + self.common = nn.Sequential(*seq) + + seq = [] + for _ in range(extra_conv): + # (2T, J', D) => (2T, J', D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, + joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, + stride=1, + padding=padding, padding_mode=padding_mode, bias=bias)) + seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) + # (2T, J', D) => (2T, J', D/2) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, + stride=1, + padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False)) + self.residual = nn.Sequential(*seq) + + # (2T, J', D) => (2T, J', D/2) + self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=joint_num, kernel_size=1, stride=1, padding=0, + bias=True, add_offset=False) + + if activation == 'relu': + self.activation = nn.PReLU() if not last_layer else None + else: + self.activation = nn.Tanh() if not last_layer else None + + def forward(self, input): + output = self.common(input) + output = self.residual(output) + self.shortcut(output) + + if self.activation is not None: + return self.activation(output) + else: + return output \ No newline at end of file diff --git a/models/utils/wav2vec.py b/models/utils/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..ca23fe1d5a03834986885ed776cbf83c29e391ea --- /dev/null +++ b/models/utils/wav2vec.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import copy +import math +from transformers import Wav2Vec2Model,Wav2Vec2Config +from transformers.modeling_outputs import BaseModelOutput +from typing import Optional, Tuple +_CONFIG_FOR_DOC = "Wav2Vec2Config" + +# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model +# initialize our encoder with the pre-trained wav2vec 2.0 weights. +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.Tensor] = None, + min_masks: int = 0, +) -> np.ndarray: + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + mask_idcs = [] + padding_mask = attention_mask.ne(1) if attention_mask is not None else None + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + lengths = np.full(num_mask, mask_length) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + return mask + +# linear interpolation layer +def linear_interpolation(features, input_fps, output_fps, output_len=None): + features = features.transpose(1, 2) + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features,size=output_len,align_corners=True,mode='linear') + return output_features.transpose(1, 2) + +class Wav2Vec2Model(Wav2Vec2Model): + def __init__(self, config): + super().__init__(config) + self.args = config + self.args.audio_fps = 15 #args.audio_fps + #input_values 16K hz, 49fps, 20ms overlap, 25ms recepion field + def forward( + self, + input_values, + dataset="beat", + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + frame_num=None + ): + #print(input_values.shape) + self.config.output_attentions = True + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.feature_extractor(input_values) + hidden_states = hidden_states.transpose(1, 2) + #print(hidden_states.shape) + if dataset == "beat": + hidden_states = linear_interpolation(hidden_states, 49, self.args.audio_fps, output_len=frame_num) + #print(hidden_states.shape) + if attention_mask is not None: + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + attention_mask = torch.zeros( + hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device + ) + attention_mask[ + (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) + ] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + + hidden_states = self.feature_projection(hidden_states)[0] + #print(hidden_states.shape) + if self.config.apply_spec_augment and self.training: + batch_size, sequence_length, hidden_size = hidden_states.size() + if self.config.mask_time_prob > 0: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + self.config.mask_time_prob, + self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=2, + ) + hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) + if self.config.mask_feature_prob > 0: + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + self.config.mask_feature_prob, + self.config.mask_feature_length, + ) + mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) + hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = encoder_outputs[0] + #print(encoder_outputs.shape) + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return hidden_states +# BaseModelOutput( +# last_hidden_state=hidden_states, +# hidden_states=encoder_outputs.hidden_states, +# attentions=encoder_outputs.attentions, +# ) \ No newline at end of file diff --git a/optimizers/__init__.py b/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optimizers/__pycache__/__init__.cpython-310.pyc b/optimizers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..257820bcf39400dcaebd8e1ad434eebafd3fbfeb Binary files /dev/null and b/optimizers/__pycache__/__init__.cpython-310.pyc differ diff --git a/optimizers/__pycache__/__init__.cpython-38.pyc b/optimizers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bfcc28797c290d5ad0460a7852e190248ec9b50 Binary files /dev/null and b/optimizers/__pycache__/__init__.cpython-38.pyc differ diff --git a/optimizers/__pycache__/loss_factory.cpython-310.pyc b/optimizers/__pycache__/loss_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c6b646249e78c313765ca30051dd44f7f0cab9e Binary files /dev/null and b/optimizers/__pycache__/loss_factory.cpython-310.pyc differ diff --git a/optimizers/__pycache__/loss_factory.cpython-38.pyc b/optimizers/__pycache__/loss_factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a4c4cbc28cfcd6fd41c3059f8dae07b057f93dd Binary files /dev/null and b/optimizers/__pycache__/loss_factory.cpython-38.pyc differ diff --git a/optimizers/__pycache__/optim_factory.cpython-310.pyc b/optimizers/__pycache__/optim_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b0f24c32488aedc269f0b156a193393c4c79f08 Binary files /dev/null and b/optimizers/__pycache__/optim_factory.cpython-310.pyc differ diff --git a/optimizers/__pycache__/optim_factory.cpython-38.pyc b/optimizers/__pycache__/optim_factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3feab1343dead2088b1714dbcf2a7433460dd7fb Binary files /dev/null and b/optimizers/__pycache__/optim_factory.cpython-38.pyc differ diff --git a/optimizers/__pycache__/scheduler_factory.cpython-310.pyc b/optimizers/__pycache__/scheduler_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac64acc22233c3097ed1ae20aac5878526b93cd0 Binary files /dev/null and b/optimizers/__pycache__/scheduler_factory.cpython-310.pyc differ diff --git a/optimizers/__pycache__/scheduler_factory.cpython-38.pyc b/optimizers/__pycache__/scheduler_factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79eacb06bf857aa46df939a3ba1157f01c125a0f Binary files /dev/null and b/optimizers/__pycache__/scheduler_factory.cpython-38.pyc differ diff --git a/optimizers/loss_factory.py b/optimizers/loss_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..4986ef6a5492245ac44d9e11724e9b21d56ac08e --- /dev/null +++ b/optimizers/loss_factory.py @@ -0,0 +1,118 @@ +# Copyright (c) HuaWei, Inc. and its affiliates. +# liu.haiyang@huawei.com + +import torch.nn as nn +import torch.nn.functional as F +import torch +import numpy as np + + +class GeodesicLoss(nn.Module): + def __init__(self): + super(GeodesicLoss, self).__init__() + + def compute_geodesic_distance(self, m1, m2): + """ Compute the geodesic distance between two rotation matrices. + + Args: + m1, m2: Two rotation matrices with the shape (batch x 3 x 3). + + Returns: + The minimal angular difference between two rotation matrices in radian form [0, pi]. + """ + m1 = m1.reshape(-1, 3, 3) + m2 = m2.reshape(-1, 3, 3) + batch = m1.shape[0] + m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3 + + cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 + cos = torch.clamp(cos, min=-1 + 1E-6, max=1-1E-6) + + theta = torch.acos(cos) + + return theta + + def __call__(self, m1, m2, reduction='mean'): + loss = self.compute_geodesic_distance(m1, m2) + + if reduction == 'mean': + return loss.mean() + elif reduction == 'none': + return loss + else: + raise RuntimeError(f'unsupported reduction: {reduction}') + + +class BCE_Loss(nn.Module): + def __init__(self, args=None): + super(BCE_Loss, self).__init__() + + def forward(self, fake_outputs, real_target): + final_loss = F.cross_entropy(fake_outputs, real_target, reduce="mean") + return final_loss + +class weight_Loss(nn.Module): + def __init__(self, args=None): + super(weight_Loss, self).__init__() + def forward(self, weight_f): + weight_loss_div = torch.mean(weight_f[:, :, 0]*weight_f[:, :, 1]) + weight_loss_gap = torch.mean(-torch.log(torch.max(weight_f[:, :, 0], dim=1)[0] - torch.min(weight_f[:, :, 0], dim=1)[0])) + return weight_loss_div, weight_loss_gap + + +class HuberLoss(nn.Module): + def __init__(self, beta=0.1, reduction="mean"): + super(HuberLoss, self).__init__() + self.beta = beta + self.reduction = reduction + + def forward(self, outputs, targets): + final_loss = F.smooth_l1_loss(outputs / self.beta, targets / self.beta, reduction=self.reduction) * self.beta + return final_loss + + +class KLDLoss(nn.Module): + def __init__(self, beta=0.1): + super(KLDLoss, self).__init__() + self.beta = beta + + def forward(self, outputs, targets): + final_loss = F.smooth_l1_loss((outputs / self.beta, targets / self.beta) * self.beta) + return final_loss + + +class REGLoss(nn.Module): + def __init__(self, beta=0.1): + super(REGLoss, self).__init__() + self.beta = beta + + def forward(self, outputs, targets): + final_loss = F.smooth_l1_loss((outputs / self.beta, targets / self.beta) * self.beta) + return final_loss + + +class L2Loss(nn.Module): + def __init__(self): + super(L2Loss, self).__init__() + + def forward(self, outputs, targets): + final_loss = F.l2_loss(outputs, targets) + return final_loss + +LOSS_FUNC_LUT = { + "bce_loss": BCE_Loss, + "l2_loss": L2Loss, + "huber_loss": HuberLoss, + "kl_loss": KLDLoss, + "id_loss": REGLoss, + "GeodesicLoss": GeodesicLoss, + "weight_Loss": weight_Loss, + } + + +def get_loss_func(loss_name, **kwargs): + loss_func_class = LOSS_FUNC_LUT.get(loss_name) + loss_func = loss_func_class(**kwargs) + return loss_func + + diff --git a/optimizers/optim_factory.py b/optimizers/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..3016448c5dcfe0dea27f6fc305ab7a5b83063f7f --- /dev/null +++ b/optimizers/optim_factory.py @@ -0,0 +1,176 @@ +""" Optimizer Factory w/ Custom Weight Decay +Hacked together by / Copyright 2020 Ross Wightman +""" +from typing import Optional + +import torch +import torch.nn as nn +import torch.optim as optim + +from .timm.adafactor import Adafactor +from .timm.adahessian import Adahessian +from .timm.adamp import AdamP +from .timm.lookahead import Lookahead +from .timm.nadam import Nadam +from .timm.novograd import NovoGrad +from .timm.nvnovograd import NvNovoGrad +from .timm.radam import RAdam +from .timm.rmsprop_tf import RMSpropTF +from .timm.sgdp import SGDP +from .timm.adabelief import AdaBelief + +try: + from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD + has_apex = True +except ImportError: + has_apex = False + + +def add_weight_decay(model, weight_decay=1e-5, skip_list=()): + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + no_decay.append(param) + else: + decay.append(param) + return [ + {"params": no_decay, "weight_decay": 0.}, + {"params": decay, "weight_decay": weight_decay}] + + +def optimizer_kwargs(args, lr_weight): + """ args/argparse to kwargs helper + Convert optimizer args in argparse args or args like object to keyword args for updated create fn. + """ + kwargs = dict( + optimizer_name=args.opt, + learning_rate=args.lr_base*args.batch_size/128*lr_weight, + weight_decay=args.weight_decay, + momentum=args.momentum) + if getattr(args, "opt_eps", None) is not None: + kwargs["eps"] = args.opt_eps + if getattr(args, "opt_betas", None) is not None: + kwargs["betas"] = args.opt_betas + if getattr(args, "opt_args", None) is not None: + kwargs.update(args.opt_args) + return kwargs + + +def create_optimizer(args, model, filter_bias_and_bn=True, lr_weight=1): + """ Legacy optimizer factory for backwards compatibility. + NOTE: Use create_optimizer_v2 for new code. + """ + return create_optimizer_v2( + model, + **optimizer_kwargs(args, lr_weight), + filter_bias_and_bn=filter_bias_and_bn, + ) + + +def create_optimizer_v2( + model: nn.Module, + optimizer_name: str = "sgd", + learning_rate: Optional[float] = None, + weight_decay: float = 0., + momentum: float = 0.9, + filter_bias_and_bn: bool = True, + **kwargs): + """ Create an optimizer. + + TODO currently the model is passed in and all parameters are selected for optimization. + For more general use an interface that allows selection of parameters to optimize and lr groups, one of: + * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion + * expose the parameters interface and leave it up to caller + + Args: + model (nn.Module): model containing parameters to optimize + optimizer_name: name of optimizer to create + learning_rate: initial learning rate + weight_decay: weight decay to apply in optimizer + momentum: momentum for momentum based optimizers (others may use betas via kwargs) + filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay + **kwargs: extra optimizer specific kwargs to pass through + + Returns: + Optimizer + """ + opt_lower = optimizer_name.lower() + if weight_decay and filter_bias_and_bn: + skip = {} + if hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + parameters = add_weight_decay(model, weight_decay, skip) + weight_decay = 0. + else: + parameters = model.parameters() + if "fused" in opt_lower: + assert has_apex and torch.cuda.is_available(), "APEX and CUDA required for fused optimizers" + + opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs) + opt_split = opt_lower.split("_") + opt_lower = opt_split[-1] + if opt_lower == "sgd" or opt_lower == "nesterov": + opt_args.pop("eps", None) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == "momentum": + opt_args.pop("eps", None) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) + + elif opt_lower == "adam": + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == "adabelief": + optimizer = AdaBelief(parameters, rectify=False, **opt_args) + elif opt_lower == "adamw": + optimizer = optim.AdamW(parameters, lr=learning_rate, weight_decay=weight_decay) + + elif opt_lower == "nadam": + optimizer = Nadam(parameters, **opt_args) + elif opt_lower == "radam": + optimizer = RAdam(parameters, **opt_args) + elif opt_lower == "adamp": + optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) + elif opt_lower == "sgdp": + optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == "adadelta": + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == "adafactor": + if not learning_rate: + opt_args["lr"] = None + optimizer = Adafactor(parameters, **opt_args) + elif opt_lower == "adahessian": + optimizer = Adahessian(parameters, **opt_args) + elif opt_lower == "rmsprop": + optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) + elif opt_lower == "rmsproptf": + optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) + elif opt_lower == "novograd": + optimizer = NovoGrad(parameters, **opt_args) + elif opt_lower == "nvnovograd": + optimizer = NvNovoGrad(parameters, **opt_args) + elif opt_lower == "fusedsgd": + opt_args.pop("eps", None) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == "fusedmomentum": + opt_args.pop("eps", None) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) + elif opt_lower == "fusedadam": + optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) + elif opt_lower == "fusedadamw": + optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) + elif opt_lower == "fusedlamb": + optimizer = FusedLAMB(parameters, **opt_args) + elif opt_lower == "fusednovograd": + opt_args.setdefault("betas", (0.95, 0.98)) + optimizer = FusedNovoGrad(parameters, **opt_args) + else: + assert False and "Invalid optimizer" + raise ValueError + + if len(opt_split) > 1: + if opt_split[0] == "lookahead": + optimizer = Lookahead(optimizer) + + return optimizer diff --git a/optimizers/scheduler_factory.py b/optimizers/scheduler_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..8cff09723d2c2128317220bb2d316ddf7b035b76 --- /dev/null +++ b/optimizers/scheduler_factory.py @@ -0,0 +1,103 @@ +""" Scheduler Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from .timm.cosine_lr import CosineLRScheduler +from .timm.tanh_lr import TanhLRScheduler +from .timm.step_lr import StepLRScheduler +from .timm.plateau_lr import PlateauLRScheduler +import torch + +def create_scheduler(args, optimizer, **kwargs): + num_epochs = args.epochs + + if getattr(args, 'lr_noise', None) is not None: + lr_noise = getattr(args, 'lr_noise') + if isinstance(lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in lr_noise] + if len(noise_range) == 1: + noise_range = noise_range[0] + else: + noise_range = lr_noise * num_epochs + else: + noise_range = None + + lr_scheduler = None + if args.lr_policy == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=getattr(args, 'lr_cycle_mul', 1.), + lr_min=args.lr_min, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cycle_limit=getattr(args, 'lr_cycle_limit', 1), + t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS + elif args.lr_policy == 'tanh': + lr_scheduler = TanhLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=getattr(args, 'lr_cycle_mul', 1.), + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cycle_limit=getattr(args, 'lr_cycle_limit', 1), + t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS + elif args.lr_policy == 'step': + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=args.decay_epochs - getattr(kwargs, 'init_epoch', 0), # for D + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + elif args.lr_policy == 'plateau': + mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' + lr_scheduler = PlateauLRScheduler( + optimizer, + decay_rate=args.decay_rate, + patience_t=args.patience_epochs, + lr_min=args.min_lr, + mode=mode, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cooldown_t=0, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + elif args.lr_policy == "onecyclelr": + lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=args.LR, + total_steps=kwargs["total_steps"], + pct_start=args.PCT_START, + div_factor=args.DIV_FACTOR_ONECOS, + final_div_factor=args.FIN_DACTOR_ONCCOS, + ) + elif args.lr_policy == "cosinerestart": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0 = kwargs["total_steps"], + T_mult=2, + eta_min = 1e-6, + last_epoch=-1, + ) + return lr_scheduler \ No newline at end of file diff --git a/optimizers/timm/__init__.py b/optimizers/timm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..612eaaebb068a160663278c01db7c544a67907f3 --- /dev/null +++ b/optimizers/timm/__init__.py @@ -0,0 +1,16 @@ +from .adamp import AdamP +from .adamw import AdamW +from .adafactor import Adafactor +from .adahessian import Adahessian +from .lookahead import Lookahead +from .nadam import Nadam +from .novograd import NovoGrad +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP +from .adabelief import AdaBelief +from .cosine_lr import CosineLRScheduler +from .plateau_lr import PlateauLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler \ No newline at end of file diff --git a/optimizers/timm/__pycache__/__init__.cpython-310.pyc b/optimizers/timm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f5b186623a6a0f5c7ec4d10dd9651f0ef54a7de Binary files /dev/null and b/optimizers/timm/__pycache__/__init__.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/__init__.cpython-38.pyc b/optimizers/timm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aebdda211022209fe2e81de9f52ac42433e252c7 Binary files /dev/null and b/optimizers/timm/__pycache__/__init__.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/adabelief.cpython-310.pyc b/optimizers/timm/__pycache__/adabelief.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60aa310acfb6c775f639be5ed79458a84ebd69ff Binary files /dev/null and b/optimizers/timm/__pycache__/adabelief.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/adabelief.cpython-38.pyc b/optimizers/timm/__pycache__/adabelief.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d97344ddc842c2e474dd4911ebbf59e11ce6f450 Binary files /dev/null and b/optimizers/timm/__pycache__/adabelief.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/adafactor.cpython-310.pyc b/optimizers/timm/__pycache__/adafactor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57e308907e55deb4f259a0479cd30489294dabc0 Binary files /dev/null and b/optimizers/timm/__pycache__/adafactor.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/adafactor.cpython-38.pyc b/optimizers/timm/__pycache__/adafactor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5f03be6b540c7737f7585732cc0370e5ae15669 Binary files /dev/null and b/optimizers/timm/__pycache__/adafactor.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/adahessian.cpython-310.pyc b/optimizers/timm/__pycache__/adahessian.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cbe23b3d6f62783b0282d613dfa8f77af4d83cb Binary files /dev/null and b/optimizers/timm/__pycache__/adahessian.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/adahessian.cpython-38.pyc b/optimizers/timm/__pycache__/adahessian.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f221e3c92d0ce19db071aba34fadc9fb02aaad6 Binary files /dev/null and b/optimizers/timm/__pycache__/adahessian.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/adamp.cpython-310.pyc b/optimizers/timm/__pycache__/adamp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cf3f71303b472398b10d4c7fea189e7b3b597e8 Binary files /dev/null and b/optimizers/timm/__pycache__/adamp.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/adamp.cpython-38.pyc b/optimizers/timm/__pycache__/adamp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9a7aeeb3e374ecaaf42014d36b12a60c9f08752 Binary files /dev/null and b/optimizers/timm/__pycache__/adamp.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/adamw.cpython-310.pyc b/optimizers/timm/__pycache__/adamw.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af89f30a5c09c9d050be1ff8887c762d497e1c0a Binary files /dev/null and b/optimizers/timm/__pycache__/adamw.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/adamw.cpython-38.pyc b/optimizers/timm/__pycache__/adamw.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58c5d9df5d88bec6f18ffe3660c24581be254355 Binary files /dev/null and b/optimizers/timm/__pycache__/adamw.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/cosine_lr.cpython-310.pyc b/optimizers/timm/__pycache__/cosine_lr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce3ff4578cd641649e09e2adef68c8242417f598 Binary files /dev/null and b/optimizers/timm/__pycache__/cosine_lr.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/cosine_lr.cpython-38.pyc b/optimizers/timm/__pycache__/cosine_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fa6fd0a0f172284b6c095c6a9f677973634dffb Binary files /dev/null and b/optimizers/timm/__pycache__/cosine_lr.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/lookahead.cpython-310.pyc b/optimizers/timm/__pycache__/lookahead.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60da3f36e38a6e75354c98f7a2ba4d3f587f68dd Binary files /dev/null and b/optimizers/timm/__pycache__/lookahead.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/lookahead.cpython-38.pyc b/optimizers/timm/__pycache__/lookahead.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d5e2f822c6069b137163452cf234177d6dcc2c9 Binary files /dev/null and b/optimizers/timm/__pycache__/lookahead.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/nadam.cpython-310.pyc b/optimizers/timm/__pycache__/nadam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42797093fee962f09483cf58d13c3a24d804e820 Binary files /dev/null and b/optimizers/timm/__pycache__/nadam.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/nadam.cpython-38.pyc b/optimizers/timm/__pycache__/nadam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a07957122bc58eb846a20686a11e7311c96d799c Binary files /dev/null and b/optimizers/timm/__pycache__/nadam.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/novograd.cpython-310.pyc b/optimizers/timm/__pycache__/novograd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..653d248c2bed5d4cfebde5e16e946cea5101c9f6 Binary files /dev/null and b/optimizers/timm/__pycache__/novograd.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/novograd.cpython-38.pyc b/optimizers/timm/__pycache__/novograd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e4c53cf657b548c563a594745ec4407a8b61c3d Binary files /dev/null and b/optimizers/timm/__pycache__/novograd.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/nvnovograd.cpython-310.pyc b/optimizers/timm/__pycache__/nvnovograd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..401cec8f449a86dde8c6cc51827483ebcbac3523 Binary files /dev/null and b/optimizers/timm/__pycache__/nvnovograd.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/nvnovograd.cpython-38.pyc b/optimizers/timm/__pycache__/nvnovograd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af736f09523fb6eb3975cf994cd18eb713562208 Binary files /dev/null and b/optimizers/timm/__pycache__/nvnovograd.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/plateau_lr.cpython-310.pyc b/optimizers/timm/__pycache__/plateau_lr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5974a1df7c8d39ac251be0f499ea2f76e097c23 Binary files /dev/null and b/optimizers/timm/__pycache__/plateau_lr.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/plateau_lr.cpython-38.pyc b/optimizers/timm/__pycache__/plateau_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25332f010a4586588dc53948a1111ae8145adb58 Binary files /dev/null and b/optimizers/timm/__pycache__/plateau_lr.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/radam.cpython-310.pyc b/optimizers/timm/__pycache__/radam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a03c354ce383f866100f038e6e5c083d44f8596 Binary files /dev/null and b/optimizers/timm/__pycache__/radam.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/radam.cpython-38.pyc b/optimizers/timm/__pycache__/radam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87af09551e01548b10e2989b0d751e76ff6b5978 Binary files /dev/null and b/optimizers/timm/__pycache__/radam.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/rmsprop_tf.cpython-310.pyc b/optimizers/timm/__pycache__/rmsprop_tf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..824408d5e4271f59b0ea51a50e537c584a3633ed Binary files /dev/null and b/optimizers/timm/__pycache__/rmsprop_tf.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/rmsprop_tf.cpython-38.pyc b/optimizers/timm/__pycache__/rmsprop_tf.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e3c748143dbd858ce84b38e67040883e38be9ae Binary files /dev/null and b/optimizers/timm/__pycache__/rmsprop_tf.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/scheduler.cpython-310.pyc b/optimizers/timm/__pycache__/scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dd5420bdaf017681beb8ee8087726a4f2d1a50f Binary files /dev/null and b/optimizers/timm/__pycache__/scheduler.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/scheduler.cpython-38.pyc b/optimizers/timm/__pycache__/scheduler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d832e38d4f58a547a9cc5d0c1a83f8f5b08494d Binary files /dev/null and b/optimizers/timm/__pycache__/scheduler.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/sgdp.cpython-310.pyc b/optimizers/timm/__pycache__/sgdp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28497864acc9e275d7f0d71a7d58c273dc536d0d Binary files /dev/null and b/optimizers/timm/__pycache__/sgdp.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/sgdp.cpython-38.pyc b/optimizers/timm/__pycache__/sgdp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53a0c7b635029b3a8345a9863f1245e419065d8a Binary files /dev/null and b/optimizers/timm/__pycache__/sgdp.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/step_lr.cpython-310.pyc b/optimizers/timm/__pycache__/step_lr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3825be493a251c191553916ef90e9484d2323bcd Binary files /dev/null and b/optimizers/timm/__pycache__/step_lr.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/step_lr.cpython-38.pyc b/optimizers/timm/__pycache__/step_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33e9a167bb1a826cf7b402acec01ef0157434502 Binary files /dev/null and b/optimizers/timm/__pycache__/step_lr.cpython-38.pyc differ diff --git a/optimizers/timm/__pycache__/tanh_lr.cpython-310.pyc b/optimizers/timm/__pycache__/tanh_lr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8926113a116636b2f8559ce67acc2c1e2e902343 Binary files /dev/null and b/optimizers/timm/__pycache__/tanh_lr.cpython-310.pyc differ diff --git a/optimizers/timm/__pycache__/tanh_lr.cpython-38.pyc b/optimizers/timm/__pycache__/tanh_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a42ada70dad57ccefc7bc5ec4722ef10994f6849 Binary files /dev/null and b/optimizers/timm/__pycache__/tanh_lr.cpython-38.pyc differ diff --git a/optimizers/timm/adabelief.py b/optimizers/timm/adabelief.py new file mode 100644 index 0000000000000000000000000000000000000000..a26d7b27ac85ce65a02bc2e938058b685d914a65 --- /dev/null +++ b/optimizers/timm/adabelief.py @@ -0,0 +1,205 @@ +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdaBelief(Optimizer): + r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-16) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + weight_decouple (boolean, optional): ( default: True) If set as True, then + the optimizer uses decoupled weight decay as in AdamW + fixed_decay (boolean, optional): (default: False) This is used when weight_decouple + is set as True. + When fixed_decay == True, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay$. + When fixed_decay == False, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the + weight decay ratio decreases with learning rate (lr). + rectify (boolean, optional): (default: True) If set as True, then perform the rectified + update similar to RAdam + degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update + when variance of gradient is high + reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020 + + For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer' + For example train/args for EfficientNet see these gists + - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037 + - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, + weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True, + degenerated_to_sgd=True): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, buffer=[[None, None, None] for _ in range(10)]) + super(AdaBelief, self).__init__(params, defaults) + + self.degenerated_to_sgd = degenerated_to_sgd + self.weight_decouple = weight_decouple + self.rectify = rectify + self.fixed_decay = fixed_decay + + def __setstate__(self, state): + super(AdaBelief, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def reset(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + amsgrad = group['amsgrad'] + + # State initialization + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p.data) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # cast data type + half_precision = False + if p.data.dtype == torch.float16: + half_precision = True + p.data = p.data.float() + p.grad = p.grad.float() + + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + 'AdaBelief does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + beta1, beta2 = group['betas'] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p.data) + + # perform weight decay, check if decoupled weight decay + if self.weight_decouple: + if not self.fixed_decay: + p.data.mul_(1.0 - group['lr'] * group['weight_decay']) + else: + p.data.mul_(1.0 - group['weight_decay']) + else: + if group['weight_decay'] != 0: + grad.add_(p.data, alpha=group['weight_decay']) + + # get current state variable + exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Update first and second moment running average + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + grad_residual = grad - exp_avg + exp_avg_var.mul_(beta2).addcmul_( grad_residual, grad_residual, value=1 - beta2) + + if amsgrad: + max_exp_avg_var = state['max_exp_avg_var'] + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + # update + if not self.rectify: + # Default update + step_size = group['lr'] / bias_correction1 + p.data.addcdiv_( exp_avg, denom, value=-step_size) + + else: # Rectified update, forked from RAdam + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + if N_sma >= 5: + denom = exp_avg_var.sqrt().add_(group['eps']) + p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + elif step_size > 0: + p.data.add_( exp_avg, alpha=-step_size * group['lr']) + + if half_precision: + p.data = p.data.half() + p.grad = p.grad.half() + + return loss diff --git a/optimizers/timm/adafactor.py b/optimizers/timm/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..088ce3acd82e2be1b393afafa05f48435e538a1a --- /dev/null +++ b/optimizers/timm/adafactor.py @@ -0,0 +1,174 @@ +""" Adafactor Optimizer + +Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + +Original header/copyright below. + +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import math + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate depending on the + *scale_parameter*, *relative_step* and *warmup_init* options. + + To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) + relative_step (bool): if True, time-dependent learning rate is computed + instead of external learning rate (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, + decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + relative_step = lr is None + if warmup_init and not relative_step: + raise ValueError('warmup_init requires relative_step=True') + + beta1 = None if betas is None else betas[0] # make it compat with standard betas arg + defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, + beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, + relative_step=relative_step, warmup_init=warmup_init) + super(Adafactor, self).__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + if param_group['relative_step']: + min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 + lr_t = min(min_step, 1.0 / math.sqrt(param_state['step'])) + param_scale = 1.0 + if param_group['scale_parameter']: + param_scale = max(param_group['eps_scale'], param_state['RMS']) + param_group['lr'] = lr_t * param_scale + return param_group['lr'] + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group['beta1'] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError('Adafactor does not support sparse gradients.') + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state['step'] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + if factored: + state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['RMS'] = 0 + else: + if use_first_moment: + state['exp_avg'] = state['exp_avg'].to(grad) + if factored: + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + else: + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state['step'] += 1 + state['RMS'] = self._rms(p_data_fp32) + lr_t = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) + update = grad ** 2 + group['eps'] + if factored: + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] + + exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1)) + exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2)) + #exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+ + #exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update) + #exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+ + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) + update.mul_(lr_t) + + if use_first_moment: + exp_avg = state['exp_avg'] + exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update) + #exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+ + update = exp_avg + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32) + #p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+ + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss \ No newline at end of file diff --git a/optimizers/timm/adahessian.py b/optimizers/timm/adahessian.py new file mode 100644 index 0000000000000000000000000000000000000000..985c67ca686a65f61f5c5b1a7db3e5bba815a19b --- /dev/null +++ b/optimizers/timm/adahessian.py @@ -0,0 +1,156 @@ +""" AdaHessian Optimizer + +Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py +Originally licensed MIT, Copyright 2020, David Samuel +""" +import torch + + +class Adahessian(torch.optim.Optimizer): + """ + Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): learning rate (default: 0.1) + betas ((float, float), optional): coefficients used for computing running averages of gradient and the + squared hessian trace (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) + hessian_power (float, optional): exponent of the hessian trace (default: 1.0) + update_each (int, optional): compute the hessian trace approximation only after *this* number of steps + (to save time) (default: 1) + n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) + """ + + def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, + hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= hessian_power <= 1.0: + raise ValueError(f"Invalid Hessian power value: {hessian_power}") + + self.n_samples = n_samples + self.update_each = update_each + self.avg_conv_kernel = avg_conv_kernel + + # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training + self.seed = 2147483647 + self.generator = torch.Generator().manual_seed(self.seed) + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) + super(Adahessian, self).__init__(params, defaults) + + for p in self.get_params(): + p.hess = 0.0 + self.state[p]["hessian step"] = 0 + + @property + def is_second_order(self): + return True + + def get_params(self): + """ + Gets all parameters in all param_groups with gradients + """ + + return (p for group in self.param_groups for p in group['params'] if p.requires_grad) + + def zero_hessian(self): + """ + Zeros out the accumalated hessian traces. + """ + + for p in self.get_params(): + if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0: + p.hess.zero_() + + @torch.no_grad() + def set_hessian(self): + """ + Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. + """ + + params = [] + for p in filter(lambda p: p.grad is not None, self.get_params()): + if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step + params.append(p) + self.state[p]["hessian step"] += 1 + + if len(params) == 0: + return + + if self.generator.device != params[0].device: # hackish way of casting the generator to the right device + self.generator = torch.Generator(params[0].device).manual_seed(self.seed) + + grads = [p.grad for p in params] + + for i in range(self.n_samples): + # Rademacher distribution {-1.0, 1.0} + zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] + h_zs = torch.autograd.grad( + grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1) + for h_z, z, p in zip(h_zs, zs, params): + p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + Arguments: + closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) + """ + + loss = None + if closure is not None: + loss = closure() + + self.zero_hessian() + self.set_hessian() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None or p.hess is None: + continue + + if self.avg_conv_kernel and p.dim() == 4: + p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone() + + # Perform correct stepweight decay as in AdamW + p.mul_(1 - group['lr'] * group['weight_decay']) + + state = self.state[p] + + # State initialization + if len(state) == 1: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of Hessian diagonal square values + state['exp_hessian_diag_sq'] = torch.zeros_like(p) + + exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] + beta1, beta2 = group['betas'] + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) + exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + k = group['hessian_power'] + denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps']) + + # make update + step_size = group['lr'] / bias_correction1 + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/optimizers/timm/adamp.py b/optimizers/timm/adamp.py new file mode 100644 index 0000000000000000000000000000000000000000..468c3e865e0ceb6fb2bf22f9388237a783314f07 --- /dev/null +++ b/optimizers/timm/adamp.py @@ -0,0 +1,107 @@ +""" +AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn as nn +from torch.optim.optimizer import Optimizer, required +import math + +class AdamP(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) + super(AdamP, self).__init__(params, defaults) + + def _channel_view(self, x): + return x.view(x.size(0), -1) + + def _layer_view(self, x): + return x.view(1, -1) + + def _cosine_similarity(self, x, y, eps, view_func): + x = view_func(x) + y = view_func(y) + + x_norm = x.norm(dim=1).add_(eps) + y_norm = y.norm(dim=1).add_(eps) + dot = (x * y).sum(dim=1) + + return dot.abs() / x_norm / y_norm + + def _projection(self, p, grad, perturb, delta, wd_ratio, eps): + wd = 1 + expand_size = [-1] + [1] * (len(p.shape) - 1) + for view_func in [self._channel_view, self._layer_view]: + + cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) + + if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): + p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) + wd = wd_ratio + + return perturb, wd + + return perturb, wd + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad.data + beta1, beta2 = group['betas'] + nesterov = group['nesterov'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg_sq'] = torch.zeros_like(p.data) + + # Adam + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + step_size = group['lr'] / bias_correction1 + + if nesterov: + perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom + else: + perturb = exp_avg / denom + + # Projection + wd_ratio = 1 + if len(p.shape) > 1: + perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if group['weight_decay'] > 0: + p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio) + + # Step + p.data.add_(-step_size, perturb) + + return loss diff --git a/optimizers/timm/adamw.py b/optimizers/timm/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..66f9a959de586356a29ace2f9c57d3fee8d1057a --- /dev/null +++ b/optimizers/timm/adamw.py @@ -0,0 +1,117 @@ +""" AdamW Optimizer +Impl copied from PyTorch master +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.data.addcdiv_(-step_size, exp_avg, denom) + + return loss diff --git a/optimizers/timm/cosine_lr.py b/optimizers/timm/cosine_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..1532f092b5cc8c0af5125967cfb84b32ce03ca4a --- /dev/null +++ b/optimizers/timm/cosine_lr.py @@ -0,0 +1,116 @@ +""" Cosine Scheduler + +Cosine LR schedule with warmup, cycle/restarts, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class CosineLRScheduler(Scheduler): + """ + Cosine decay with restarts. + This is described in the paper https://arxiv.org/abs/1608.03983. + + Inspiration from + https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + t_mul: float = 1., + lr_min: float = 0., + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + cycle_limit=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and t_mul == 1 and decay_rate == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.t_mul = t_mul + self.lr_min = lr_min + self.decay_rate = decay_rate + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.t_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) + t_i = self.t_mul ** i * self.t_initial + t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.decay_rate ** i + lr_min = self.lr_min * gamma + lr_max_values = [v * gamma for v in self.base_values] + + if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): + lrs = [ + lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + if not cycles: + cycles = self.cycle_limit + cycles = max(1, cycles) + if self.t_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) diff --git a/optimizers/timm/lookahead.py b/optimizers/timm/lookahead.py new file mode 100644 index 0000000000000000000000000000000000000000..6b5b7f38ec8cb6594e3986b66223fa2881daeca3 --- /dev/null +++ b/optimizers/timm/lookahead.py @@ -0,0 +1,92 @@ +""" Lookahead Optimizer Wrapper. +Implementation modified from: https://github.com/alphadl/lookahead.pytorch +Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer +from collections import defaultdict + + +class Lookahead(Optimizer): + def __init__(self, base_optimizer, alpha=0.5, k=6): + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) + self.base_optimizer = base_optimizer + self.param_groups = self.base_optimizer.param_groups + self.defaults = base_optimizer.defaults + self.defaults.update(defaults) + self.state = defaultdict(dict) + # manually add our defaults to the param groups + for name, default in defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) + + def update_slow(self, group): + for fast_p in group["params"]: + if fast_p.grad is None: + continue + param_state = self.state[fast_p] + if 'slow_buffer' not in param_state: + param_state['slow_buffer'] = torch.empty_like(fast_p.data) + param_state['slow_buffer'].copy_(fast_p.data) + slow = param_state['slow_buffer'] + slow.add_(group['lookahead_alpha'], fast_p.data - slow) + fast_p.data.copy_(slow) + + def sync_lookahead(self): + for group in self.param_groups: + self.update_slow(group) + + def step(self, closure=None): + #assert id(self.param_groups) == id(self.base_optimizer.param_groups) + loss = self.base_optimizer.step(closure) + for group in self.param_groups: + group['lookahead_step'] += 1 + if group['lookahead_step'] % group['lookahead_k'] == 0: + self.update_slow(group) + return loss + + def state_dict(self): + fast_state_dict = self.base_optimizer.state_dict() + slow_state = { + (id(k) if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + fast_state = fast_state_dict['state'] + param_groups = fast_state_dict['param_groups'] + return { + 'state': fast_state, + 'slow_state': slow_state, + 'param_groups': param_groups, + } + + def load_state_dict(self, state_dict): + fast_state_dict = { + 'state': state_dict['state'], + 'param_groups': state_dict['param_groups'], + } + self.base_optimizer.load_state_dict(fast_state_dict) + + # We want to restore the slow state, but share param_groups reference + # with base_optimizer. This is a bit redundant but least code + slow_state_new = False + if 'slow_state' not in state_dict: + print('Loading state_dict from optimizer without Lookahead applied.') + state_dict['slow_state'] = defaultdict(dict) + slow_state_new = True + slow_state_dict = { + 'state': state_dict['slow_state'], + 'param_groups': state_dict['param_groups'], # this is pointless but saves code + } + super(Lookahead, self).load_state_dict(slow_state_dict) + self.param_groups = self.base_optimizer.param_groups # make both ref same container + if slow_state_new: + # reapply defaults to catch missing lookahead specific ones + for name, default in self.defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) diff --git a/optimizers/timm/nadam.py b/optimizers/timm/nadam.py new file mode 100644 index 0000000000000000000000000000000000000000..d994d1b83485c9b068de73f5f3cf2efb1e5bec39 --- /dev/null +++ b/optimizers/timm/nadam.py @@ -0,0 +1,88 @@ +import torch +from torch.optim import Optimizer + + +class Nadam(Optimizer): + """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). + + It has been proposed in `Incorporating Nesterov Momentum into Adam`__. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + schedule_decay (float, optional): momentum schedule decay (default: 4e-3) + + __ http://cs229.stanford.edu/proj2015/054_report.pdf + __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf + + Originally taken from: https://github.com/pytorch/pytorch/pull/1408 + NOTE: Has potential issues but does work well on some problems. + """ + + def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, schedule_decay=4e-3): + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, schedule_decay=schedule_decay) + super(Nadam, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['m_schedule'] = 1. + state['exp_avg'] = grad.new().resize_as_(grad).zero_() + state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() + + # Warming momentum schedule + m_schedule = state['m_schedule'] + schedule_decay = group['schedule_decay'] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + eps = group['eps'] + state['step'] += 1 + t = state['step'] + + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + momentum_cache_t = beta1 * \ + (1. - 0.5 * (0.96 ** (t * schedule_decay))) + momentum_cache_t_1 = beta1 * \ + (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) + m_schedule_new = m_schedule * momentum_cache_t + m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 + state['m_schedule'] = m_schedule_new + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1. - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) + exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) + denom = exp_avg_sq_prime.sqrt_().add_(eps) + + p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) + p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) + + return loss diff --git a/optimizers/timm/novograd.py b/optimizers/timm/novograd.py new file mode 100644 index 0000000000000000000000000000000000000000..4137c6aa9406360d29f5f7234ebbdef294404d0e --- /dev/null +++ b/optimizers/timm/novograd.py @@ -0,0 +1,77 @@ +"""NovoGrad Optimizer. +Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NovoGrad(Optimizer): + def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(NovoGrad, self).__init__(params, defaults) + self._lr = lr + self._beta1 = betas[0] + self._beta2 = betas[1] + self._eps = eps + self._wd = weight_decay + self._grad_averaging = grad_averaging + + self._momentum_initialized = False + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + if not self._momentum_initialized: + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('NovoGrad does not support sparse gradients') + + v = torch.norm(grad)**2 + m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data + state['step'] = 0 + state['v'] = v + state['m'] = m + state['grad_ema'] = None + self._momentum_initialized = True + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + state['step'] += 1 + + step, v, m = state['step'], state['v'], state['m'] + grad_ema = state['grad_ema'] + + grad = p.grad.data + g2 = torch.norm(grad)**2 + grad_ema = g2 if grad_ema is None else grad_ema * \ + self._beta2 + g2 * (1. - self._beta2) + grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) + + if self._grad_averaging: + grad *= (1. - self._beta1) + + g2 = torch.norm(grad)**2 + v = self._beta2*v + (1. - self._beta2)*g2 + m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) + bias_correction1 = 1 - self._beta1 ** step + bias_correction2 = 1 - self._beta2 ** step + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + state['v'], state['m'] = v, m + state['grad_ema'] = grad_ema + p.data.add_(-step_size, m) + return loss diff --git a/optimizers/timm/nvnovograd.py b/optimizers/timm/nvnovograd.py new file mode 100644 index 0000000000000000000000000000000000000000..323312d2fc36d028124f7a7ec604d248e71503cd --- /dev/null +++ b/optimizers/timm/nvnovograd.py @@ -0,0 +1,118 @@ +""" Nvidia NovoGrad Optimizer. +Original impl by Nvidia from Jasper example: + - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NvNovoGrad(Optimizer): + """ + Implements Novograd algorithm. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.95, 0.98)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging: gradient averaging + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + """ + + def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, + weight_decay=0, grad_averaging=False, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad) + + super(NvNovoGrad, self).__init__(params, defaults) + + def __setstate__(self, state): + super(NvNovoGrad, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Sparse gradients are not supported.') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + norm = torch.sum(torch.pow(grad, 2)) + + if exp_avg_sq == 0: + exp_avg_sq.copy_(norm) + else: + exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + grad.div_(denom) + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + if group['grad_averaging']: + grad.mul_(1 - beta1) + exp_avg.mul_(beta1).add_(grad) + + p.data.add_(-group['lr'], exp_avg) + + return loss diff --git a/optimizers/timm/plateau_lr.py b/optimizers/timm/plateau_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2cacb65a1bf23d10aa6fd296f74579571043cf --- /dev/null +++ b/optimizers/timm/plateau_lr.py @@ -0,0 +1,113 @@ +""" Plateau Scheduler + +Adapts PyTorch plateau scheduler and allows application of noise, warmup. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +from .scheduler import Scheduler + + +class PlateauLRScheduler(Scheduler): + """Decay the LR by a factor every time the validation loss plateaus.""" + + def __init__(self, + optimizer, + decay_rate=0.1, + patience_t=10, + verbose=True, + threshold=1e-4, + cooldown_t=0, + warmup_t=0, + warmup_lr_init=0, + lr_min=0, + mode='max', + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize=True, + ): + super().__init__(optimizer, 'lr', initialize=initialize) + + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + patience=patience_t, + factor=decay_rate, + verbose=verbose, + threshold=threshold, + cooldown=cooldown_t, + mode=mode, + min_lr=lr_min + ) + + self.noise_range = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + self.restore_lr = None + + def state_dict(self): + return { + 'best': self.lr_scheduler.best, + 'last_epoch': self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + self.lr_scheduler.best = state_dict['best'] + if 'last_epoch' in state_dict: + self.lr_scheduler.last_epoch = state_dict['last_epoch'] + + # override the base class step fn completely + def step(self, epoch, metric=None): + if epoch <= self.warmup_t: + lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] + super().update_groups(lrs) + else: + if self.restore_lr is not None: + # restore actual LR from before our last noise perturbation before stepping base + for i, param_group in enumerate(self.optimizer.param_groups): + param_group['lr'] = self.restore_lr[i] + self.restore_lr = None + + self.lr_scheduler.step(metric, epoch) # step the base scheduler + + if self.noise_range is not None: + if isinstance(self.noise_range, (list, tuple)): + apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] + else: + apply_noise = epoch >= self.noise_range + if apply_noise: + self._apply_noise(epoch) + + def _apply_noise(self, epoch): + g = torch.Generator() + g.manual_seed(self.noise_seed + epoch) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + + # apply the noise on top of previous LR, cache the old value so we can restore for normal + # stepping of base scheduler + restore_lr = [] + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group['lr']) + restore_lr.append(old_lr) + new_lr = old_lr + old_lr * noise + param_group['lr'] = new_lr + self.restore_lr = restore_lr diff --git a/optimizers/timm/radam.py b/optimizers/timm/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..9987a334460286b1a6c8ec6d57ee023596a74219 --- /dev/null +++ b/optimizers/timm/radam.py @@ -0,0 +1,152 @@ +"""RAdam Optimizer. +Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam +Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 +""" +import math +import torch +from torch.optim.optimizer import Optimizer, required + + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.buffer = [[None, None, None] for ind in range(10)] + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = self.buffer[int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss + + +class PlainRAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + + super(PlainRAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PlainRAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss diff --git a/optimizers/timm/rmsprop_tf.py b/optimizers/timm/rmsprop_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..5115555cd26040e3af297a6e79e7bd5e4d202623 --- /dev/null +++ b/optimizers/timm/rmsprop_tf.py @@ -0,0 +1,136 @@ +""" RMSProp modified to behave like Tensorflow impl + +Originally cut & paste from PyTorch RMSProp +https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py +Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE + +Modifications Copyright 2020 Ross Wightman +""" + +import torch +from torch.optim import Optimizer + + +class RMSpropTF(Optimizer): + """Implements RMSprop algorithm (TensorFlow style epsilon) + + NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt + and a few other modifications to closer match Tensorflow for matching hyper-params. + + Noteworthy changes include: + 1. Epsilon applied inside square-root + 2. square_avg initialized to ones + 3. LR scaling of update accumulated in momentum buffer + + Proposed by G. Hinton in his + `course `_. + + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing (decay) constant (default: 0.9) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 + lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer + update as per defaults in Tensorflow + + """ + + def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, + decoupled_decay=False, lr_in_momentum=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, + decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + super(RMSpropTF, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSpropTF, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p.data) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p.data) + + square_avg = state['square_avg'] + one_minus_alpha = 1. - group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + if 'decoupled_decay' in group and group['decoupled_decay']: + p.data.add_(-group['weight_decay'], p.data) + else: + grad = grad.add(group['weight_decay'], p.data) + + # Tensorflow order of ops for updating squared avg + square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) + # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(one_minus_alpha, grad - grad_avg) + # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original + avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt + else: + avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + # Tensorflow accumulates the LR scaling in the momentum buffer + if 'lr_in_momentum' in group and group['lr_in_momentum']: + buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) + p.data.add_(-buf) + else: + # PyTorch scales the param update by LR + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.data.add_(-group['lr'], buf) + else: + p.data.addcdiv_(-group['lr'], grad, avg) + + return loss diff --git a/optimizers/timm/scheduler.py b/optimizers/timm/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..21d51509c87a0783c6b61986c574a3ed5366e165 --- /dev/null +++ b/optimizers/timm/scheduler.py @@ -0,0 +1,105 @@ +from typing import Dict, Any + +import torch + + +class Scheduler: + """ Parameter Scheduler Base Class + A scheduler base class that can be used to schedule any optimizer parameter groups. + + Unlike the builtin PyTorch schedulers, this is intended to be consistently called + * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value + * At the END of each optimizer update, after incrementing the update count, to calculate next update's value + + The schedulers built on this should try to remain as stateless as possible (for simplicity). + + This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' + and -1 values for special behaviour. All epoch and update counts must be tracked in the training + code and explicitly passed in to the schedulers on the corresponding step or step_update call. + + Based on ideas from: + * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler + * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + param_group_field: str, + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize: bool = True) -> None: + self.optimizer = optimizer + self.param_group_field = param_group_field + self._initial_param_group_field = f"initial_{param_group_field}" + if initialize: + for i, group in enumerate(self.optimizer.param_groups): + if param_group_field not in group: + raise KeyError(f"{param_group_field} missing from param_groups[{i}]") + group.setdefault(self._initial_param_group_field, group[param_group_field]) + else: + for i, group in enumerate(self.optimizer.param_groups): + if self._initial_param_group_field not in group: + raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") + self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] + self.metric = None # any point to having this for all? + self.noise_range_t = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.update_groups(self.base_values) + + def state_dict(self) -> Dict[str, Any]: + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.__dict__.update(state_dict) + + def get_epoch_values(self, epoch: int): + return None + + def get_update_values(self, num_updates: int): + return None + + def step(self, epoch: int, metric: float = None) -> None: + self.metric = metric + values = self.get_epoch_values(epoch) + if values is not None: + values = self._add_noise(values, epoch) + self.update_groups(values) + + def step_update(self, num_updates: int, metric: float = None): + self.metric = metric + values = self.get_update_values(num_updates) + if values is not None: + values = self._add_noise(values, num_updates) + self.update_groups(values) + + def update_groups(self, values): + if not isinstance(values, (list, tuple)): + values = [values] * len(self.optimizer.param_groups) + for param_group, value in zip(self.optimizer.param_groups, values): + param_group[self.param_group_field] = value + + def _add_noise(self, lrs, t): + if self.noise_range_t is not None: + if isinstance(self.noise_range_t, (list, tuple)): + apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] + else: + apply_noise = t >= self.noise_range_t + if apply_noise: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + lrs = [v + v * noise for v in lrs] + return lrs diff --git a/optimizers/timm/sgdp.py b/optimizers/timm/sgdp.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a94aa332d7030a70e888342eb6cc4623d69836 --- /dev/null +++ b/optimizers/timm/sgdp.py @@ -0,0 +1,96 @@ +""" +SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn as nn +from torch.optim.optimizer import Optimizer, required +import math + +class SGDP(Optimizer): + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, + nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) + super(SGDP, self).__init__(params, defaults) + + def _channel_view(self, x): + return x.view(x.size(0), -1) + + def _layer_view(self, x): + return x.view(1, -1) + + def _cosine_similarity(self, x, y, eps, view_func): + x = view_func(x) + y = view_func(y) + + x_norm = x.norm(dim=1).add_(eps) + y_norm = y.norm(dim=1).add_(eps) + dot = (x * y).sum(dim=1) + + return dot.abs() / x_norm / y_norm + + def _projection(self, p, grad, perturb, delta, wd_ratio, eps): + wd = 1 + expand_size = [-1] + [1] * (len(p.shape) - 1) + for view_func in [self._channel_view, self._layer_view]: + + cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) + + if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): + p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) + wd = wd_ratio + + return perturb, wd + + return perturb, wd + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + # State initialization + if len(state) == 0: + state['momentum'] = torch.zeros_like(p.data) + + # SGD + buf = state['momentum'] + buf.mul_(momentum).add_(1 - dampening, grad) + if nesterov: + d_p = grad + momentum * buf + else: + d_p = buf + + # Projection + wd_ratio = 1 + if len(p.shape) > 1: + d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if weight_decay != 0: + p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) + + # Step + p.data.add_(-group['lr'], d_p) + + return loss diff --git a/optimizers/timm/step_lr.py b/optimizers/timm/step_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..f797e1a8cf35999531dd5f1ccbbe09a9d0cf30a9 --- /dev/null +++ b/optimizers/timm/step_lr.py @@ -0,0 +1,63 @@ +""" Step Scheduler + +Basic step LR schedule with warmup, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import torch + +from .scheduler import Scheduler + + +class StepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_t: float, + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/optimizers/timm/tanh_lr.py b/optimizers/timm/tanh_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc338bb1df7a564d9207b32ab0f59cdf1ef4c59 --- /dev/null +++ b/optimizers/timm/tanh_lr.py @@ -0,0 +1,120 @@ +""" TanH Scheduler + +TanH schedule with warmup, cycle/restarts, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class TanhLRScheduler(Scheduler): + """ + Hyberbolic-Tangent decay with restarts. + This is described in the paper https://arxiv.org/abs/1806.01593 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lb: float = -6., + ub: float = 4., + t_mul: float = 1., + lr_min: float = 0., + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + cycle_limit=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + assert lb < ub + assert cycle_limit >= 0 + assert warmup_t >= 0 + assert warmup_lr_init >= 0 + self.lb = lb + self.ub = ub + self.t_initial = t_initial + self.t_mul = t_mul + self.lr_min = lr_min + self.decay_rate = decay_rate + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + if self.warmup_t: + t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.t_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) + t_i = self.t_mul ** i * self.t_initial + t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): + gamma = self.decay_rate ** i + lr_min = self.lr_min * gamma + lr_max_values = [v * gamma for v in self.base_values] + + tr = t_curr / t_i + lrs = [ + lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + if not cycles: + cycles = self.cycle_limit + cycles = max(1, cycles) + if self.t_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) diff --git a/outputs/audio2pose/custom/hf/999/gt_2_scott_0_3_3.npz b/outputs/audio2pose/custom/hf/999/gt_2_scott_0_3_3.npz new file mode 100644 index 0000000000000000000000000000000000000000..62b0df13d933e4f482832a1f6293f1f529ac883a --- /dev/null +++ b/outputs/audio2pose/custom/hf/999/gt_2_scott_0_3_3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c22b5851a4a73eb4b5ecc8c3cc0430a344bf226b58488aa442a0d3b55d41d95 +size 71632 diff --git a/outputs/audio2pose/custom/hf/999/res_2_scott_0_3_3.mp4 b/outputs/audio2pose/custom/hf/999/res_2_scott_0_3_3.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..89fb48c981c2b468b1522f862343b257162ce866 Binary files /dev/null and b/outputs/audio2pose/custom/hf/999/res_2_scott_0_3_3.mp4 differ diff --git a/outputs/audio2pose/custom/hf/999/res_2_scott_0_3_3.npz b/outputs/audio2pose/custom/hf/999/res_2_scott_0_3_3.npz new file mode 100644 index 0000000000000000000000000000000000000000..f07491391d338fdd97b45c93d61a71b58b7b4ac6 --- /dev/null +++ b/outputs/audio2pose/custom/hf/999/res_2_scott_0_3_3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77fd86b392889c8da1d292f23326217639190933ac605cce49b2f1d41d5a81dd +size 71632 diff --git a/outputs/audio2pose/custom/hf/tmp.wav b/outputs/audio2pose/custom/hf/tmp.wav new file mode 100644 index 0000000000000000000000000000000000000000..a2bd689d57c4e35395570cb379c5dc98d364326b Binary files /dev/null and b/outputs/audio2pose/custom/hf/tmp.wav differ diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..9695433780ba56a2bbc4a9ba13e2725d6f9bbb7a --- /dev/null +++ b/packages.txt @@ -0,0 +1,4 @@ +libgl1-mesa-dev +libglu1-mesa-dev +freeglut3-dev +mesa-common-dev \ No newline at end of file diff --git a/pyrender/.coveragerc b/pyrender/.coveragerc new file mode 100644 index 0000000000000000000000000000000000000000..ee31cded3509cbd991a33dd27e2525b93a1a6558 --- /dev/null +++ b/pyrender/.coveragerc @@ -0,0 +1,5 @@ +[report] +exclude_lines = + def __repr__ + def __str__ + @abc.abstractmethod diff --git a/pyrender/.flake8 b/pyrender/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..fec4bcfc3ba774b53a866d839ea15bae6ebdb4a6 --- /dev/null +++ b/pyrender/.flake8 @@ -0,0 +1,8 @@ +[flake8] +ignore = E231,W504,F405,F403 +max-line-length = 79 +select = B,C,E,F,W,T4,B9 +exclude = + docs/source/conf.py, + __pycache__, + examples/* diff --git a/pyrender/.gitignore b/pyrender/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ae59dec631f71a23d4255aaf9c0274a699f4ba25 --- /dev/null +++ b/pyrender/.gitignore @@ -0,0 +1,106 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +docs/**/generated/** + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/pyrender/.pre-commit-config.yaml b/pyrender/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1817eb39bf409aff80c7d2cc79a3bc3856c70dbd --- /dev/null +++ b/pyrender/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.7.1 + hooks: + - id: flake8 + exclude: ^setup.py diff --git a/pyrender/.travis.yml b/pyrender/.travis.yml new file mode 100644 index 0000000000000000000000000000000000000000..1ad289ae1513eaf8fda74f8d5ab7840be3ef56cb --- /dev/null +++ b/pyrender/.travis.yml @@ -0,0 +1,43 @@ +language: python +sudo: required +dist: xenial + +python: +- '3.6' +- '3.7' + +before_install: + # Pre-install osmesa + - sudo apt update + - sudo wget https://github.com/mmatl/travis_debs/raw/master/xenial/mesa_18.3.3-0.deb + - sudo dpkg -i ./mesa_18.3.3-0.deb || true + - sudo apt install -f + - git clone https://github.com/mmatl/pyopengl.git + - cd pyopengl + - pip install . + - cd .. + +install: + - pip install . + # - pip install -q pytest pytest-cov coveralls + - pip install pytest pytest-cov coveralls + - pip install ./pyopengl + +script: + - PYOPENGL_PLATFORM=osmesa pytest --cov=pyrender tests + +after_success: +- coveralls || true + +deploy: + provider: pypi + skip_existing: true + user: mmatl + on: + tags: true + branch: master + password: + secure: O4WWMbTYb2eVYIO4mMOVa6/xyhX7mPvJpd96cxfNvJdyuqho8VapOhzqsI5kahMB1hFjWWr61yR4+Ru5hoDYf3XA6BQVk8eCY9+0H7qRfvoxex71lahKAqfHLMoE1xNdiVTgl+QN9hYjOnopLod24rx8I8eXfpHu/mfCpuTYGyLlNcDP5St3bXpXLPB5wg8Jo1YRRv6W/7fKoXyuWjewk9cJAS0KrEgnDnSkdwm6Pb+80B2tcbgdGvpGaByw5frndwKiMUMgVUownepDU5POQq2p29wwn9lCvRucULxjEgO+63jdbZRj5fNutLarFa2nISfYnrd72LOyDfbJubwAzzAIsy2JbFORyeHvCgloiuE9oE7a9oOQt/1QHBoIV0seiawMWn55Yp70wQ7HlJs4xSGJWCGa5+9883QRNsvj420atkb3cgO8P+PXwiwTi78Dq7Z/xHqccsU0b8poqBneQoA+pUGgNnF6V7Z8e9RsCcse2gAWSZWuOK3ua+9xCgH7I7MeL3afykr2aJ+yFCoYJMFrUjJeodMX2RbL0q+3FzIPZeGW3WdhTEAL9TSKRcJBSQTskaQlZx/OcpobxS7t3d2S68CCLG9uMTqOTYws55WZ1etalA75sRk9K2MR7ZGjZW3jdtvMViISc/t6Rrjea1GE8ZHGJC6/IeLIWA2c7nc= + distributions: sdist bdist_wheel +notifications: + email: false diff --git a/pyrender/LICENSE b/pyrender/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4276f7d204e4d85104246df637e0e36adbef14a7 --- /dev/null +++ b/pyrender/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Matthew Matl + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pyrender/MANIFEST.in b/pyrender/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..097bcca3b4fccdc39ddd63c10f710ad524898e95 --- /dev/null +++ b/pyrender/MANIFEST.in @@ -0,0 +1,5 @@ +# Include the license +include LICENSE +include README.rst +include pyrender/fonts/* +include pyrender/shaders/* diff --git a/pyrender/README.md b/pyrender/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ae88ed1c5e78f247e38291ed83cf4c81230bf976 --- /dev/null +++ b/pyrender/README.md @@ -0,0 +1,92 @@ +# Pyrender + +[![Build Status](https://travis-ci.org/mmatl/pyrender.svg?branch=master)](https://travis-ci.org/mmatl/pyrender) +[![Documentation Status](https://readthedocs.org/projects/pyrender/badge/?version=latest)](https://pyrender.readthedocs.io/en/latest/?badge=latest) +[![Coverage Status](https://coveralls.io/repos/github/mmatl/pyrender/badge.svg?branch=master)](https://coveralls.io/github/mmatl/pyrender?branch=master) +[![PyPI version](https://badge.fury.io/py/pyrender.svg)](https://badge.fury.io/py/pyrender) +[![Downloads](https://pepy.tech/badge/pyrender)](https://pepy.tech/project/pyrender) + +Pyrender is a pure Python (2.7, 3.4, 3.5, 3.6) library for physically-based +rendering and visualization. +It is designed to meet the [glTF 2.0 specification from Khronos](https://www.khronos.org/gltf/). + +Pyrender is lightweight, easy to install, and simple to use. +It comes packaged with both an intuitive scene viewer and a headache-free +offscreen renderer with support for GPU-accelerated rendering on headless +servers, which makes it perfect for machine learning applications. + +Extensive documentation, including a quickstart guide, is provided [here](https://pyrender.readthedocs.io/en/latest/). + +For a minimal working example of GPU-accelerated offscreen rendering using EGL, +check out the [EGL Google CoLab Notebook](https://colab.research.google.com/drive/1pcndwqeY8vker3bLKQNJKr3B-7-SYenE?usp=sharing). + + +

+ GIF of Viewer + Damaged Helmet +

+ +## Installation +You can install pyrender directly from pip. + +```bash +pip install pyrender +``` + +## Features + +Despite being lightweight, pyrender has lots of features, including: + +* Simple interoperation with the amazing [trimesh](https://github.com/mikedh/trimesh) project, +which enables out-of-the-box support for dozens of mesh types, including OBJ, +STL, DAE, OFF, PLY, and GLB. +* An easy-to-use scene viewer with support for animation, showing face and vertex +normals, toggling lighting conditions, and saving images and GIFs. +* An offscreen rendering module that supports OSMesa and EGL backends. +* Shadow mapping for directional and spot lights. +* Metallic-roughness materials for physically-based rendering, including several +types of texture and normal mapping. +* Transparency. +* Depth and color image generation. + +## Sample Usage + +For sample usage, check out the [quickstart +guide](https://pyrender.readthedocs.io/en/latest/examples/index.html) or one of +the Google CoLab Notebooks: + +* [EGL Google CoLab Notebook](https://colab.research.google.com/drive/1pcndwqeY8vker3bLKQNJKr3B-7-SYenE?usp=sharing) + +## Viewer Keyboard and Mouse Controls + +When using the viewer, the basic controls for moving about the scene are as follows: + +* To rotate the camera about the center of the scene, hold the left mouse button and drag the cursor. +* To rotate the camera about its viewing axis, hold `CTRL` left mouse button and drag the cursor. +* To pan the camera, do one of the following: + * Hold `SHIFT`, then hold the left mouse button and drag the cursor. + * Hold the middle mouse button and drag the cursor. +* To zoom the camera in or out, do one of the following: + * Scroll the mouse wheel. + * Hold the right mouse button and drag the cursor. + +The available keyboard commands are as follows: + +* `a`: Toggles rotational animation mode. +* `c`: Toggles backface culling. +* `f`: Toggles fullscreen mode. +* `h`: Toggles shadow rendering. +* `i`: Toggles axis display mode (no axes, world axis, mesh axes, all axes). +* `l`: Toggles lighting mode (scene lighting, Raymond lighting, or direct lighting). +* `m`: Toggles face normal visualization. +* `n`: Toggles vertex normal visualization. +* `o`: Toggles orthographic camera mode. +* `q`: Quits the viewer. +* `r`: Starts recording a GIF, and pressing again stops recording and opens a file dialog. +* `s`: Opens a file dialog to save the current view as an image. +* `w`: Toggles wireframe mode (scene default, flip wireframes, all wireframe, or all solid). +* `z`: Resets the camera to the default view. + +As a note, displaying shadows significantly slows down rendering, so if you're +experiencing low framerates, just kill shadows or reduce the number of lights in +your scene. diff --git a/pyrender/docs/Makefile b/pyrender/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..b1064a04362a0c4372fae351f99ed3bd9f82ff92 --- /dev/null +++ b/pyrender/docs/Makefile @@ -0,0 +1,23 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +clean: + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + rm -rf ./source/generated/* + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/pyrender/docs/make.bat b/pyrender/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..4d9eb83d9f9309029f4b14ff09024658bb0f5563 --- /dev/null +++ b/pyrender/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/pyrender/docs/source/api/index.rst b/pyrender/docs/source/api/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..b6e473149d8f132f176e242c93406fdb84ce0b04 --- /dev/null +++ b/pyrender/docs/source/api/index.rst @@ -0,0 +1,59 @@ +Pyrender API Documentation +========================== + +Constants +--------- +.. automodapi:: pyrender.constants + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + +Cameras +------- +.. automodapi:: pyrender.camera + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + +Lighting +-------- +.. automodapi:: pyrender.light + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + +Objects +------- +.. automodapi:: pyrender + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + :skip: Camera, DirectionalLight, Light, OffscreenRenderer, Node + :skip: OrthographicCamera, PerspectiveCamera, PointLight, RenderFlags + :skip: Renderer, Scene, SpotLight, TextAlign, Viewer, GLTF + +Scenes +------ +.. automodapi:: pyrender + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + :skip: Camera, DirectionalLight, Light, OffscreenRenderer + :skip: OrthographicCamera, PerspectiveCamera, PointLight, RenderFlags + :skip: Renderer, SpotLight, TextAlign, Viewer, Sampler, Texture, Material + :skip: MetallicRoughnessMaterial, Primitive, Mesh, GLTF + +On-Screen Viewer +---------------- +.. automodapi:: pyrender.viewer + :no-inheritance-diagram: + :no-inherited-members: + :no-main-docstr: + :no-heading: + +Off-Screen Rendering +-------------------- +.. automodapi:: pyrender.offscreen + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: diff --git a/pyrender/docs/source/conf.py b/pyrender/docs/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf194c375e7e789b334a838953adfeaf2eb59b6 --- /dev/null +++ b/pyrender/docs/source/conf.py @@ -0,0 +1,352 @@ +# -*- coding: utf-8 -*- +# +# core documentation build configuration file, created by +# sphinx-quickstart on Sun Oct 16 14:33:48 2016. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys +import os +from pyrender import __version__ +from sphinx.domains.python import PythonDomain + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.abspath('../../')) + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +#needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.coverage', + 'sphinx.ext.githubpages', + 'sphinx.ext.intersphinx', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx_automodapi.automodapi', + 'sphinx_automodapi.smart_resolver' +] +numpydoc_class_members_toctree = False +automodapi_toctreedirnm = 'generated' +automodsumm_inherited_members = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The encoding of source files. +#source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'pyrender' +copyright = u'2018, Matthew Matl' +author = u'Matthew Matl' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = __version__ +# The full version, including alpha/beta/rc tags. +release = __version__ + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +#today = '' +# Else, today_fmt is used as the format for a strftime call. +#today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = [] + +# The reST default role (used for this markup: `text`) to use for all +# documents. +#default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +#add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +#add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +#show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +#modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +#keep_warnings = False + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +import sphinx_rtd_theme +html_theme = 'sphinx_rtd_theme' +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +#html_theme_options = {} + +# Add any paths that contain custom themes here, relative to this directory. +#html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +#html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +#html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +#html_logo = None + +# The name of an image file (relative to this directory) to use as a favicon of +# the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Add any extra paths that contain custom files (such as robots.txt or +# .htaccess) here, relative to this directory. These files are copied +# directly to the root of the documentation. +#html_extra_path = [] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +#html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +#html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +#html_additional_pages = {} + +# If false, no module index is generated. +#html_domain_indices = True + +# If false, no index is generated. +#html_use_index = True + +# If true, the index is split into individual pages for each letter. +#html_split_index = False + +# If true, links to the reST sources are added to the pages. +#html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +#html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +#html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +#html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +#html_file_suffix = None + +# Language to be used for generating the HTML full-text search index. +# Sphinx supports the following languages: +# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' +# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' +#html_search_language = 'en' + +# A dictionary with options for the search language support, empty by default. +# Now only 'ja' uses this config value +#html_search_options = {'type': 'default'} + +# The name of a javascript file (relative to the configuration directory) that +# implements a search results scorer. If empty, the default will be used. +#html_search_scorer = 'scorer.js' + +# Output file base name for HTML help builder. +htmlhelp_basename = 'coredoc' + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { +# The paper size ('letterpaper' or 'a4paper'). +#'papersize': 'letterpaper', + +# The font size ('10pt', '11pt' or '12pt'). +#'pointsize': '10pt', + +# Additional stuff for the LaTeX preamble. +#'preamble': '', + +# Latex figure (float) alignment +#'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'pyrender.tex', u'pyrender Documentation', + u'Matthew Matl', 'manual'), +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +#latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +#latex_use_parts = False + +# If true, show page references after internal links. +#latex_show_pagerefs = False + +# If true, show URL addresses after external links. +#latex_show_urls = False + +# Documents to append as an appendix to all manuals. +#latex_appendices = [] + +# If false, no module index is generated. +#latex_domain_indices = True + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'pyrender', u'pyrender Documentation', + [author], 1) +] + +# If true, show URL addresses after external links. +#man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'pyrender', u'pyrender Documentation', + author, 'pyrender', 'One line description of project.', + 'Miscellaneous'), +] + +# Documents to append as an appendix to all manuals. +#texinfo_appendices = [] + +# If false, no module index is generated. +#texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +#texinfo_show_urls = 'footnote' + +# If true, do not generate a @detailmenu in the "Top" node's menu. +#texinfo_no_detailmenu = False + +intersphinx_mapping = { + 'python' : ('https://docs.python.org/', None), + 'pyrender' : ('https://pyrender.readthedocs.io/en/latest/', None), +} + +# Autosummary fix +autosummary_generate = True + +# Try to suppress multiple-definition warnings by always taking the shorter +# path when two or more paths have the same base module + +class MyPythonDomain(PythonDomain): + + def find_obj(self, env, modname, classname, name, type, searchmode=0): + """Ensures an object always resolves to the desired module + if defined there.""" + orig_matches = PythonDomain.find_obj( + self, env, modname, classname, name, type, searchmode + ) + + if len(orig_matches) <= 1: + return orig_matches + + # If multiple matches, try to take the shortest if all the modules are + # the same + first_match_name_sp = orig_matches[0][0].split('.') + base_name = first_match_name_sp[0] + min_len = len(first_match_name_sp) + best_match = orig_matches[0] + + for match in orig_matches[1:]: + match_name = match[0] + match_name_sp = match_name.split('.') + match_base = match_name_sp[0] + + # If we have mismatched bases, return them all to trigger warnings + if match_base != base_name: + return orig_matches + + # Otherwise, check and see if it's shorter + if len(match_name_sp) < min_len: + min_len = len(match_name_sp) + best_match = match + + return (best_match,) + + +def setup(sphinx): + """Use MyPythonDomain in place of PythonDomain""" + sphinx.override_domain(MyPythonDomain) + diff --git a/pyrender/docs/source/examples/cameras.rst b/pyrender/docs/source/examples/cameras.rst new file mode 100644 index 0000000000000000000000000000000000000000..39186b75b16584d11fd1606b92291c104e0452bd --- /dev/null +++ b/pyrender/docs/source/examples/cameras.rst @@ -0,0 +1,26 @@ +.. _camera_guide: + +Creating Cameras +================ + +Pyrender supports three camera types -- :class:`.PerspectiveCamera` and +:class:`.IntrinsicsCamera` types, +which render scenes as a human would see them, and +:class:`.OrthographicCamera` types, which preserve distances between points. + +Creating cameras is easy -- just specify their basic attributes: + +>>> pc = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.414) +>>> oc = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + +For more information, see the Khronos group's documentation here_: + +.. _here: https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#projection-matrices + +When you add cameras to the scene, make sure that you're using OpenGL camera +coordinates to specify their pose. See the illustration below for details. +Basically, the camera z-axis points away from the scene, the x-axis points +right in image space, and the y-axis points up in image space. + +.. image:: /_static/camera_coords.png + diff --git a/pyrender/docs/source/examples/index.rst b/pyrender/docs/source/examples/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..4be536cd62c1cca112228f4e114e783be77a0ab8 --- /dev/null +++ b/pyrender/docs/source/examples/index.rst @@ -0,0 +1,20 @@ +.. _guide: + +User Guide +========== + +This section contains guides on how to use Pyrender to quickly visualize +your 3D data, including a quickstart guide and more detailed descriptions +of each part of the rendering pipeline. + + +.. toctree:: + :maxdepth: 2 + + quickstart.rst + models.rst + lighting.rst + cameras.rst + scenes.rst + offscreen.rst + viewer.rst diff --git a/pyrender/docs/source/examples/lighting.rst b/pyrender/docs/source/examples/lighting.rst new file mode 100644 index 0000000000000000000000000000000000000000..f89bee7d15027a0f52711622b053b49cc6e1b410 --- /dev/null +++ b/pyrender/docs/source/examples/lighting.rst @@ -0,0 +1,21 @@ +.. _lighting_guide: + +Creating Lights +=============== + +Pyrender supports three types of punctual light: + +- :class:`.PointLight`: Point-based light sources, such as light bulbs. +- :class:`.SpotLight`: A conical light source, like a flashlight. +- :class:`.DirectionalLight`: A general light that does not attenuate with + distance. + +Creating lights is easy -- just specify their basic attributes: + +>>> pl = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=2.0) +>>> sl = pyrender.SpotLight(color=[1.0, 1.0, 1.0], intensity=2.0, +... innerConeAngle=0.05, outerConeAngle=0.5) +>>> dl = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=2.0) + +For more information about how these lighting models are implemented, +see their class documentation. diff --git a/pyrender/docs/source/examples/models.rst b/pyrender/docs/source/examples/models.rst new file mode 100644 index 0000000000000000000000000000000000000000..84e71c4ff41a8d2e0eb2dc48434caedb757ff954 --- /dev/null +++ b/pyrender/docs/source/examples/models.rst @@ -0,0 +1,143 @@ +.. _model_guide: + +Loading and Configuring Models +============================== +The first step to any rendering application is loading your models. +Pyrender implements the GLTF 2.0 specification, which means that all +models are composed of a hierarchy of objects. + +At the top level, we have a :class:`.Mesh`. The :class:`.Mesh` is +basically a wrapper of any number of :class:`.Primitive` types, +which actually represent geometry that can be drawn to the screen. + +Primitives are composed of a variety of parameters, including +vertex positions, vertex normals, color and texture information, +and triangle indices if smooth rendering is desired. +They can implement point clouds, triangular meshes, or lines +depending on how you configure their data and set their +:attr:`.Primitive.mode` parameter. + +Although you can create primitives yourself if you want to, +it's probably easier to just use the utility functions provided +in the :class:`.Mesh` class. + +Creating Triangular Meshes +-------------------------- + +Simple Construction +~~~~~~~~~~~~~~~~~~~ +Pyrender allows you to create a :class:`.Mesh` containing a +triangular mesh model directly from a :class:`~trimesh.base.Trimesh` object +using the :meth:`.Mesh.from_trimesh` static method. + +>>> import trimesh +>>> import pyrender +>>> import numpy as np +>>> tm = trimesh.load('examples/models/fuze.obj') +>>> m = pyrender.Mesh.from_trimesh(tm) +>>> m.primitives +[] + +You can also create a single :class:`.Mesh` from a list of +:class:`~trimesh.base.Trimesh` objects: + +>>> tms = [trimesh.creation.icosahedron(), trimesh.creation.cylinder()] +>>> m = pyrender.Mesh.from_trimesh(tms) +[, + ] + +Vertex Smoothing +~~~~~~~~~~~~~~~~ + +The :meth:`.Mesh.from_trimesh` method has a few additional optional parameters. +If you want to render the mesh without interpolating face normals, which can +be useful for meshes that are supposed to be angular (e.g. a cube), you +can specify ``smooth=False``. + +>>> m = pyrender.Mesh.from_trimesh(tm, smooth=False) + +Per-Face or Per-Vertex Coloration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you have an untextured trimesh, you can color it in with per-face or +per-vertex colors: + +>>> tm.visual.vertex_colors = np.random.uniform(size=tm.vertices.shape) +>>> tm.visual.face_colors = np.random.uniform(size=tm.faces.shape) +>>> m = pyrender.Mesh.from_trimesh(tm) + +Instancing +~~~~~~~~~~ + +If you want to render many copies of the same mesh at different poses, +you can statically create a vast array of them in an efficient manner. +Simply specify the ``poses`` parameter to be a list of ``N`` 4x4 homogenous +transformation matrics that position the meshes relative to their common +base frame: + +>>> tfs = np.tile(np.eye(4), (3,1,1)) +>>> tfs[1,:3,3] = [0.1, 0.0, 0.0] +>>> tfs[2,:3,3] = [0.2, 0.0, 0.0] +>>> tfs +array([[[1. , 0. , 0. , 0. ], + [0. , 1. , 0. , 0. ], + [0. , 0. , 1. , 0. ], + [0. , 0. , 0. , 1. ]], + [[1. , 0. , 0. , 0.1], + [0. , 1. , 0. , 0. ], + [0. , 0. , 1. , 0. ], + [0. , 0. , 0. , 1. ]], + [[1. , 0. , 0. , 0.2], + [0. , 1. , 0. , 0. ], + [0. , 0. , 1. , 0. ], + [0. , 0. , 0. , 1. ]]]) + +>>> m = pyrender.Mesh.from_trimesh(tm, poses=tfs) + +Custom Materials +~~~~~~~~~~~~~~~~ + +You can also specify a custom material for any triangular mesh you create +in the ``material`` parameter of :meth:`.Mesh.from_trimesh`. +The main material supported by Pyrender is the +:class:`.MetallicRoughnessMaterial`. +The metallic-roughness model supports rendering highly-realistic objects across +a wide gamut of materials. + +For more information, see the documentation of the +:class:`.MetallicRoughnessMaterial` constructor or look at the Khronos_ +documentation for more information. + +.. _Khronos: https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#materials + +Creating Point Clouds +--------------------- + +Point Sprites +~~~~~~~~~~~~~ +Pyrender also allows you to create a :class:`.Mesh` containing a +point cloud directly from :class:`numpy.ndarray` instances +using the :meth:`.Mesh.from_points` static method. + +Simply provide a list of points and optional per-point colors and normals. + +>>> pts = tm.vertices.copy() +>>> colors = np.random.uniform(size=pts.shape) +>>> m = pyrender.Mesh.from_points(pts, colors=colors) + +Point clouds created in this way will be rendered as square point sprites. + +.. image:: /_static/points.png + +Point Spheres +~~~~~~~~~~~~~ +If you have a monochromatic point cloud and would like to render it with +spheres, you can render it by instancing a spherical trimesh: + +>>> sm = trimesh.creation.uv_sphere(radius=0.1) +>>> sm.visual.vertex_colors = [1.0, 0.0, 0.0] +>>> tfs = np.tile(np.eye(4), (len(pts), 1, 1)) +>>> tfs[:,:3,3] = pts +>>> m = pyrender.Mesh.from_trimesh(sm, poses=tfs) + +.. image:: /_static/points2.png diff --git a/pyrender/docs/source/examples/offscreen.rst b/pyrender/docs/source/examples/offscreen.rst new file mode 100644 index 0000000000000000000000000000000000000000..291532b6e0c0e512df35a97e3c826cc83015aeca --- /dev/null +++ b/pyrender/docs/source/examples/offscreen.rst @@ -0,0 +1,87 @@ +.. _offscreen_guide: + +Offscreen Rendering +=================== + +.. note:: + If you're using a headless server, you'll need to use either EGL (for + GPU-accelerated rendering) or OSMesa (for CPU-only software rendering). + If you're using OSMesa, be sure that you've installed it properly. See + :ref:`osmesa` for details. + +Choosing a Backend +------------------ + +Once you have a scene set up with its geometry, cameras, and lights, +you can render it using the :class:`.OffscreenRenderer`. Pyrender supports +three backends for offscreen rendering: + +- Pyglet, the same engine that runs the viewer. This requires an active + display manager, so you can't run it on a headless server. This is the + default option. +- OSMesa, a software renderer. +- EGL, which allows for GPU-accelerated rendering without a display manager. + +If you want to use OSMesa or EGL, you need to set the ``PYOPENGL_PLATFORM`` +environment variable before importing pyrender or any other OpenGL library. +You can do this at the command line: + +.. code-block:: bash + + PYOPENGL_PLATFORM=osmesa python render.py + +or at the top of your Python script: + +.. code-block:: bash + + # Top of main python script + import os + os.environ['PYOPENGL_PLATFORM'] = 'egl' + +The handle for EGL is ``egl``, and the handle for OSMesa is ``osmesa``. + +Running the Renderer +-------------------- + +Once you've set your environment variable appropriately, create your scene and +then configure the :class:`.OffscreenRenderer` object with a window width, +a window height, and a size for point-cloud points: + +>>> r = pyrender.OffscreenRenderer(viewport_width=640, +... viewport_height=480, +... point_size=1.0) + +Then, just call the :meth:`.OffscreenRenderer.render` function: + +>>> color, depth = r.render(scene) + +.. image:: /_static/scene.png + +This will return a ``(w,h,3)`` channel floating-point color image and +a ``(w,h)`` floating-point depth image rendered from the scene's main camera. + +You can customize the rendering process by using flag options from +:class:`.RenderFlags` and bitwise or-ing them together. For example, +the following code renders a color image with an alpha channel +and enables shadow mapping for all directional lights: + +>>> flags = RenderFlags.RGBA | RenderFlags.SHADOWS_DIRECTIONAL +>>> color, depth = r.render(scene, flags=flags) + +Once you're done with the offscreen renderer, you need to close it before you +can run a different renderer or open the viewer for the same scene: + +>>> r.delete() + +Google CoLab Examples +--------------------- + +For a minimal working example of offscreen rendering using OSMesa, +see the `OSMesa Google CoLab notebook`_. + +.. _OSMesa Google CoLab notebook: https://colab.research.google.com/drive/1Z71mHIc-Sqval92nK290vAsHZRUkCjUx + +For a minimal working example of offscreen rendering using EGL, +see the `EGL Google CoLab notebook`_. + +.. _EGL Google CoLab notebook: https://colab.research.google.com/drive/1rTLHk0qxh4dn8KNe-mCnN8HAWdd2_BEh diff --git a/pyrender/docs/source/examples/quickstart.rst b/pyrender/docs/source/examples/quickstart.rst new file mode 100644 index 0000000000000000000000000000000000000000..ac556419e5206c2ccd4bc985feb1a8c7347310af --- /dev/null +++ b/pyrender/docs/source/examples/quickstart.rst @@ -0,0 +1,71 @@ +.. _quickstart_guide: + +Quickstart +========== + + +Minimal Example for 3D Viewer +----------------------------- +Here is a minimal example of loading and viewing a triangular mesh model +in pyrender. + +>>> import trimesh +>>> import pyrender +>>> fuze_trimesh = trimesh.load('examples/models/fuze.obj') +>>> mesh = pyrender.Mesh.from_trimesh(fuze_trimesh) +>>> scene = pyrender.Scene() +>>> scene.add(mesh) +>>> pyrender.Viewer(scene, use_raymond_lighting=True) + +.. image:: /_static/fuze.png + + +Minimal Example for Offscreen Rendering +--------------------------------------- +.. note:: + If you're using a headless server, make sure that you followed the guide + for installing OSMesa. See :ref:`osmesa`. + +Here is a minimal example of rendering a mesh model offscreen in pyrender. +The only additional necessities are that you need to add lighting and a camera. + +>>> import numpy as np +>>> import trimesh +>>> import pyrender +>>> import matplotlib.pyplot as plt + +>>> fuze_trimesh = trimesh.load('examples/models/fuze.obj') +>>> mesh = pyrender.Mesh.from_trimesh(fuze_trimesh) +>>> scene = pyrender.Scene() +>>> scene.add(mesh) +>>> camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0) +>>> s = np.sqrt(2)/2 +>>> camera_pose = np.array([ +... [0.0, -s, s, 0.3], +... [1.0, 0.0, 0.0, 0.0], +... [0.0, s, s, 0.35], +... [0.0, 0.0, 0.0, 1.0], +... ]) +>>> scene.add(camera, pose=camera_pose) +>>> light = pyrender.SpotLight(color=np.ones(3), intensity=3.0, +... innerConeAngle=np.pi/16.0, +... outerConeAngle=np.pi/6.0) +>>> scene.add(light, pose=camera_pose) +>>> r = pyrender.OffscreenRenderer(400, 400) +>>> color, depth = r.render(scene) +>>> plt.figure() +>>> plt.subplot(1,2,1) +>>> plt.axis('off') +>>> plt.imshow(color) +>>> plt.subplot(1,2,2) +>>> plt.axis('off') +>>> plt.imshow(depth, cmap=plt.cm.gray_r) +>>> plt.show() + +.. image:: /_static/minexcolor.png + :width: 45% + :align: left +.. image:: /_static/minexdepth.png + :width: 45% + :align: right + diff --git a/pyrender/docs/source/examples/scenes.rst b/pyrender/docs/source/examples/scenes.rst new file mode 100644 index 0000000000000000000000000000000000000000..94c243f8b860b9669ac26105fd2b9906054f4568 --- /dev/null +++ b/pyrender/docs/source/examples/scenes.rst @@ -0,0 +1,78 @@ +.. _scene_guide: + +Creating Scenes +=============== + +Before you render anything, you need to put all of your lights, cameras, +and meshes into a scene. The :class:`.Scene` object keeps track of the relative +poses of these primitives by inserting them into :class:`.Node` objects and +keeping them in a directed acyclic graph. + +Adding Objects +-------------- + +To create a :class:`.Scene`, simply call the constructor. You can optionally +specify an ambient light color and a background color: + +>>> scene = pyrender.Scene(ambient_light=[0.02, 0.02, 0.02], +... bg_color=[1.0, 1.0, 1.0]) + +You can add objects to a scene by first creating a :class:`.Node` object +and adding the object and its pose to the :class:`.Node`. Poses are specified +as 4x4 homogenous transformation matrices that are stored in the node's +:attr:`.Node.matrix` attribute. Note that the :class:`.Node` +constructor requires you to specify whether you're adding a mesh, light, +or camera. + +>>> mesh = pyrender.Mesh.from_trimesh(tm) +>>> light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=2.0) +>>> cam = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.414) +>>> nm = pyrender.Node(mesh=mesh, matrix=np.eye(4)) +>>> nl = pyrender.Node(light=light, matrix=np.eye(4)) +>>> nc = pyrender.Node(camera=cam, matrix=np.eye(4)) +>>> scene.add_node(nm) +>>> scene.add_node(nl) +>>> scene.add_node(nc) + +You can also add objects directly to a scene with the :meth:`.Scene.add` function, +which takes care of creating a :class:`.Node` for you. + +>>> scene.add(mesh, pose=np.eye(4)) +>>> scene.add(light, pose=np.eye(4)) +>>> scene.add(cam, pose=np.eye(4)) + +Nodes can be hierarchical, in which case the node's :attr:`.Node.matrix` +specifies that node's pose relative to its parent frame. You can add nodes to +a scene hierarchically by specifying a parent node in your calls to +:meth:`.Scene.add` or :meth:`.Scene.add_node`: + +>>> scene.add_node(nl, parent_node=nc) +>>> scene.add(cam, parent_node=nm) + +If you add multiple cameras to a scene, you can specify which one to render from +by setting the :attr:`.Scene.main_camera_node` attribute. + +Updating Objects +---------------- + +You can update the poses of existing nodes with the :meth:`.Scene.set_pose` +function. Simply call it with a :class:`.Node` that is already in the scene +and the new pose of that node with respect to its parent as a 4x4 homogenous +transformation matrix: + +>>> scene.set_pose(nl, pose=np.eye(4)) + +If you want to get the local pose of a node, you can just access its +:attr:`.Node.matrix` attribute. However, if you want to the get +the pose of a node *with respect to the world frame*, you can call the +:meth:`.Scene.get_pose` method. + +>>> tf = scene.get_pose(nl) + +Removing Objects +---------------- + +Finally, you can remove a :class:`.Node` and all of its children from the +scene with the :meth:`.Scene.remove_node` function: + +>>> scene.remove_node(nl) diff --git a/pyrender/docs/source/examples/viewer.rst b/pyrender/docs/source/examples/viewer.rst new file mode 100644 index 0000000000000000000000000000000000000000..00a7973b46ec7da33b51b65581af6f25c1b1652f --- /dev/null +++ b/pyrender/docs/source/examples/viewer.rst @@ -0,0 +1,61 @@ +.. _viewer_guide: + +Live Scene Viewer +================= + +Standard Usage +-------------- +In addition to the offscreen renderer, Pyrender comes with a live scene viewer. +In its standard invocation, calling the :class:`.Viewer`'s constructor will +immediately pop a viewing window that you can navigate around in. + +>>> pyrender.Viewer(scene) + +By default, the viewer uses your scene's lighting. If you'd like to start with +some additional lighting that moves around with the camera, you can specify that +with: + +>>> pyrender.Viewer(scene, use_raymond_lighting=True) + +For a full list of the many options that the :class:`.Viewer` supports, check out its +documentation. + +.. image:: /_static/rotation.gif + +Running the Viewer in a Separate Thread +--------------------------------------- +If you'd like to animate your models, you'll want to run the viewer in a +separate thread so that you can update the scene while the viewer is running. +To do this, first pop the viewer in a separate thread by calling its constructor +with the ``run_in_thread`` option set: + +>>> v = pyrender.Viewer(scene, run_in_thread=True) + +Then, you can manipulate the :class:`.Scene` while the viewer is running to +animate things. However, be careful to acquire the viewer's +:attr:`.Viewer.render_lock` before editing the scene to prevent data corruption: + +>>> i = 0 +>>> while True: +... pose = np.eye(4) +... pose[:3,3] = [i, 0, 0] +... v.render_lock.acquire() +... scene.set_pose(mesh_node, pose) +... v.render_lock.release() +... i += 0.01 + +.. image:: /_static/scissors.gif + +You can wait on the viewer to be closed manually: + +>>> while v.is_active: +... pass + +Or you can close it from the main thread forcibly. +Make sure to still loop and block for the viewer to actually exit before using +the scene object again. + +>>> v.close_external() +>>> while v.is_active: +... pass + diff --git a/pyrender/docs/source/index.rst b/pyrender/docs/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..baf189ede6bb3435cad5b8795e1937ef1a3c2c56 --- /dev/null +++ b/pyrender/docs/source/index.rst @@ -0,0 +1,41 @@ +.. core documentation master file, created by + sphinx-quickstart on Sun Oct 16 14:33:48 2016. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Pyrender Documentation +======================== +Pyrender is a pure Python (2.7, 3.4, 3.5, 3.6) library for physically-based +rendering and visualization. +It is designed to meet the glTF 2.0 specification_ from Khronos + +.. _specification: https://www.khronos.org/gltf/ + +Pyrender is lightweight, easy to install, and simple to use. +It comes packaged with both an intuitive scene viewer and a headache-free +offscreen renderer with support for GPU-accelerated rendering on headless +servers, which makes it perfect for machine learning applications. +Check out the :ref:`guide` for a full tutorial, or fork me on +Github_. + +.. _Github: https://github.com/mmatl/pyrender + +.. image:: _static/rotation.gif + +.. image:: _static/damaged_helmet.png + +.. toctree:: + :maxdepth: 2 + + install/index.rst + examples/index.rst + api/index.rst + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` + diff --git a/pyrender/docs/source/install/index.rst b/pyrender/docs/source/install/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..c785f202d877f8bbaf286c21eddca1925973f75e --- /dev/null +++ b/pyrender/docs/source/install/index.rst @@ -0,0 +1,172 @@ +Installation Guide +================== + +Python Installation +------------------- + +This package is available via ``pip``. + +.. code-block:: bash + + pip install pyrender + +If you're on MacOS, you'll need +to pre-install my fork of ``pyglet``, as the version on PyPI hasn't yet included +my change that enables OpenGL contexts on MacOS. + +.. code-block:: bash + + git clone https://github.com/mmatl/pyglet.git + cd pyglet + pip install . + +.. _osmesa: + +Getting Pyrender Working with OSMesa +------------------------------------ +If you want to render scenes offscreen but don't want to have to +install a display manager or deal with the pains of trying to get +OpenGL to work over SSH, you have two options. + +The first (and preferred) option is using EGL, which enables you to perform +GPU-accelerated rendering on headless servers. +However, you'll need EGL 1.5 to get modern OpenGL contexts. +This comes packaged with NVIDIA's current drivers, but if you are having issues +getting EGL to work with your hardware, you can try using OSMesa, +a software-based offscreen renderer that is included with any Mesa +install. + +If you want to use OSMesa with pyrender, you'll have to perform two additional +installation steps: + +- :ref:`installmesa` +- :ref:`installpyopengl` + +Then, read the offscreen rendering tutorial. See :ref:`offscreen_guide`. + +.. _installmesa: + +Installing OSMesa +***************** + +As a first step, you'll need to rebuild and re-install Mesa with support +for fast offscreen rendering and OpenGL 3+ contexts. +I'd recommend installing from source, but you can also try my ``.deb`` +for Ubuntu 16.04 and up. + +Installing from a Debian Package +******************************** + +If you're running Ubuntu 16.04 or newer, you should be able to install the +required version of Mesa from my ``.deb`` file. + +.. code-block:: bash + + sudo apt update + sudo wget https://github.com/mmatl/travis_debs/raw/master/xenial/mesa_18.3.3-0.deb + sudo dpkg -i ./mesa_18.3.3-0.deb || true + sudo apt install -f + +If this doesn't work, try building from source. + +Building From Source +******************** + +First, install build dependencies via `apt` or your system's package manager. + +.. code-block:: bash + + sudo apt-get install llvm-6.0 freeglut3 freeglut3-dev + +Then, download the current release of Mesa from here_. +Unpack the source and go to the source folder: + +.. _here: https://archive.mesa3d.org/mesa-18.3.3.tar.gz + +.. code-block:: bash + + tar xfv mesa-18.3.3.tar.gz + cd mesa-18.3.3 + +Replace ``PREFIX`` with the path you want to install Mesa at. +If you're not worried about overwriting your default Mesa install, +a good place is at ``/usr/local``. + +Now, configure the installation by running the following command: + +.. code-block:: bash + + ./configure --prefix=PREFIX \ + --enable-opengl --disable-gles1 --disable-gles2 \ + --disable-va --disable-xvmc --disable-vdpau \ + --enable-shared-glapi \ + --disable-texture-float \ + --enable-gallium-llvm --enable-llvm-shared-libs \ + --with-gallium-drivers=swrast,swr \ + --disable-dri --with-dri-drivers= \ + --disable-egl --with-egl-platforms= --disable-gbm \ + --disable-glx \ + --disable-osmesa --enable-gallium-osmesa \ + ac_cv_path_LLVM_CONFIG=llvm-config-6.0 + +Finally, build and install Mesa. + +.. code-block:: bash + + make -j8 + make install + +Finally, if you didn't install Mesa in the system path, +add the following lines to your ``~/.bashrc`` file after +changing ``MESA_HOME`` to your mesa installation path (i.e. what you used as +``PREFIX`` during the configure command). + +.. code-block:: bash + + MESA_HOME=/path/to/your/mesa/installation + export LIBRARY_PATH=$LIBRARY_PATH:$MESA_HOME/lib + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$MESA_HOME/lib + export C_INCLUDE_PATH=$C_INCLUDE_PATH:$MESA_HOME/include/ + export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$MESA_HOME/include/ + +.. _installpyopengl: + +Installing a Compatible Fork of PyOpenGL +**************************************** + +Next, install and use my fork of ``PyOpenGL``. +This fork enables getting modern OpenGL contexts with OSMesa. +My patch has been included in ``PyOpenGL``, but it has not yet been released +on PyPI. + +.. code-block:: bash + + git clone https://github.com/mmatl/pyopengl.git + pip install ./pyopengl + + +Building Documentation +---------------------- + +The online documentation for ``pyrender`` is automatically built by Read The Docs. +Building ``pyrender``'s documentation locally requires a few extra dependencies -- +specifically, `sphinx`_ and a few plugins. + +.. _sphinx: http://www.sphinx-doc.org/en/master/ + +To install the dependencies required, simply change directories into the `pyrender` source and run + +.. code-block:: bash + + $ pip install .[docs] + +Then, go to the ``docs`` directory and run ``make`` with the appropriate target. +For example, + +.. code-block:: bash + + $ cd docs/ + $ make html + +will generate a set of web pages. Any documentation files +generated in this manner can be found in ``docs/build``. diff --git a/pyrender/examples/duck.py b/pyrender/examples/duck.py new file mode 100644 index 0000000000000000000000000000000000000000..9a94bad5bfb30493f7364f2e52cbb4badbccb2c7 --- /dev/null +++ b/pyrender/examples/duck.py @@ -0,0 +1,13 @@ +from pyrender import Mesh, Scene, Viewer +from io import BytesIO +import numpy as np +import trimesh +import requests + +duck_source = "https://github.com/KhronosGroup/glTF-Sample-Models/raw/master/2.0/Duck/glTF-Binary/Duck.glb" + +duck = trimesh.load(BytesIO(requests.get(duck_source).content), file_type='glb') +duckmesh = Mesh.from_trimesh(list(duck.geometry.values())[0]) +scene = Scene(ambient_light=np.array([1.0, 1.0, 1.0, 1.0])) +scene.add(duckmesh) +Viewer(scene) diff --git a/pyrender/examples/example.py b/pyrender/examples/example.py new file mode 100644 index 0000000000000000000000000000000000000000..599a4850a5899cdeb1a76db1c5cf1c91c263cd41 --- /dev/null +++ b/pyrender/examples/example.py @@ -0,0 +1,157 @@ +"""Examples of using pyrender for viewing and offscreen rendering. +""" +import pyglet +pyglet.options['shadow_window'] = False +import os +import numpy as np +import trimesh + +from pyrender import PerspectiveCamera,\ + DirectionalLight, SpotLight, PointLight,\ + MetallicRoughnessMaterial,\ + Primitive, Mesh, Node, Scene,\ + Viewer, OffscreenRenderer, RenderFlags + +#============================================================================== +# Mesh creation +#============================================================================== + +#------------------------------------------------------------------------------ +# Creating textured meshes from trimeshes +#------------------------------------------------------------------------------ + +# Fuze trimesh +fuze_trimesh = trimesh.load('./models/fuze.obj') +fuze_mesh = Mesh.from_trimesh(fuze_trimesh) + +# Drill trimesh +drill_trimesh = trimesh.load('./models/drill.obj') +drill_mesh = Mesh.from_trimesh(drill_trimesh) +drill_pose = np.eye(4) +drill_pose[0,3] = 0.1 +drill_pose[2,3] = -np.min(drill_trimesh.vertices[:,2]) + +# Wood trimesh +wood_trimesh = trimesh.load('./models/wood.obj') +wood_mesh = Mesh.from_trimesh(wood_trimesh) + +# Water bottle trimesh +bottle_gltf = trimesh.load('./models/WaterBottle.glb') +bottle_trimesh = bottle_gltf.geometry[list(bottle_gltf.geometry.keys())[0]] +bottle_mesh = Mesh.from_trimesh(bottle_trimesh) +bottle_pose = np.array([ + [1.0, 0.0, 0.0, 0.1], + [0.0, 0.0, -1.0, -0.16], + [0.0, 1.0, 0.0, 0.13], + [0.0, 0.0, 0.0, 1.0], +]) + +#------------------------------------------------------------------------------ +# Creating meshes with per-vertex colors +#------------------------------------------------------------------------------ +boxv_trimesh = trimesh.creation.box(extents=0.1*np.ones(3)) +boxv_vertex_colors = np.random.uniform(size=(boxv_trimesh.vertices.shape)) +boxv_trimesh.visual.vertex_colors = boxv_vertex_colors +boxv_mesh = Mesh.from_trimesh(boxv_trimesh, smooth=False) + +#------------------------------------------------------------------------------ +# Creating meshes with per-face colors +#------------------------------------------------------------------------------ +boxf_trimesh = trimesh.creation.box(extents=0.1*np.ones(3)) +boxf_face_colors = np.random.uniform(size=boxf_trimesh.faces.shape) +boxf_trimesh.visual.face_colors = boxf_face_colors +boxf_mesh = Mesh.from_trimesh(boxf_trimesh, smooth=False) + +#------------------------------------------------------------------------------ +# Creating meshes from point clouds +#------------------------------------------------------------------------------ +points = trimesh.creation.icosphere(radius=0.05).vertices +point_colors = np.random.uniform(size=points.shape) +points_mesh = Mesh.from_points(points, colors=point_colors) + +#============================================================================== +# Light creation +#============================================================================== + +direc_l = DirectionalLight(color=np.ones(3), intensity=1.0) +spot_l = SpotLight(color=np.ones(3), intensity=10.0, + innerConeAngle=np.pi/16, outerConeAngle=np.pi/6) +point_l = PointLight(color=np.ones(3), intensity=10.0) + +#============================================================================== +# Camera creation +#============================================================================== + +cam = PerspectiveCamera(yfov=(np.pi / 3.0)) +cam_pose = np.array([ + [0.0, -np.sqrt(2)/2, np.sqrt(2)/2, 0.5], + [1.0, 0.0, 0.0, 0.0], + [0.0, np.sqrt(2)/2, np.sqrt(2)/2, 0.4], + [0.0, 0.0, 0.0, 1.0] +]) + +#============================================================================== +# Scene creation +#============================================================================== + +scene = Scene(ambient_light=np.array([0.02, 0.02, 0.02, 1.0])) + +#============================================================================== +# Adding objects to the scene +#============================================================================== + +#------------------------------------------------------------------------------ +# By manually creating nodes +#------------------------------------------------------------------------------ +fuze_node = Node(mesh=fuze_mesh, translation=np.array([0.1, 0.15, -np.min(fuze_trimesh.vertices[:,2])])) +scene.add_node(fuze_node) +boxv_node = Node(mesh=boxv_mesh, translation=np.array([-0.1, 0.10, 0.05])) +scene.add_node(boxv_node) +boxf_node = Node(mesh=boxf_mesh, translation=np.array([-0.1, -0.10, 0.05])) +scene.add_node(boxf_node) + +#------------------------------------------------------------------------------ +# By using the add() utility function +#------------------------------------------------------------------------------ +drill_node = scene.add(drill_mesh, pose=drill_pose) +bottle_node = scene.add(bottle_mesh, pose=bottle_pose) +wood_node = scene.add(wood_mesh) +direc_l_node = scene.add(direc_l, pose=cam_pose) +spot_l_node = scene.add(spot_l, pose=cam_pose) + +#============================================================================== +# Using the viewer with a default camera +#============================================================================== + +v = Viewer(scene, shadows=True) + +#============================================================================== +# Using the viewer with a pre-specified camera +#============================================================================== +cam_node = scene.add(cam, pose=cam_pose) +v = Viewer(scene, central_node=drill_node) + +#============================================================================== +# Rendering offscreen from that camera +#============================================================================== + +r = OffscreenRenderer(viewport_width=640*2, viewport_height=480*2) +color, depth = r.render(scene) + +import matplotlib.pyplot as plt +plt.figure() +plt.imshow(color) +plt.show() + +#============================================================================== +# Segmask rendering +#============================================================================== + +nm = {node: 20*(i + 1) for i, node in enumerate(scene.mesh_nodes)} +seg = r.render(scene, RenderFlags.SEG, nm)[0] +plt.figure() +plt.imshow(seg) +plt.show() + +r.delete() + diff --git a/pyrender/pyrender/__init__.py b/pyrender/pyrender/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3709846823b7c4b71b22da0e24d63d805528a8 --- /dev/null +++ b/pyrender/pyrender/__init__.py @@ -0,0 +1,24 @@ +from .camera import (Camera, PerspectiveCamera, OrthographicCamera, + IntrinsicsCamera) +from .light import Light, PointLight, DirectionalLight, SpotLight +from .sampler import Sampler +from .texture import Texture +from .material import Material, MetallicRoughnessMaterial +from .primitive import Primitive +from .mesh import Mesh +from .node import Node +from .scene import Scene +from .renderer import Renderer +from .viewer import Viewer +from .offscreen import OffscreenRenderer +from .version import __version__ +from .constants import RenderFlags, TextAlign, GLTF + +__all__ = [ + 'Camera', 'PerspectiveCamera', 'OrthographicCamera', 'IntrinsicsCamera', + 'Light', 'PointLight', 'DirectionalLight', 'SpotLight', + 'Sampler', 'Texture', 'Material', 'MetallicRoughnessMaterial', + 'Primitive', 'Mesh', 'Node', 'Scene', 'Renderer', 'Viewer', + 'OffscreenRenderer', '__version__', 'RenderFlags', 'TextAlign', + 'GLTF' +] diff --git a/pyrender/pyrender/camera.py b/pyrender/pyrender/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..e019358039033c3a372c990ebad3151258c3651d --- /dev/null +++ b/pyrender/pyrender/camera.py @@ -0,0 +1,437 @@ +"""Virtual cameras compliant with the glTF 2.0 specification as described at +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-camera + +Author: Matthew Matl +""" +import abc +import numpy as np +import six +import sys + +from .constants import DEFAULT_Z_NEAR, DEFAULT_Z_FAR + + +@six.add_metaclass(abc.ABCMeta) +class Camera(object): + """Abstract base class for all cameras. + + Note + ---- + Camera poses are specified in the OpenGL format, + where the z axis points away from the view direction and the + x and y axes point to the right and up in the image plane, respectively. + + Parameters + ---------- + znear : float + The floating-point distance to the near clipping plane. + zfar : float + The floating-point distance to the far clipping plane. + ``zfar`` must be greater than ``znear``. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + znear=DEFAULT_Z_NEAR, + zfar=DEFAULT_Z_FAR, + name=None): + self.name = name + self.znear = znear + self.zfar = zfar + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def znear(self): + """float : The distance to the near clipping plane. + """ + return self._znear + + @znear.setter + def znear(self, value): + value = float(value) + if value < 0: + raise ValueError('z-near must be >= 0.0') + self._znear = value + + @property + def zfar(self): + """float : The distance to the far clipping plane. + """ + return self._zfar + + @zfar.setter + def zfar(self, value): + value = float(value) + if value <= 0 or value <= self.znear: + raise ValueError('zfar must be >0 and >znear') + self._zfar = value + + @abc.abstractmethod + def get_projection_matrix(self, width=None, height=None): + """Return the OpenGL projection matrix for this camera. + + Parameters + ---------- + width : int + Width of the current viewport, in pixels. + height : int + Height of the current viewport, in pixels. + """ + pass + + +class PerspectiveCamera(Camera): + + """A perspective camera for perspective projection. + + Parameters + ---------- + yfov : float + The floating-point vertical field of view in radians. + znear : float + The floating-point distance to the near clipping plane. + If not specified, defaults to 0.05. + zfar : float, optional + The floating-point distance to the far clipping plane. + ``zfar`` must be greater than ``znear``. + If None, the camera uses an infinite projection matrix. + aspectRatio : float, optional + The floating-point aspect ratio of the field of view. + If not specified, the camera uses the viewport's aspect ratio. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + yfov, + znear=DEFAULT_Z_NEAR, + zfar=None, + aspectRatio=None, + name=None): + super(PerspectiveCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + + self.yfov = yfov + self.aspectRatio = aspectRatio + + @property + def yfov(self): + """float : The vertical field of view in radians. + """ + return self._yfov + + @yfov.setter + def yfov(self, value): + value = float(value) + if value <= 0.0: + raise ValueError('Field of view must be positive') + self._yfov = value + + @property + def zfar(self): + """float : The distance to the far clipping plane. + """ + return self._zfar + + @zfar.setter + def zfar(self, value): + if value is not None: + value = float(value) + if value <= 0 or value <= self.znear: + raise ValueError('zfar must be >0 and >znear') + self._zfar = value + + @property + def aspectRatio(self): + """float : The ratio of the width to the height of the field of view. + """ + return self._aspectRatio + + @aspectRatio.setter + def aspectRatio(self, value): + if value is not None: + value = float(value) + if value <= 0.0: + raise ValueError('Aspect ratio must be positive') + self._aspectRatio = value + + def get_projection_matrix(self, width=None, height=None): + """Return the OpenGL projection matrix for this camera. + + Parameters + ---------- + width : int + Width of the current viewport, in pixels. + height : int + Height of the current viewport, in pixels. + """ + aspect_ratio = self.aspectRatio + if aspect_ratio is None: + if width is None or height is None: + raise ValueError('Aspect ratio of camera must be defined') + aspect_ratio = float(width) / float(height) + + a = aspect_ratio + t = np.tan(self.yfov / 2.0) + n = self.znear + f = self.zfar + + P = np.zeros((4,4)) + P[0][0] = 1.0 / (a * t) + P[1][1] = 1.0 / t + P[3][2] = -1.0 + + if f is None: + P[2][2] = -1.0 + P[2][3] = -2.0 * n + else: + P[2][2] = (f + n) / (n - f) + P[2][3] = (2 * f * n) / (n - f) + + return P + + +class OrthographicCamera(Camera): + """An orthographic camera for orthographic projection. + + Parameters + ---------- + xmag : float + The floating-point horizontal magnification of the view. + ymag : float + The floating-point vertical magnification of the view. + znear : float + The floating-point distance to the near clipping plane. + If not specified, defaults to 0.05. + zfar : float + The floating-point distance to the far clipping plane. + ``zfar`` must be greater than ``znear``. + If not specified, defaults to 100.0. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + xmag, + ymag, + znear=DEFAULT_Z_NEAR, + zfar=DEFAULT_Z_FAR, + name=None): + super(OrthographicCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + + self.xmag = xmag + self.ymag = ymag + + @property + def xmag(self): + """float : The horizontal magnification of the view. + """ + return self._xmag + + @xmag.setter + def xmag(self, value): + value = float(value) + if value <= 0.0: + raise ValueError('X magnification must be positive') + self._xmag = value + + @property + def ymag(self): + """float : The vertical magnification of the view. + """ + return self._ymag + + @ymag.setter + def ymag(self, value): + value = float(value) + if value <= 0.0: + raise ValueError('Y magnification must be positive') + self._ymag = value + + @property + def znear(self): + """float : The distance to the near clipping plane. + """ + return self._znear + + @znear.setter + def znear(self, value): + value = float(value) + if value <= 0: + raise ValueError('z-near must be > 0.0') + self._znear = value + + def get_projection_matrix(self, width=None, height=None): + """Return the OpenGL projection matrix for this camera. + + Parameters + ---------- + width : int + Width of the current viewport, in pixels. + Unused in this function. + height : int + Height of the current viewport, in pixels. + Unused in this function. + """ + xmag = self.xmag + ymag = self.ymag + + # If screen width/height defined, rescale xmag + if width is not None and height is not None: + xmag = width / height * ymag + + n = self.znear + f = self.zfar + P = np.zeros((4,4)) + P[0][0] = 1.0 / xmag + P[1][1] = 1.0 / ymag + P[2][2] = 2.0 / (n - f) + P[2][3] = (f + n) / (n - f) + P[3][3] = 1.0 + return P + + +class IntrinsicsCamera(Camera): + """A perspective camera with custom intrinsics. + + Parameters + ---------- + fx : float + X-axis focal length in pixels. + fy : float + Y-axis focal length in pixels. + cx : float + X-axis optical center in pixels. + cy : float + Y-axis optical center in pixels. + znear : float + The floating-point distance to the near clipping plane. + If not specified, defaults to 0.05. + zfar : float + The floating-point distance to the far clipping plane. + ``zfar`` must be greater than ``znear``. + If not specified, defaults to 100.0. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + fx, + fy, + cx, + cy, + znear=DEFAULT_Z_NEAR, + zfar=DEFAULT_Z_FAR, + name=None): + super(IntrinsicsCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + + @property + def fx(self): + """float : X-axis focal length in meters. + """ + return self._fx + + @fx.setter + def fx(self, value): + self._fx = float(value) + + @property + def fy(self): + """float : Y-axis focal length in meters. + """ + return self._fy + + @fy.setter + def fy(self, value): + self._fy = float(value) + + @property + def cx(self): + """float : X-axis optical center in pixels. + """ + return self._cx + + @cx.setter + def cx(self, value): + self._cx = float(value) + + @property + def cy(self): + """float : Y-axis optical center in pixels. + """ + return self._cy + + @cy.setter + def cy(self, value): + self._cy = float(value) + + def get_projection_matrix(self, width, height): + """Return the OpenGL projection matrix for this camera. + + Parameters + ---------- + width : int + Width of the current viewport, in pixels. + height : int + Height of the current viewport, in pixels. + """ + width = float(width) + height = float(height) + + cx, cy = self.cx, self.cy + fx, fy = self.fx, self.fy + if sys.platform == 'darwin': + cx = self.cx * 2.0 + cy = self.cy * 2.0 + fx = self.fx * 2.0 + fy = self.fy * 2.0 + + P = np.zeros((4,4)) + P[0][0] = 2.0 * fx / width + P[1][1] = 2.0 * fy / height + P[0][2] = 1.0 - 2.0 * cx / width + P[1][2] = 2.0 * cy / height - 1.0 + P[3][2] = -1.0 + + n = self.znear + f = self.zfar + if f is None: + P[2][2] = -1.0 + P[2][3] = -2.0 * n + else: + P[2][2] = (f + n) / (n - f) + P[2][3] = (2 * f * n) / (n - f) + + return P + + +__all__ = ['Camera', 'PerspectiveCamera', 'OrthographicCamera', + 'IntrinsicsCamera'] diff --git a/pyrender/pyrender/constants.py b/pyrender/pyrender/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5785b6fdb21910a174252c5af2f05b40ece4a5 --- /dev/null +++ b/pyrender/pyrender/constants.py @@ -0,0 +1,149 @@ +DEFAULT_Z_NEAR = 0.05 # Near clipping plane, in meters +DEFAULT_Z_FAR = 100.0 # Far clipping plane, in meters +DEFAULT_SCENE_SCALE = 2.0 # Default scene scale +MAX_N_LIGHTS = 4 # Maximum number of lights of each type allowed +TARGET_OPEN_GL_MAJOR = 4 # Target OpenGL Major Version +TARGET_OPEN_GL_MINOR = 1 # Target OpenGL Minor Version +MIN_OPEN_GL_MAJOR = 3 # Minimum OpenGL Major Version +MIN_OPEN_GL_MINOR = 3 # Minimum OpenGL Minor Version +FLOAT_SZ = 4 # Byte size of GL float32 +UINT_SZ = 4 # Byte size of GL uint32 +SHADOW_TEX_SZ = 2048 # Width and Height of Shadow Textures +TEXT_PADDING = 20 # Width of padding for rendering text (px) + + +# Flags for render type +class RenderFlags(object): + """Flags for rendering in the scene. + + Combine them with the bitwise or. For example, + + >>> flags = OFFSCREEN | SHADOWS_DIRECTIONAL | VERTEX_NORMALS + + would result in an offscreen render with directional shadows and + vertex normals enabled. + """ + NONE = 0 + """Normal PBR Render.""" + DEPTH_ONLY = 1 + """Only render the depth buffer.""" + OFFSCREEN = 2 + """Render offscreen and return the depth and (optionally) color buffers.""" + FLIP_WIREFRAME = 4 + """Invert the status of wireframe rendering for each mesh.""" + ALL_WIREFRAME = 8 + """Render all meshes as wireframes.""" + ALL_SOLID = 16 + """Render all meshes as solids.""" + SHADOWS_DIRECTIONAL = 32 + """Render shadows for directional lights.""" + SHADOWS_POINT = 64 + """Render shadows for point lights.""" + SHADOWS_SPOT = 128 + """Render shadows for spot lights.""" + SHADOWS_ALL = 32 | 64 | 128 + """Render shadows for all lights.""" + VERTEX_NORMALS = 256 + """Render vertex normals.""" + FACE_NORMALS = 512 + """Render face normals.""" + SKIP_CULL_FACES = 1024 + """Do not cull back faces.""" + RGBA = 2048 + """Render the color buffer with the alpha channel enabled.""" + FLAT = 4096 + """Render the color buffer flat, with no lighting computations.""" + SEG = 8192 + + +class TextAlign: + """Text alignment options for captions. + + Only use one at a time. + """ + CENTER = 0 + """Center the text by width and height.""" + CENTER_LEFT = 1 + """Center the text by height and left-align it.""" + CENTER_RIGHT = 2 + """Center the text by height and right-align it.""" + BOTTOM_LEFT = 3 + """Put the text in the bottom-left corner.""" + BOTTOM_RIGHT = 4 + """Put the text in the bottom-right corner.""" + BOTTOM_CENTER = 5 + """Center the text by width and fix it to the bottom.""" + TOP_LEFT = 6 + """Put the text in the top-left corner.""" + TOP_RIGHT = 7 + """Put the text in the top-right corner.""" + TOP_CENTER = 8 + """Center the text by width and fix it to the top.""" + + +class GLTF(object): + """Options for GL objects.""" + NEAREST = 9728 + """Nearest neighbor interpolation.""" + LINEAR = 9729 + """Linear interpolation.""" + NEAREST_MIPMAP_NEAREST = 9984 + """Nearest mipmapping.""" + LINEAR_MIPMAP_NEAREST = 9985 + """Linear mipmapping.""" + NEAREST_MIPMAP_LINEAR = 9986 + """Nearest mipmapping.""" + LINEAR_MIPMAP_LINEAR = 9987 + """Linear mipmapping.""" + CLAMP_TO_EDGE = 33071 + """Clamp to the edge of the texture.""" + MIRRORED_REPEAT = 33648 + """Mirror the texture.""" + REPEAT = 10497 + """Repeat the texture.""" + POINTS = 0 + """Render as points.""" + LINES = 1 + """Render as lines.""" + LINE_LOOP = 2 + """Render as a line loop.""" + LINE_STRIP = 3 + """Render as a line strip.""" + TRIANGLES = 4 + """Render as triangles.""" + TRIANGLE_STRIP = 5 + """Render as a triangle strip.""" + TRIANGLE_FAN = 6 + """Render as a triangle fan.""" + + +class BufFlags(object): + POSITION = 0 + NORMAL = 1 + TANGENT = 2 + TEXCOORD_0 = 4 + TEXCOORD_1 = 8 + COLOR_0 = 16 + JOINTS_0 = 32 + WEIGHTS_0 = 64 + + +class TexFlags(object): + NONE = 0 + NORMAL = 1 + OCCLUSION = 2 + EMISSIVE = 4 + BASE_COLOR = 8 + METALLIC_ROUGHNESS = 16 + DIFFUSE = 32 + SPECULAR_GLOSSINESS = 64 + + +class ProgramFlags: + NONE = 0 + USE_MATERIAL = 1 + VERTEX_NORMALS = 2 + FACE_NORMALS = 4 + + +__all__ = ['RenderFlags', 'TextAlign', 'GLTF'] diff --git a/pyrender/pyrender/font.py b/pyrender/pyrender/font.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac530d7b949f50314a0d9cf5d744bedcace0571 --- /dev/null +++ b/pyrender/pyrender/font.py @@ -0,0 +1,272 @@ +"""Font texture loader and processor. + +Author: Matthew Matl +""" +import freetype +import numpy as np +import os + +import OpenGL +from OpenGL.GL import * + +from .constants import TextAlign, FLOAT_SZ +from .texture import Texture +from .sampler import Sampler + + +class FontCache(object): + """A cache for fonts. + """ + + def __init__(self, font_dir=None): + self._font_cache = {} + self.font_dir = font_dir + if self.font_dir is None: + base_dir, _ = os.path.split(os.path.realpath(__file__)) + self.font_dir = os.path.join(base_dir, 'fonts') + + def get_font(self, font_name, font_pt): + # If it's a file, load it directly, else, try to load from font dir. + if os.path.isfile(font_name): + font_filename = font_name + _, font_name = os.path.split(font_name) + font_name, _ = os.path.split(font_name) + else: + font_filename = os.path.join(self.font_dir, font_name) + '.ttf' + + cid = OpenGL.contextdata.getContext() + key = (cid, font_name, int(font_pt)) + + if key not in self._font_cache: + self._font_cache[key] = Font(font_filename, font_pt) + return self._font_cache[key] + + def clear(self): + for key in self._font_cache: + self._font_cache[key].delete() + self._font_cache = {} + + +class Character(object): + """A single character, with its texture and attributes. + """ + + def __init__(self, texture, size, bearing, advance): + self.texture = texture + self.size = size + self.bearing = bearing + self.advance = advance + + +class Font(object): + """A font object. + + Parameters + ---------- + font_file : str + The file to load the font from. + font_pt : int + The height of the font in pixels. + """ + + def __init__(self, font_file, font_pt=40): + self.font_file = font_file + self.font_pt = int(font_pt) + self._face = freetype.Face(font_file) + self._face.set_pixel_sizes(0, font_pt) + self._character_map = {} + + for i in range(0, 128): + + # Generate texture + face = self._face + face.load_char(chr(i)) + buf = face.glyph.bitmap.buffer + src = (np.array(buf) / 255.0).astype(np.float32) + src = src.reshape((face.glyph.bitmap.rows, + face.glyph.bitmap.width)) + tex = Texture( + sampler=Sampler( + magFilter=GL_LINEAR, + minFilter=GL_LINEAR, + wrapS=GL_CLAMP_TO_EDGE, + wrapT=GL_CLAMP_TO_EDGE + ), + source=src, + source_channels='R', + ) + character = Character( + texture=tex, + size=np.array([face.glyph.bitmap.width, + face.glyph.bitmap.rows]), + bearing=np.array([face.glyph.bitmap_left, + face.glyph.bitmap_top]), + advance=face.glyph.advance.x + ) + self._character_map[chr(i)] = character + + self._vbo = None + self._vao = None + + @property + def font_file(self): + """str : The file the font was loaded from. + """ + return self._font_file + + @font_file.setter + def font_file(self, value): + self._font_file = value + + @property + def font_pt(self): + """int : The height of the font in pixels. + """ + return self._font_pt + + @font_pt.setter + def font_pt(self, value): + self._font_pt = int(value) + + def _add_to_context(self): + + self._vao = glGenVertexArrays(1) + glBindVertexArray(self._vao) + self._vbo = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self._vbo) + glBufferData(GL_ARRAY_BUFFER, FLOAT_SZ * 6 * 4, None, GL_DYNAMIC_DRAW) + glEnableVertexAttribArray(0) + glVertexAttribPointer( + 0, 4, GL_FLOAT, GL_FALSE, 4 * FLOAT_SZ, ctypes.c_void_p(0) + ) + glBindVertexArray(0) + + glPixelStorei(GL_UNPACK_ALIGNMENT, 1) + for c in self._character_map: + ch = self._character_map[c] + if not ch.texture._in_context(): + ch.texture._add_to_context() + + def _remove_from_context(self): + for c in self._character_map: + ch = self._character_map[c] + ch.texture.delete() + if self._vao is not None: + glDeleteVertexArrays(1, [self._vao]) + glDeleteBuffers(1, [self._vbo]) + self._vao = None + self._vbo = None + + def _in_context(self): + return self._vao is not None + + def _bind(self): + glBindVertexArray(self._vao) + + def _unbind(self): + glBindVertexArray(0) + + def delete(self): + self._unbind() + self._remove_from_context() + + def render_string(self, text, x, y, scale=1.0, + align=TextAlign.BOTTOM_LEFT): + """Render a string to the current view buffer. + + Note + ---- + Assumes correct shader program already bound w/ uniforms set. + + Parameters + ---------- + text : str + The text to render. + x : int + Horizontal pixel location of text. + y : int + Vertical pixel location of text. + scale : int + Scaling factor for text. + align : int + One of the TextAlign options which specifies where the ``x`` + and ``y`` parameters lie on the text. For example, + :attr:`.TextAlign.BOTTOM_LEFT` means that ``x`` and ``y`` indicate + the position of the bottom-left corner of the textbox. + """ + glActiveTexture(GL_TEXTURE0) + glEnable(GL_BLEND) + glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) + glDisable(GL_DEPTH_TEST) + glPolygonMode(GL_FRONT_AND_BACK, GL_FILL) + self._bind() + + # Determine width and height of text relative to x, y + width = 0.0 + height = 0.0 + for c in text: + ch = self._character_map[c] + height = max(height, ch.bearing[1] * scale) + width += (ch.advance >> 6) * scale + + # Determine offsets based on alignments + xoff = 0 + yoff = 0 + if align == TextAlign.BOTTOM_RIGHT: + xoff = -width + elif align == TextAlign.BOTTOM_CENTER: + xoff = -width / 2.0 + elif align == TextAlign.TOP_LEFT: + yoff = -height + elif align == TextAlign.TOP_RIGHT: + yoff = -height + xoff = -width + elif align == TextAlign.TOP_CENTER: + yoff = -height + xoff = -width / 2.0 + elif align == TextAlign.CENTER: + xoff = -width / 2.0 + yoff = -height / 2.0 + elif align == TextAlign.CENTER_LEFT: + yoff = -height / 2.0 + elif align == TextAlign.CENTER_RIGHT: + xoff = -width + yoff = -height / 2.0 + + x += xoff + y += yoff + + ch = None + for c in text: + ch = self._character_map[c] + xpos = x + ch.bearing[0] * scale + ypos = y - (ch.size[1] - ch.bearing[1]) * scale + w = ch.size[0] * scale + h = ch.size[1] * scale + + vertices = np.array([ + [xpos, ypos, 0.0, 0.0], + [xpos + w, ypos, 1.0, 0.0], + [xpos + w, ypos + h, 1.0, 1.0], + [xpos + w, ypos + h, 1.0, 1.0], + [xpos, ypos + h, 0.0, 1.0], + [xpos, ypos, 0.0, 0.0], + ], dtype=np.float32) + + ch.texture._bind() + + glBindBuffer(GL_ARRAY_BUFFER, self._vbo) + glBufferData( + GL_ARRAY_BUFFER, FLOAT_SZ * 6 * 4, vertices, GL_DYNAMIC_DRAW + ) + # TODO MAKE THIS MORE EFFICIENT, lgBufferSubData is broken + # glBufferSubData( + # GL_ARRAY_BUFFER, 0, 6 * 4 * FLOAT_SZ, + # np.ascontiguousarray(vertices.flatten) + # ) + glDrawArrays(GL_TRIANGLES, 0, 6) + x += (ch.advance >> 6) * scale + + self._unbind() + if ch: + ch.texture._unbind() diff --git a/pyrender/pyrender/fonts/OpenSans-Bold.ttf b/pyrender/pyrender/fonts/OpenSans-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..fd79d43bea0293ac1b20e8aca1142627983d2c07 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Bold.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-BoldItalic.ttf b/pyrender/pyrender/fonts/OpenSans-BoldItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..9bc800958a421d937fc392e00beaef4eea76dc71 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-BoldItalic.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-ExtraBold.ttf b/pyrender/pyrender/fonts/OpenSans-ExtraBold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..21f6f84a0799946fc4ae02c52b27e61c3762c745 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-ExtraBold.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-ExtraBoldItalic.ttf b/pyrender/pyrender/fonts/OpenSans-ExtraBoldItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..31cb688340eff462dddf47efbb4dfef66cb7fbed Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-ExtraBoldItalic.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-Italic.ttf b/pyrender/pyrender/fonts/OpenSans-Italic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..c90da48ff3b8ad6167236d70c48df4d7b5de3bbb Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Italic.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-Light.ttf b/pyrender/pyrender/fonts/OpenSans-Light.ttf new file mode 100644 index 0000000000000000000000000000000000000000..0d381897da20345fa63112f19042561f44ee3aa0 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Light.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-LightItalic.ttf b/pyrender/pyrender/fonts/OpenSans-LightItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..68299c4bc6b5b7adfff2c9aee4aed7c1547100ef Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-LightItalic.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-Regular.ttf b/pyrender/pyrender/fonts/OpenSans-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..db433349b7047f72f40072630c1bc110620bf09e Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Regular.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-Semibold.ttf b/pyrender/pyrender/fonts/OpenSans-Semibold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..1a7679e3949fb045f152f456bc4adad31e8b9f55 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Semibold.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-SemiboldItalic.ttf b/pyrender/pyrender/fonts/OpenSans-SemiboldItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..59b6d16b065f6baa6f70ddbd4322a4f44bb9636a Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-SemiboldItalic.ttf differ diff --git a/pyrender/pyrender/light.py b/pyrender/pyrender/light.py new file mode 100644 index 0000000000000000000000000000000000000000..333d9e4e553a245c259251a89b69cb46b73b1278 --- /dev/null +++ b/pyrender/pyrender/light.py @@ -0,0 +1,385 @@ +"""Punctual light sources as defined by the glTF 2.0 KHR extension at +https://github.com/KhronosGroup/glTF/tree/master/extensions/2.0/Khronos/KHR_lights_punctual + +Author: Matthew Matl +""" +import abc +import numpy as np +import six + +from OpenGL.GL import * + +from .utils import format_color_vector +from .texture import Texture +from .constants import SHADOW_TEX_SZ +from .camera import OrthographicCamera, PerspectiveCamera + + + +@six.add_metaclass(abc.ABCMeta) +class Light(object): + """Base class for all light objects. + + Parameters + ---------- + color : (3,) float + RGB value for the light's color in linear space. + intensity : float + Brightness of light. The units that this is defined in depend on the + type of light. Point and spot lights use luminous intensity in candela + (lm/sr), while directional lights use illuminance in lux (lm/m2). + name : str, optional + Name of the light. + """ + def __init__(self, + color=None, + intensity=None, + name=None): + + if color is None: + color = np.ones(3) + if intensity is None: + intensity = 1.0 + + self.name = name + self.color = color + self.intensity = intensity + self._shadow_camera = None + self._shadow_texture = None + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def color(self): + """(3,) float : The light's color. + """ + return self._color + + @color.setter + def color(self, value): + self._color = format_color_vector(value, 3) + + @property + def intensity(self): + """float : The light's intensity in candela or lux. + """ + return self._intensity + + @intensity.setter + def intensity(self, value): + self._intensity = float(value) + + @property + def shadow_texture(self): + """:class:`.Texture` : A texture used to hold shadow maps for this light. + """ + return self._shadow_texture + + @shadow_texture.setter + def shadow_texture(self, value): + if self._shadow_texture is not None: + if self._shadow_texture._in_context(): + self._shadow_texture.delete() + self._shadow_texture = value + + @abc.abstractmethod + def _generate_shadow_texture(self, size=None): + """Generate a shadow texture for this light. + + Parameters + ---------- + size : int, optional + Size of texture map. Must be a positive power of two. + """ + pass + + @abc.abstractmethod + def _get_shadow_camera(self, scene_scale): + """Generate and return a shadow mapping camera for this light. + + Parameters + ---------- + scene_scale : float + Length of scene's bounding box diagonal. + + Returns + ------- + camera : :class:`.Camera` + The camera used to render shadowmaps for this light. + """ + pass + + +class DirectionalLight(Light): + """Directional lights are light sources that act as though they are + infinitely far away and emit light in the direction of the local -z axis. + This light type inherits the orientation of the node that it belongs to; + position and scale are ignored except for their effect on the inherited + node orientation. Because it is at an infinite distance, the light is + not attenuated. Its intensity is defined in lumens per metre squared, + or lux (lm/m2). + + Parameters + ---------- + color : (3,) float, optional + RGB value for the light's color in linear space. Defaults to white + (i.e. [1.0, 1.0, 1.0]). + intensity : float, optional + Brightness of light, in lux (lm/m^2). Defaults to 1.0 + name : str, optional + Name of the light. + """ + + def __init__(self, + color=None, + intensity=None, + name=None): + super(DirectionalLight, self).__init__( + color=color, + intensity=intensity, + name=name, + ) + + def _generate_shadow_texture(self, size=None): + """Generate a shadow texture for this light. + + Parameters + ---------- + size : int, optional + Size of texture map. Must be a positive power of two. + """ + if size is None: + size = SHADOW_TEX_SZ + self.shadow_texture = Texture(width=size, height=size, + source_channels='D', data_format=GL_FLOAT) + + def _get_shadow_camera(self, scene_scale): + """Generate and return a shadow mapping camera for this light. + + Parameters + ---------- + scene_scale : float + Length of scene's bounding box diagonal. + + Returns + ------- + camera : :class:`.Camera` + The camera used to render shadowmaps for this light. + """ + return OrthographicCamera( + znear=0.01 * scene_scale, + zfar=10 * scene_scale, + xmag=scene_scale, + ymag=scene_scale + ) + + +class PointLight(Light): + """Point lights emit light in all directions from their position in space; + rotation and scale are ignored except for their effect on the inherited + node position. The brightness of the light attenuates in a physically + correct manner as distance increases from the light's position (i.e. + brightness goes like the inverse square of the distance). Point light + intensity is defined in candela, which is lumens per square radian (lm/sr). + + Parameters + ---------- + color : (3,) float + RGB value for the light's color in linear space. + intensity : float + Brightness of light in candela (lm/sr). + range : float + Cutoff distance at which light's intensity may be considered to + have reached zero. If None, the range is assumed to be infinite. + name : str, optional + Name of the light. + """ + + def __init__(self, + color=None, + intensity=None, + range=None, + name=None): + super(PointLight, self).__init__( + color=color, + intensity=intensity, + name=name, + ) + self.range = range + + @property + def range(self): + """float : The cutoff distance for the light. + """ + return self._range + + @range.setter + def range(self, value): + if value is not None: + value = float(value) + if value <= 0: + raise ValueError('Range must be > 0') + self._range = value + self._range = value + + def _generate_shadow_texture(self, size=None): + """Generate a shadow texture for this light. + + Parameters + ---------- + size : int, optional + Size of texture map. Must be a positive power of two. + """ + raise NotImplementedError('Shadows not implemented for point lights') + + def _get_shadow_camera(self, scene_scale): + """Generate and return a shadow mapping camera for this light. + + Parameters + ---------- + scene_scale : float + Length of scene's bounding box diagonal. + + Returns + ------- + camera : :class:`.Camera` + The camera used to render shadowmaps for this light. + """ + raise NotImplementedError('Shadows not implemented for point lights') + + +class SpotLight(Light): + """Spot lights emit light in a cone in the direction of the local -z axis. + The angle and falloff of the cone is defined using two numbers, the + ``innerConeAngle`` and ``outerConeAngle``. + As with point lights, the brightness + also attenuates in a physically correct manner as distance increases from + the light's position (i.e. brightness goes like the inverse square of the + distance). Spot light intensity refers to the brightness inside the + ``innerConeAngle`` (and at the location of the light) and is defined in + candela, which is lumens per square radian (lm/sr). A spot light's position + and orientation are inherited from its node transform. Inherited scale does + not affect cone shape, and is ignored except for its effect on position + and orientation. + + Parameters + ---------- + color : (3,) float + RGB value for the light's color in linear space. + intensity : float + Brightness of light in candela (lm/sr). + range : float + Cutoff distance at which light's intensity may be considered to + have reached zero. If None, the range is assumed to be infinite. + innerConeAngle : float + Angle, in radians, from centre of spotlight where falloff begins. + Must be greater than or equal to ``0`` and less + than ``outerConeAngle``. Defaults to ``0``. + outerConeAngle : float + Angle, in radians, from centre of spotlight where falloff ends. + Must be greater than ``innerConeAngle`` and less than or equal to + ``PI / 2.0``. Defaults to ``PI / 4.0``. + name : str, optional + Name of the light. + """ + + def __init__(self, + color=None, + intensity=None, + range=None, + innerConeAngle=0.0, + outerConeAngle=(np.pi / 4.0), + name=None): + super(SpotLight, self).__init__( + name=name, + color=color, + intensity=intensity, + ) + self.outerConeAngle = outerConeAngle + self.innerConeAngle = innerConeAngle + self.range = range + + @property + def innerConeAngle(self): + """float : The inner cone angle in radians. + """ + return self._innerConeAngle + + @innerConeAngle.setter + def innerConeAngle(self, value): + if value < 0.0 or value > self.outerConeAngle: + raise ValueError('Invalid value for inner cone angle') + self._innerConeAngle = float(value) + + @property + def outerConeAngle(self): + """float : The outer cone angle in radians. + """ + return self._outerConeAngle + + @outerConeAngle.setter + def outerConeAngle(self, value): + if value < 0.0 or value > np.pi / 2.0 + 1e-9: + raise ValueError('Invalid value for outer cone angle') + self._outerConeAngle = float(value) + + @property + def range(self): + """float : The cutoff distance for the light. + """ + return self._range + + @range.setter + def range(self, value): + if value is not None: + value = float(value) + if value <= 0: + raise ValueError('Range must be > 0') + self._range = value + self._range = value + + def _generate_shadow_texture(self, size=None): + """Generate a shadow texture for this light. + + Parameters + ---------- + size : int, optional + Size of texture map. Must be a positive power of two. + """ + if size is None: + size = SHADOW_TEX_SZ + self.shadow_texture = Texture(width=size, height=size, + source_channels='D', data_format=GL_FLOAT) + + def _get_shadow_camera(self, scene_scale): + """Generate and return a shadow mapping camera for this light. + + Parameters + ---------- + scene_scale : float + Length of scene's bounding box diagonal. + + Returns + ------- + camera : :class:`.Camera` + The camera used to render shadowmaps for this light. + """ + return PerspectiveCamera( + znear=0.01 * scene_scale, + zfar=10 * scene_scale, + yfov=np.clip(2 * self.outerConeAngle + np.pi / 16.0, 0.0, np.pi), + aspectRatio=1.0 + ) + + +__all__ = ['Light', 'DirectionalLight', 'SpotLight', 'PointLight'] diff --git a/pyrender/pyrender/material.py b/pyrender/pyrender/material.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce9c2d184ed213c84b015e36bea558cd1efc6b7 --- /dev/null +++ b/pyrender/pyrender/material.py @@ -0,0 +1,707 @@ +"""Material properties, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-material +and +https://github.com/KhronosGroup/glTF/tree/master/extensions/2.0/Khronos/KHR_materials_pbrSpecularGlossiness + +Author: Matthew Matl +""" +import abc +import numpy as np +import six + +from .constants import TexFlags +from .utils import format_color_vector, format_texture_source +from .texture import Texture + + +@six.add_metaclass(abc.ABCMeta) +class Material(object): + """Base for standard glTF 2.0 materials. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + normalTexture : (n,n,3) float or :class:`Texture`, optional + A tangent space normal map. The texture contains RGB components in + linear space. Each texel represents the XYZ components of a normal + vector in tangent space. Red [0 to 255] maps to X [-1 to 1]. Green + [0 to 255] maps to Y [-1 to 1]. Blue [128 to 255] maps to Z + [1/255 to 1]. The normal vectors use OpenGL conventions where +X is + right and +Y is up. +Z points toward the viewer. + occlusionTexture : (n,n,1) float or :class:`Texture`, optional + The occlusion map texture. The occlusion values are sampled from the R + channel. Higher values indicate areas that should receive full indirect + lighting and lower values indicate no indirect lighting. These values + are linear. If other channels are present (GBA), they are ignored for + occlusion calculations. + emissiveTexture : (n,n,3) float or :class:`Texture`, optional + The emissive map controls the color and intensity of the light being + emitted by the material. This texture contains RGB components in sRGB + color space. If a fourth component (A) is present, it is ignored. + emissiveFactor : (3,) float, optional + The RGB components of the emissive color of the material. These values + are linear. If an emissiveTexture is specified, this value is + multiplied with the texel values. + alphaMode : str, optional + The material's alpha rendering mode enumeration specifying the + interpretation of the alpha value of the main factor and texture. + Allowed Values: + + - `"OPAQUE"` The alpha value is ignored and the rendered output is + fully opaque. + - `"MASK"` The rendered output is either fully opaque or fully + transparent depending on the alpha value and the specified alpha + cutoff value. + - `"BLEND"` The alpha value is used to composite the source and + destination areas. The rendered output is combined with the + background using the normal painting operation (i.e. the Porter + and Duff over operator). + + alphaCutoff : float, optional + Specifies the cutoff threshold when in MASK mode. If the alpha value is + greater than or equal to this value then it is rendered as fully + opaque, otherwise, it is rendered as fully transparent. + A value greater than 1.0 will render the entire material as fully + transparent. This value is ignored for other modes. + doubleSided : bool, optional + Specifies whether the material is double sided. When this value is + false, back-face culling is enabled. When this value is true, + back-face culling is disabled and double sided lighting is enabled. + smooth : bool, optional + If True, the material is rendered smoothly by using only one normal + per vertex and face indexing. + wireframe : bool, optional + If True, the material is rendered in wireframe mode. + """ + + def __init__(self, + name=None, + normalTexture=None, + occlusionTexture=None, + emissiveTexture=None, + emissiveFactor=None, + alphaMode=None, + alphaCutoff=None, + doubleSided=False, + smooth=True, + wireframe=False): + + # Set defaults + if alphaMode is None: + alphaMode = 'OPAQUE' + + if alphaCutoff is None: + alphaCutoff = 0.5 + + if emissiveFactor is None: + emissiveFactor = np.zeros(3).astype(np.float32) + + self.name = name + self.normalTexture = normalTexture + self.occlusionTexture = occlusionTexture + self.emissiveTexture = emissiveTexture + self.emissiveFactor = emissiveFactor + self.alphaMode = alphaMode + self.alphaCutoff = alphaCutoff + self.doubleSided = doubleSided + self.smooth = smooth + self.wireframe = wireframe + + self._tex_flags = None + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def normalTexture(self): + """(n,n,3) float or :class:`Texture` : The tangent-space normal map. + """ + return self._normalTexture + + @normalTexture.setter + def normalTexture(self, value): + # TODO TMP + self._normalTexture = self._format_texture(value, 'RGB') + self._tex_flags = None + + @property + def occlusionTexture(self): + """(n,n,1) float or :class:`Texture` : The ambient occlusion map. + """ + return self._occlusionTexture + + @occlusionTexture.setter + def occlusionTexture(self, value): + self._occlusionTexture = self._format_texture(value, 'R') + self._tex_flags = None + + @property + def emissiveTexture(self): + """(n,n,3) float or :class:`Texture` : The emission map. + """ + return self._emissiveTexture + + @emissiveTexture.setter + def emissiveTexture(self, value): + self._emissiveTexture = self._format_texture(value, 'RGB') + self._tex_flags = None + + @property + def emissiveFactor(self): + """(3,) float : Base multiplier for emission colors. + """ + return self._emissiveFactor + + @emissiveFactor.setter + def emissiveFactor(self, value): + if value is None: + value = np.zeros(3) + self._emissiveFactor = format_color_vector(value, 3) + + @property + def alphaMode(self): + """str : The mode for blending. + """ + return self._alphaMode + + @alphaMode.setter + def alphaMode(self, value): + if value not in set(['OPAQUE', 'MASK', 'BLEND']): + raise ValueError('Invalid alpha mode {}'.format(value)) + self._alphaMode = value + + @property + def alphaCutoff(self): + """float : The cutoff threshold in MASK mode. + """ + return self._alphaCutoff + + @alphaCutoff.setter + def alphaCutoff(self, value): + if value < 0 or value > 1: + raise ValueError('Alpha cutoff must be in range [0,1]') + self._alphaCutoff = float(value) + + @property + def doubleSided(self): + """bool : Whether the material is double-sided. + """ + return self._doubleSided + + @doubleSided.setter + def doubleSided(self, value): + if not isinstance(value, bool): + raise TypeError('Double sided must be a boolean value') + self._doubleSided = value + + @property + def smooth(self): + """bool : Whether to render the mesh smoothly by + interpolating vertex normals. + """ + return self._smooth + + @smooth.setter + def smooth(self, value): + if not isinstance(value, bool): + raise TypeError('Double sided must be a boolean value') + self._smooth = value + + @property + def wireframe(self): + """bool : Whether to render the mesh in wireframe mode. + """ + return self._wireframe + + @wireframe.setter + def wireframe(self, value): + if not isinstance(value, bool): + raise TypeError('Wireframe must be a boolean value') + self._wireframe = value + + @property + def is_transparent(self): + """bool : If True, the object is partially transparent. + """ + return self._compute_transparency() + + @property + def tex_flags(self): + """int : Texture availability flags. + """ + if self._tex_flags is None: + self._tex_flags = self._compute_tex_flags() + return self._tex_flags + + @property + def textures(self): + """list of :class:`Texture` : The textures associated with this + material. + """ + return self._compute_textures() + + def _compute_transparency(self): + return False + + def _compute_tex_flags(self): + tex_flags = TexFlags.NONE + if self.normalTexture is not None: + tex_flags |= TexFlags.NORMAL + if self.occlusionTexture is not None: + tex_flags |= TexFlags.OCCLUSION + if self.emissiveTexture is not None: + tex_flags |= TexFlags.EMISSIVE + return tex_flags + + def _compute_textures(self): + all_textures = [ + self.normalTexture, self.occlusionTexture, self.emissiveTexture + ] + textures = set([t for t in all_textures if t is not None]) + return textures + + def _format_texture(self, texture, target_channels='RGB'): + """Format a texture as a float32 np array. + """ + if isinstance(texture, Texture) or texture is None: + return texture + else: + source = format_texture_source(texture, target_channels) + return Texture(source=source, source_channels=target_channels) + + +class MetallicRoughnessMaterial(Material): + """A material based on the metallic-roughness material model from + Physically-Based Rendering (PBR) methodology. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + normalTexture : (n,n,3) float or :class:`Texture`, optional + A tangent space normal map. The texture contains RGB components in + linear space. Each texel represents the XYZ components of a normal + vector in tangent space. Red [0 to 255] maps to X [-1 to 1]. Green + [0 to 255] maps to Y [-1 to 1]. Blue [128 to 255] maps to Z + [1/255 to 1]. The normal vectors use OpenGL conventions where +X is + right and +Y is up. +Z points toward the viewer. + occlusionTexture : (n,n,1) float or :class:`Texture`, optional + The occlusion map texture. The occlusion values are sampled from the R + channel. Higher values indicate areas that should receive full indirect + lighting and lower values indicate no indirect lighting. These values + are linear. If other channels are present (GBA), they are ignored for + occlusion calculations. + emissiveTexture : (n,n,3) float or :class:`Texture`, optional + The emissive map controls the color and intensity of the light being + emitted by the material. This texture contains RGB components in sRGB + color space. If a fourth component (A) is present, it is ignored. + emissiveFactor : (3,) float, optional + The RGB components of the emissive color of the material. These values + are linear. If an emissiveTexture is specified, this value is + multiplied with the texel values. + alphaMode : str, optional + The material's alpha rendering mode enumeration specifying the + interpretation of the alpha value of the main factor and texture. + Allowed Values: + + - `"OPAQUE"` The alpha value is ignored and the rendered output is + fully opaque. + - `"MASK"` The rendered output is either fully opaque or fully + transparent depending on the alpha value and the specified alpha + cutoff value. + - `"BLEND"` The alpha value is used to composite the source and + destination areas. The rendered output is combined with the + background using the normal painting operation (i.e. the Porter + and Duff over operator). + + alphaCutoff : float, optional + Specifies the cutoff threshold when in MASK mode. If the alpha value is + greater than or equal to this value then it is rendered as fully + opaque, otherwise, it is rendered as fully transparent. + A value greater than 1.0 will render the entire material as fully + transparent. This value is ignored for other modes. + doubleSided : bool, optional + Specifies whether the material is double sided. When this value is + false, back-face culling is enabled. When this value is true, + back-face culling is disabled and double sided lighting is enabled. + smooth : bool, optional + If True, the material is rendered smoothly by using only one normal + per vertex and face indexing. + wireframe : bool, optional + If True, the material is rendered in wireframe mode. + baseColorFactor : (4,) float, optional + The RGBA components of the base color of the material. The fourth + component (A) is the alpha coverage of the material. The alphaMode + property specifies how alpha is interpreted. These values are linear. + If a baseColorTexture is specified, this value is multiplied with the + texel values. + baseColorTexture : (n,n,4) float or :class:`Texture`, optional + The base color texture. This texture contains RGB(A) components in sRGB + color space. The first three components (RGB) specify the base color of + the material. If the fourth component (A) is present, it represents the + alpha coverage of the material. Otherwise, an alpha of 1.0 is assumed. + The alphaMode property specifies how alpha is interpreted. + The stored texels must not be premultiplied. + metallicFactor : float + The metalness of the material. A value of 1.0 means the material is a + metal. A value of 0.0 means the material is a dielectric. Values in + between are for blending between metals and dielectrics such as dirty + metallic surfaces. This value is linear. If a metallicRoughnessTexture + is specified, this value is multiplied with the metallic texel values. + roughnessFactor : float + The roughness of the material. A value of 1.0 means the material is + completely rough. A value of 0.0 means the material is completely + smooth. This value is linear. If a metallicRoughnessTexture is + specified, this value is multiplied with the roughness texel values. + metallicRoughnessTexture : (n,n,2) float or :class:`Texture`, optional + The metallic-roughness texture. The metalness values are sampled from + the B channel. The roughness values are sampled from the G channel. + These values are linear. If other channels are present (R or A), they + are ignored for metallic-roughness calculations. + """ + + def __init__(self, + name=None, + normalTexture=None, + occlusionTexture=None, + emissiveTexture=None, + emissiveFactor=None, + alphaMode=None, + alphaCutoff=None, + doubleSided=False, + smooth=True, + wireframe=False, + baseColorFactor=None, + baseColorTexture=None, + metallicFactor=1.0, + roughnessFactor=1.0, + metallicRoughnessTexture=None): + super(MetallicRoughnessMaterial, self).__init__( + name=name, + normalTexture=normalTexture, + occlusionTexture=occlusionTexture, + emissiveTexture=emissiveTexture, + emissiveFactor=emissiveFactor, + alphaMode=alphaMode, + alphaCutoff=alphaCutoff, + doubleSided=doubleSided, + smooth=smooth, + wireframe=wireframe + ) + + # Set defaults + if baseColorFactor is None: + baseColorFactor = np.ones(4).astype(np.float32) + + self.baseColorFactor = baseColorFactor + self.baseColorTexture = baseColorTexture + self.metallicFactor = metallicFactor + self.roughnessFactor = roughnessFactor + self.metallicRoughnessTexture = metallicRoughnessTexture + + @property + def baseColorFactor(self): + """(4,) float or :class:`Texture` : The RGBA base color multiplier. + """ + return self._baseColorFactor + + @baseColorFactor.setter + def baseColorFactor(self, value): + if value is None: + value = np.ones(4) + self._baseColorFactor = format_color_vector(value, 4) + + @property + def baseColorTexture(self): + """(n,n,4) float or :class:`Texture` : The diffuse texture. + """ + return self._baseColorTexture + + @baseColorTexture.setter + def baseColorTexture(self, value): + self._baseColorTexture = self._format_texture(value, 'RGBA') + self._tex_flags = None + + @property + def metallicFactor(self): + """float : The metalness of the material. + """ + return self._metallicFactor + + @metallicFactor.setter + def metallicFactor(self, value): + if value is None: + value = 1.0 + if value < 0 or value > 1: + raise ValueError('Metallic factor must be in range [0,1]') + self._metallicFactor = float(value) + + @property + def roughnessFactor(self): + """float : The roughness of the material. + """ + return self.RoughnessFactor + + @roughnessFactor.setter + def roughnessFactor(self, value): + if value is None: + value = 1.0 + if value < 0 or value > 1: + raise ValueError('Roughness factor must be in range [0,1]') + self.RoughnessFactor = float(value) + + @property + def metallicRoughnessTexture(self): + """(n,n,2) float or :class:`Texture` : The metallic-roughness texture. + """ + return self._metallicRoughnessTexture + + @metallicRoughnessTexture.setter + def metallicRoughnessTexture(self, value): + self._metallicRoughnessTexture = self._format_texture(value, 'GB') + self._tex_flags = None + + def _compute_tex_flags(self): + tex_flags = super(MetallicRoughnessMaterial, self)._compute_tex_flags() + if self.baseColorTexture is not None: + tex_flags |= TexFlags.BASE_COLOR + if self.metallicRoughnessTexture is not None: + tex_flags |= TexFlags.METALLIC_ROUGHNESS + return tex_flags + + def _compute_transparency(self): + if self.alphaMode == 'OPAQUE': + return False + cutoff = self.alphaCutoff + if self.alphaMode == 'BLEND': + cutoff = 1.0 + if self.baseColorFactor[3] < cutoff: + return True + if (self.baseColorTexture is not None and + self.baseColorTexture.is_transparent(cutoff)): + return True + return False + + def _compute_textures(self): + textures = super(MetallicRoughnessMaterial, self)._compute_textures() + all_textures = [self.baseColorTexture, self.metallicRoughnessTexture] + all_textures = {t for t in all_textures if t is not None} + textures |= all_textures + return textures + + +class SpecularGlossinessMaterial(Material): + """A material based on the specular-glossiness material model from + Physically-Based Rendering (PBR) methodology. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + normalTexture : (n,n,3) float or :class:`Texture`, optional + A tangent space normal map. The texture contains RGB components in + linear space. Each texel represents the XYZ components of a normal + vector in tangent space. Red [0 to 255] maps to X [-1 to 1]. Green + [0 to 255] maps to Y [-1 to 1]. Blue [128 to 255] maps to Z + [1/255 to 1]. The normal vectors use OpenGL conventions where +X is + right and +Y is up. +Z points toward the viewer. + occlusionTexture : (n,n,1) float or :class:`Texture`, optional + The occlusion map texture. The occlusion values are sampled from the R + channel. Higher values indicate areas that should receive full indirect + lighting and lower values indicate no indirect lighting. These values + are linear. If other channels are present (GBA), they are ignored for + occlusion calculations. + emissiveTexture : (n,n,3) float or :class:`Texture`, optional + The emissive map controls the color and intensity of the light being + emitted by the material. This texture contains RGB components in sRGB + color space. If a fourth component (A) is present, it is ignored. + emissiveFactor : (3,) float, optional + The RGB components of the emissive color of the material. These values + are linear. If an emissiveTexture is specified, this value is + multiplied with the texel values. + alphaMode : str, optional + The material's alpha rendering mode enumeration specifying the + interpretation of the alpha value of the main factor and texture. + Allowed Values: + + - `"OPAQUE"` The alpha value is ignored and the rendered output is + fully opaque. + - `"MASK"` The rendered output is either fully opaque or fully + transparent depending on the alpha value and the specified alpha + cutoff value. + - `"BLEND"` The alpha value is used to composite the source and + destination areas. The rendered output is combined with the + background using the normal painting operation (i.e. the Porter + and Duff over operator). + + alphaCutoff : float, optional + Specifies the cutoff threshold when in MASK mode. If the alpha value is + greater than or equal to this value then it is rendered as fully + opaque, otherwise, it is rendered as fully transparent. + A value greater than 1.0 will render the entire material as fully + transparent. This value is ignored for other modes. + doubleSided : bool, optional + Specifies whether the material is double sided. When this value is + false, back-face culling is enabled. When this value is true, + back-face culling is disabled and double sided lighting is enabled. + smooth : bool, optional + If True, the material is rendered smoothly by using only one normal + per vertex and face indexing. + wireframe : bool, optional + If True, the material is rendered in wireframe mode. + diffuseFactor : (4,) float + The RGBA components of the reflected diffuse color of the material. + Metals have a diffuse value of [0.0, 0.0, 0.0]. The fourth component + (A) is the opacity of the material. The values are linear. + diffuseTexture : (n,n,4) float or :class:`Texture`, optional + The diffuse texture. This texture contains RGB(A) components of the + reflected diffuse color of the material in sRGB color space. If the + fourth component (A) is present, it represents the alpha coverage of + the material. Otherwise, an alpha of 1.0 is assumed. + The alphaMode property specifies how alpha is interpreted. + The stored texels must not be premultiplied. + specularFactor : (3,) float + The specular RGB color of the material. This value is linear. + glossinessFactor : float + The glossiness or smoothness of the material. A value of 1.0 means the + material has full glossiness or is perfectly smooth. A value of 0.0 + means the material has no glossiness or is perfectly rough. This value + is linear. + specularGlossinessTexture : (n,n,4) or :class:`Texture`, optional + The specular-glossiness texture is a RGBA texture, containing the + specular color (RGB) in sRGB space and the glossiness value (A) in + linear space. + """ + + def __init__(self, + name=None, + normalTexture=None, + occlusionTexture=None, + emissiveTexture=None, + emissiveFactor=None, + alphaMode=None, + alphaCutoff=None, + doubleSided=False, + smooth=True, + wireframe=False, + diffuseFactor=None, + diffuseTexture=None, + specularFactor=None, + glossinessFactor=1.0, + specularGlossinessTexture=None): + super(SpecularGlossinessMaterial, self).__init__( + name=name, + normalTexture=normalTexture, + occlusionTexture=occlusionTexture, + emissiveTexture=emissiveTexture, + emissiveFactor=emissiveFactor, + alphaMode=alphaMode, + alphaCutoff=alphaCutoff, + doubleSided=doubleSided, + smooth=smooth, + wireframe=wireframe + ) + + # Set defaults + if diffuseFactor is None: + diffuseFactor = np.ones(4).astype(np.float32) + if specularFactor is None: + specularFactor = np.ones(3).astype(np.float32) + + self.diffuseFactor = diffuseFactor + self.diffuseTexture = diffuseTexture + self.specularFactor = specularFactor + self.glossinessFactor = glossinessFactor + self.specularGlossinessTexture = specularGlossinessTexture + + @property + def diffuseFactor(self): + """(4,) float : The diffuse base color. + """ + return self._diffuseFactor + + @diffuseFactor.setter + def diffuseFactor(self, value): + self._diffuseFactor = format_color_vector(value, 4) + + @property + def diffuseTexture(self): + """(n,n,4) float or :class:`Texture` : The diffuse map. + """ + return self._diffuseTexture + + @diffuseTexture.setter + def diffuseTexture(self, value): + self._diffuseTexture = self._format_texture(value, 'RGBA') + self._tex_flags = None + + @property + def specularFactor(self): + """(3,) float : The specular color of the material. + """ + return self._specularFactor + + @specularFactor.setter + def specularFactor(self, value): + self._specularFactor = format_color_vector(value, 3) + + @property + def glossinessFactor(self): + """float : The glossiness of the material. + """ + return self.glossinessFactor + + @glossinessFactor.setter + def glossinessFactor(self, value): + if value < 0 or value > 1: + raise ValueError('glossiness factor must be in range [0,1]') + self._glossinessFactor = float(value) + + @property + def specularGlossinessTexture(self): + """(n,n,4) or :class:`Texture` : The specular-glossiness texture. + """ + return self._specularGlossinessTexture + + @specularGlossinessTexture.setter + def specularGlossinessTexture(self, value): + self._specularGlossinessTexture = self._format_texture(value, 'GB') + self._tex_flags = None + + def _compute_tex_flags(self): + flags = super(SpecularGlossinessMaterial, self)._compute_tex_flags() + if self.diffuseTexture is not None: + flags |= TexFlags.DIFFUSE + if self.specularGlossinessTexture is not None: + flags |= TexFlags.SPECULAR_GLOSSINESS + return flags + + def _compute_transparency(self): + if self.alphaMode == 'OPAQUE': + return False + cutoff = self.alphaCutoff + if self.alphaMode == 'BLEND': + cutoff = 1.0 + if self.diffuseFactor[3] < cutoff: + return True + if (self.diffuseTexture is not None and + self.diffuseTexture.is_transparent(cutoff)): + return True + return False + + def _compute_textures(self): + textures = super(SpecularGlossinessMaterial, self)._compute_textures() + all_textures = [self.diffuseTexture, self.specularGlossinessTexture] + all_textures = {t for t in all_textures if t is not None} + textures |= all_textures + return textures diff --git a/pyrender/pyrender/mesh.py b/pyrender/pyrender/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..36833ea3dfa6c095a18fc745ff34cf106e83c95d --- /dev/null +++ b/pyrender/pyrender/mesh.py @@ -0,0 +1,328 @@ +"""Meshes, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-mesh + +Author: Matthew Matl +""" +import copy + +import numpy as np +import trimesh + +from .primitive import Primitive +from .constants import GLTF +from .material import MetallicRoughnessMaterial + + +class Mesh(object): + """A set of primitives to be rendered. + + Parameters + ---------- + name : str + The user-defined name of this object. + primitives : list of :class:`Primitive` + The primitives associated with this mesh. + weights : (k,) float + Array of weights to be applied to the Morph Targets. + is_visible : bool + If False, the mesh will not be rendered. + """ + + def __init__(self, primitives, name=None, weights=None, is_visible=True): + self.primitives = primitives + self.name = name + self.weights = weights + self.is_visible = is_visible + + self._bounds = None + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def primitives(self): + """list of :class:`Primitive` : The primitives associated + with this mesh. + """ + return self._primitives + + @primitives.setter + def primitives(self, value): + self._primitives = value + + @property + def weights(self): + """(k,) float : Weights to be applied to morph targets. + """ + return self._weights + + @weights.setter + def weights(self, value): + self._weights = value + + @property + def is_visible(self): + """bool : Whether the mesh is visible. + """ + return self._is_visible + + @is_visible.setter + def is_visible(self, value): + self._is_visible = value + + @property + def bounds(self): + """(2,3) float : The axis-aligned bounds of the mesh. + """ + if self._bounds is None: + bounds = np.array([[np.infty, np.infty, np.infty], + [-np.infty, -np.infty, -np.infty]]) + for p in self.primitives: + bounds[0] = np.minimum(bounds[0], p.bounds[0]) + bounds[1] = np.maximum(bounds[1], p.bounds[1]) + self._bounds = bounds + return self._bounds + + @property + def centroid(self): + """(3,) float : The centroid of the mesh's axis-aligned bounding box + (AABB). + """ + return np.mean(self.bounds, axis=0) + + @property + def extents(self): + """(3,) float : The lengths of the axes of the mesh's AABB. + """ + return np.diff(self.bounds, axis=0).reshape(-1) + + @property + def scale(self): + """(3,) float : The length of the diagonal of the mesh's AABB. + """ + return np.linalg.norm(self.extents) + + @property + def is_transparent(self): + """bool : If True, the mesh is partially-transparent. + """ + for p in self.primitives: + if p.is_transparent: + return True + return False + + @staticmethod + def from_points(points, colors=None, normals=None, + is_visible=True, poses=None): + """Create a Mesh from a set of points. + + Parameters + ---------- + points : (n,3) float + The point positions. + colors : (n,3) or (n,4) float, optional + RGB or RGBA colors for each point. + normals : (n,3) float, optionals + The normal vectors for each point. + is_visible : bool + If False, the points will not be rendered. + poses : (x,4,4) + Array of 4x4 transformation matrices for instancing this object. + + Returns + ------- + mesh : :class:`Mesh` + The created mesh. + """ + primitive = Primitive( + positions=points, + normals=normals, + color_0=colors, + mode=GLTF.POINTS, + poses=poses + ) + mesh = Mesh(primitives=[primitive], is_visible=is_visible) + return mesh + + @staticmethod + def from_trimesh(mesh, material=None, is_visible=True, + poses=None, wireframe=False, smooth=True): + """Create a Mesh from a :class:`~trimesh.base.Trimesh`. + + Parameters + ---------- + mesh : :class:`~trimesh.base.Trimesh` or list of them + A triangular mesh or a list of meshes. + material : :class:`Material` + The material of the object. Overrides any mesh material. + If not specified and the mesh has no material, a default material + will be used. + is_visible : bool + If False, the mesh will not be rendered. + poses : (n,4,4) float + Array of 4x4 transformation matrices for instancing this object. + wireframe : bool + If `True`, the mesh will be rendered as a wireframe object + smooth : bool + If `True`, the mesh will be rendered with interpolated vertex + normals. Otherwise, the mesh edges will stay sharp. + + Returns + ------- + mesh : :class:`Mesh` + The created mesh. + """ + + if isinstance(mesh, (list, tuple, set, np.ndarray)): + meshes = list(mesh) + elif isinstance(mesh, trimesh.Trimesh): + meshes = [mesh] + else: + raise TypeError('Expected a Trimesh or a list, got a {}' + .format(type(mesh))) + + primitives = [] + for m in meshes: + positions = None + normals = None + indices = None + + # Compute positions, normals, and indices + if smooth: + positions = m.vertices.copy() + normals = m.vertex_normals.copy() + indices = m.faces.copy() + else: + positions = m.vertices[m.faces].reshape((3 * len(m.faces), 3)) + normals = np.repeat(m.face_normals, 3, axis=0) + + # Compute colors, texture coords, and material properties + color_0, texcoord_0, primitive_material = Mesh._get_trimesh_props(m, smooth=smooth, material=material) + + # Override if material is given. + if material is not None: + #primitive_material = copy.copy(material) + primitive_material = copy.deepcopy(material) # TODO + + if primitive_material is None: + # Replace material with default if needed + primitive_material = MetallicRoughnessMaterial( + alphaMode='BLEND', + baseColorFactor=[0.3, 0.3, 0.3, 1.0], + metallicFactor=0.2, + roughnessFactor=0.8 + ) + + primitive_material.wireframe = wireframe + + # Create the primitive + primitives.append(Primitive( + positions=positions, + normals=normals, + texcoord_0=texcoord_0, + color_0=color_0, + indices=indices, + material=primitive_material, + mode=GLTF.TRIANGLES, + poses=poses + )) + + return Mesh(primitives=primitives, is_visible=is_visible) + + @staticmethod + def _get_trimesh_props(mesh, smooth=False, material=None): + """Gets the vertex colors, texture coordinates, and material properties + from a :class:`~trimesh.base.Trimesh`. + """ + colors = None + texcoords = None + + # If the trimesh visual is undefined, return none for both + if not mesh.visual.defined: + return colors, texcoords, material + + # Process vertex colors + if material is None: + if mesh.visual.kind == 'vertex': + vc = mesh.visual.vertex_colors.copy() + if smooth: + colors = vc + else: + colors = vc[mesh.faces].reshape( + (3 * len(mesh.faces), vc.shape[1]) + ) + material = MetallicRoughnessMaterial( + alphaMode='BLEND', + baseColorFactor=[1.0, 1.0, 1.0, 1.0], + metallicFactor=0.2, + roughnessFactor=0.8 + ) + # Process face colors + elif mesh.visual.kind == 'face': + if smooth: + raise ValueError('Cannot use face colors with a smooth mesh') + else: + colors = np.repeat(mesh.visual.face_colors, 3, axis=0) + + material = MetallicRoughnessMaterial( + alphaMode='BLEND', + baseColorFactor=[1.0, 1.0, 1.0, 1.0], + metallicFactor=0.2, + roughnessFactor=0.8 + ) + + # Process texture colors + if mesh.visual.kind == 'texture': + # Configure UV coordinates + if mesh.visual.uv is not None and len(mesh.visual.uv) != 0: + uv = mesh.visual.uv.copy() + if smooth: + texcoords = uv + else: + texcoords = uv[mesh.faces].reshape( + (3 * len(mesh.faces), uv.shape[1]) + ) + + if material is None: + # Configure mesh material + mat = mesh.visual.material + + if isinstance(mat, trimesh.visual.texture.PBRMaterial): + material = MetallicRoughnessMaterial( + normalTexture=mat.normalTexture, + occlusionTexture=mat.occlusionTexture, + emissiveTexture=mat.emissiveTexture, + emissiveFactor=mat.emissiveFactor, + alphaMode='BLEND', + baseColorFactor=mat.baseColorFactor, + baseColorTexture=mat.baseColorTexture, + metallicFactor=mat.metallicFactor, + roughnessFactor=mat.roughnessFactor, + metallicRoughnessTexture=mat.metallicRoughnessTexture, + doubleSided=mat.doubleSided, + alphaCutoff=mat.alphaCutoff + ) + elif isinstance(mat, trimesh.visual.texture.SimpleMaterial): + glossiness = mat.kwargs.get('Ns', 1.0) + if isinstance(glossiness, list): + glossiness = float(glossiness[0]) + roughness = (2 / (glossiness + 2)) ** (1.0 / 4.0) + material = MetallicRoughnessMaterial( + alphaMode='BLEND', + roughnessFactor=roughness, + baseColorFactor=mat.diffuse, + baseColorTexture=mat.image, + ) + elif isinstance(mat, MetallicRoughnessMaterial): + material = mat + + return colors, texcoords, material diff --git a/pyrender/pyrender/node.py b/pyrender/pyrender/node.py new file mode 100644 index 0000000000000000000000000000000000000000..1f37f7856cc732a37dc58253022a7c331489493e --- /dev/null +++ b/pyrender/pyrender/node.py @@ -0,0 +1,263 @@ +"""Nodes, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-node + +Author: Matthew Matl +""" +import numpy as np + +import trimesh.transformations as transformations + +from .camera import Camera +from .mesh import Mesh +from .light import Light + + +class Node(object): + """A node in the node hierarchy. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + camera : :class:`Camera`, optional + The camera in this node. + children : list of :class:`Node` + The children of this node. + skin : int, optional + The index of the skin referenced by this node. + matrix : (4,4) float, optional + A floating-point 4x4 transformation matrix. + mesh : :class:`Mesh`, optional + The mesh in this node. + rotation : (4,) float, optional + The node's unit quaternion in the order (x, y, z, w), where + w is the scalar. + scale : (3,) float, optional + The node's non-uniform scale, given as the scaling factors along the x, + y, and z axes. + translation : (3,) float, optional + The node's translation along the x, y, and z axes. + weights : (n,) float + The weights of the instantiated Morph Target. Number of elements must + match number of Morph Targets of used mesh. + light : :class:`Light`, optional + The light in this node. + """ + + def __init__(self, + name=None, + camera=None, + children=None, + skin=None, + matrix=None, + mesh=None, + rotation=None, + scale=None, + translation=None, + weights=None, + light=None): + # Set defaults + if children is None: + children = [] + + self._matrix = None + self._scale = None + self._rotation = None + self._translation = None + if matrix is None: + if rotation is None: + rotation = np.array([0.0, 0.0, 0.0, 1.0]) + if translation is None: + translation = np.zeros(3) + if scale is None: + scale = np.ones(3) + self.rotation = rotation + self.translation = translation + self.scale = scale + else: + self.matrix = matrix + + self.name = name + self.camera = camera + self.children = children + self.skin = skin + self.mesh = mesh + self.weights = weights + self.light = light + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def camera(self): + """:class:`Camera` : The camera in this node. + """ + return self._camera + + @camera.setter + def camera(self, value): + if value is not None and not isinstance(value, Camera): + raise TypeError('Value must be a camera') + self._camera = value + + @property + def children(self): + """list of :class:`Node` : The children of this node. + """ + return self._children + + @children.setter + def children(self, value): + self._children = value + + @property + def skin(self): + """int : The skin index for this node. + """ + return self._skin + + @skin.setter + def skin(self, value): + self._skin = value + + @property + def mesh(self): + """:class:`Mesh` : The mesh in this node. + """ + return self._mesh + + @mesh.setter + def mesh(self, value): + if value is not None and not isinstance(value, Mesh): + raise TypeError('Value must be a mesh') + self._mesh = value + + @property + def light(self): + """:class:`Light` : The light in this node. + """ + return self._light + + @light.setter + def light(self, value): + if value is not None and not isinstance(value, Light): + raise TypeError('Value must be a light') + self._light = value + + @property + def rotation(self): + """(4,) float : The xyzw quaternion for this node. + """ + return self._rotation + + @rotation.setter + def rotation(self, value): + value = np.asanyarray(value) + if value.shape != (4,): + raise ValueError('Quaternion must be a (4,) vector') + if np.abs(np.linalg.norm(value) - 1.0) > 1e-3: + raise ValueError('Quaternion must have norm == 1.0') + self._rotation = value + self._matrix = None + + @property + def translation(self): + """(3,) float : The translation for this node. + """ + return self._translation + + @translation.setter + def translation(self, value): + value = np.asanyarray(value) + if value.shape != (3,): + raise ValueError('Translation must be a (3,) vector') + self._translation = value + self._matrix = None + + @property + def scale(self): + """(3,) float : The scale for this node. + """ + return self._scale + + @scale.setter + def scale(self, value): + value = np.asanyarray(value) + if value.shape != (3,): + raise ValueError('Scale must be a (3,) vector') + self._scale = value + self._matrix = None + + @property + def matrix(self): + """(4,4) float : The homogenous transform matrix for this node. + + Note that this matrix's elements are not settable, + it's just a copy of the internal matrix. You can set the whole + matrix, but not an individual element. + """ + if self._matrix is None: + self._matrix = self._m_from_tqs( + self.translation, self.rotation, self.scale + ) + return self._matrix.copy() + + @matrix.setter + def matrix(self, value): + value = np.asanyarray(value) + if value.shape != (4,4): + raise ValueError('Matrix must be a 4x4 numpy ndarray') + if not np.allclose(value[3,:], np.array([0.0, 0.0, 0.0, 1.0])): + raise ValueError('Bottom row of matrix must be [0,0,0,1]') + self.rotation = Node._q_from_m(value) + self.scale = Node._s_from_m(value) + self.translation = Node._t_from_m(value) + self._matrix = value + + @staticmethod + def _t_from_m(m): + return m[:3,3] + + @staticmethod + def _r_from_m(m): + U = m[:3,:3] + norms = np.linalg.norm(U.T, axis=1) + return U / norms + + @staticmethod + def _q_from_m(m): + M = np.eye(4) + M[:3,:3] = Node._r_from_m(m) + q_wxyz = transformations.quaternion_from_matrix(M) + return np.roll(q_wxyz, -1) + + @staticmethod + def _s_from_m(m): + return np.linalg.norm(m[:3,:3].T, axis=1) + + @staticmethod + def _r_from_q(q): + q_wxyz = np.roll(q, 1) + return transformations.quaternion_matrix(q_wxyz)[:3,:3] + + @staticmethod + def _m_from_tqs(t, q, s): + S = np.eye(4) + S[:3,:3] = np.diag(s) + + R = np.eye(4) + R[:3,:3] = Node._r_from_q(q) + + T = np.eye(4) + T[:3,3] = t + + return T.dot(R.dot(S)) diff --git a/pyrender/pyrender/offscreen.py b/pyrender/pyrender/offscreen.py new file mode 100644 index 0000000000000000000000000000000000000000..340142983006cdc6f51b6d114e9b2b294aa4a919 --- /dev/null +++ b/pyrender/pyrender/offscreen.py @@ -0,0 +1,160 @@ +"""Wrapper for offscreen rendering. + +Author: Matthew Matl +""" +import os + +from .renderer import Renderer +from .constants import RenderFlags + + +class OffscreenRenderer(object): + """A wrapper for offscreen rendering. + + Parameters + ---------- + viewport_width : int + The width of the main viewport, in pixels. + viewport_height : int + The height of the main viewport, in pixels. + point_size : float + The size of screen-space points in pixels. + """ + + def __init__(self, viewport_width, viewport_height, point_size=1.0): + self.viewport_width = viewport_width + self.viewport_height = viewport_height + self.point_size = point_size + + self._platform = None + self._renderer = None + self._create() + + @property + def viewport_width(self): + """int : The width of the main viewport, in pixels. + """ + return self._viewport_width + + @viewport_width.setter + def viewport_width(self, value): + self._viewport_width = int(value) + + @property + def viewport_height(self): + """int : The height of the main viewport, in pixels. + """ + return self._viewport_height + + @viewport_height.setter + def viewport_height(self, value): + self._viewport_height = int(value) + + @property + def point_size(self): + """float : The pixel size of points in point clouds. + """ + return self._point_size + + @point_size.setter + def point_size(self, value): + self._point_size = float(value) + + def render(self, scene, flags=RenderFlags.NONE, seg_node_map=None): + """Render a scene with the given set of flags. + + Parameters + ---------- + scene : :class:`Scene` + A scene to render. + flags : int + A bitwise or of one or more flags from :class:`.RenderFlags`. + seg_node_map : dict + A map from :class:`.Node` objects to (3,) colors for each. + If specified along with flags set to :attr:`.RenderFlags.SEG`, + the color image will be a segmentation image. + + Returns + ------- + color_im : (h, w, 3) uint8 or (h, w, 4) uint8 + The color buffer in RGB format, or in RGBA format if + :attr:`.RenderFlags.RGBA` is set. + Not returned if flags includes :attr:`.RenderFlags.DEPTH_ONLY`. + depth_im : (h, w) float32 + The depth buffer in linear units. + """ + self._platform.make_current() + # If platform does not support dynamically-resizing framebuffers, + # destroy it and restart it + if (self._platform.viewport_height != self.viewport_height or + self._platform.viewport_width != self.viewport_width): + if not self._platform.supports_framebuffers(): + self.delete() + self._create() + + self._platform.make_current() + self._renderer.viewport_width = self.viewport_width + self._renderer.viewport_height = self.viewport_height + self._renderer.point_size = self.point_size + + if self._platform.supports_framebuffers(): + flags |= RenderFlags.OFFSCREEN + retval = self._renderer.render(scene, flags, seg_node_map) + else: + self._renderer.render(scene, flags, seg_node_map) + depth = self._renderer.read_depth_buf() + if flags & RenderFlags.DEPTH_ONLY: + retval = depth + else: + color = self._renderer.read_color_buf() + retval = color, depth + + # Make the platform not current + self._platform.make_uncurrent() + return retval + + def delete(self): + """Free all OpenGL resources. + """ + self._platform.make_current() + self._renderer.delete() + self._platform.delete_context() + del self._renderer + del self._platform + self._renderer = None + self._platform = None + import gc + gc.collect() + + def _create(self): + if 'PYOPENGL_PLATFORM' not in os.environ: + from pyrender.platforms.pyglet_platform import PygletPlatform + self._platform = PygletPlatform(self.viewport_width, + self.viewport_height) + elif os.environ['PYOPENGL_PLATFORM'] == 'egl': + from pyrender.platforms import egl + device_id = int(os.environ.get('EGL_DEVICE_ID', '0')) + egl_device = egl.get_device_by_index(device_id) + self._platform = egl.EGLPlatform(self.viewport_width, + self.viewport_height, + device=egl_device) + elif os.environ['PYOPENGL_PLATFORM'] == 'osmesa': + from pyrender.platforms.osmesa import OSMesaPlatform + self._platform = OSMesaPlatform(self.viewport_width, + self.viewport_height) + else: + raise ValueError('Unsupported PyOpenGL platform: {}'.format( + os.environ['PYOPENGL_PLATFORM'] + )) + self._platform.init_context() + self._platform.make_current() + self._renderer = Renderer(self.viewport_width, self.viewport_height) + + def __del__(self): + try: + self.delete() + except Exception: + pass + + +__all__ = ['OffscreenRenderer'] diff --git a/pyrender/pyrender/platforms/__init__.py b/pyrender/pyrender/platforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7837fd5fdeccab5e48c85e41d20b238ea7396599 --- /dev/null +++ b/pyrender/pyrender/platforms/__init__.py @@ -0,0 +1,6 @@ +"""Platforms for generating offscreen OpenGL contexts for rendering. + +Author: Matthew Matl +""" + +from .base import Platform diff --git a/pyrender/pyrender/platforms/base.py b/pyrender/pyrender/platforms/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ecda906145e239737901809aa59db8d3e231c6 --- /dev/null +++ b/pyrender/pyrender/platforms/base.py @@ -0,0 +1,76 @@ +import abc + +import six + + +@six.add_metaclass(abc.ABCMeta) +class Platform(object): + """Base class for all OpenGL platforms. + + Parameters + ---------- + viewport_width : int + The width of the main viewport, in pixels. + viewport_height : int + The height of the main viewport, in pixels + """ + + def __init__(self, viewport_width, viewport_height): + self.viewport_width = viewport_width + self.viewport_height = viewport_height + + @property + def viewport_width(self): + """int : The width of the main viewport, in pixels. + """ + return self._viewport_width + + @viewport_width.setter + def viewport_width(self, value): + self._viewport_width = value + + @property + def viewport_height(self): + """int : The height of the main viewport, in pixels. + """ + return self._viewport_height + + @viewport_height.setter + def viewport_height(self, value): + self._viewport_height = value + + @abc.abstractmethod + def init_context(self): + """Create an OpenGL context. + """ + pass + + @abc.abstractmethod + def make_current(self): + """Make the OpenGL context current. + """ + pass + + @abc.abstractmethod + def make_uncurrent(self): + """Make the OpenGL context uncurrent. + """ + pass + + @abc.abstractmethod + def delete_context(self): + """Delete the OpenGL context. + """ + pass + + @abc.abstractmethod + def supports_framebuffers(self): + """Returns True if the method supports framebuffer rendering. + """ + pass + + def __del__(self): + try: + self.delete_context() + except Exception: + pass diff --git a/pyrender/pyrender/platforms/egl.py b/pyrender/pyrender/platforms/egl.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2478d29c9a538c53ad83fa31f8e2277cd897c8 --- /dev/null +++ b/pyrender/pyrender/platforms/egl.py @@ -0,0 +1,219 @@ +import ctypes +import os + +import OpenGL.platform + +from .base import Platform + +EGL_PLATFORM_DEVICE_EXT = 0x313F +EGL_DRM_DEVICE_FILE_EXT = 0x3233 + + +def _ensure_egl_loaded(): + plugin = OpenGL.platform.PlatformPlugin.by_name('egl') + if plugin is None: + raise RuntimeError("EGL platform plugin is not available.") + + plugin_class = plugin.load() + plugin.loaded = True + # create instance of this platform implementation + plugin = plugin_class() + + plugin.install(vars(OpenGL.platform)) + + +_ensure_egl_loaded() +from OpenGL import EGL as egl + + +def _get_egl_func(func_name, res_type, *arg_types): + address = egl.eglGetProcAddress(func_name) + if address is None: + return None + + proto = ctypes.CFUNCTYPE(res_type) + proto.argtypes = arg_types + func = proto(address) + return func + + +def _get_egl_struct(struct_name): + from OpenGL._opaque import opaque_pointer_cls + return opaque_pointer_cls(struct_name) + + +# These are not defined in PyOpenGL by default. +_EGLDeviceEXT = _get_egl_struct('EGLDeviceEXT') +_eglGetPlatformDisplayEXT = _get_egl_func('eglGetPlatformDisplayEXT', egl.EGLDisplay) +_eglQueryDevicesEXT = _get_egl_func('eglQueryDevicesEXT', egl.EGLBoolean) +_eglQueryDeviceStringEXT = _get_egl_func('eglQueryDeviceStringEXT', ctypes.c_char_p) + + +def query_devices(): + if _eglQueryDevicesEXT is None: + raise RuntimeError("EGL query extension is not loaded or is not supported.") + + num_devices = egl.EGLint() + success = _eglQueryDevicesEXT(0, None, ctypes.pointer(num_devices)) + if not success or num_devices.value < 1: + return [] + + devices = (_EGLDeviceEXT * num_devices.value)() # array of size num_devices + success = _eglQueryDevicesEXT(num_devices.value, devices, ctypes.pointer(num_devices)) + if not success or num_devices.value < 1: + return [] + + return [EGLDevice(devices[i]) for i in range(num_devices.value)] + + +def get_default_device(): + # Fall back to not using query extension. + if _eglQueryDevicesEXT is None: + return EGLDevice(None) + + return query_devices()[0] + + +def get_device_by_index(device_id): + if _eglQueryDevicesEXT is None and device_id == 0: + return get_default_device() + + devices = query_devices() + if device_id >= len(devices): + raise ValueError('Invalid device ID ({})'.format(device_id, len(devices))) + return devices[device_id] + + +class EGLDevice: + + def __init__(self, display=None): + self._display = display + + def get_display(self): + if self._display is None: + return egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY) + + return _eglGetPlatformDisplayEXT(EGL_PLATFORM_DEVICE_EXT, self._display, None) + + @property + def name(self): + if self._display is None: + return 'default' + + name = _eglQueryDeviceStringEXT(self._display, EGL_DRM_DEVICE_FILE_EXT) + if name is None: + return None + + return name.decode('ascii') + + def __repr__(self): + return "".format(self.name) + + +class EGLPlatform(Platform): + """Renders using EGL. + """ + + def __init__(self, viewport_width, viewport_height, device: EGLDevice = None): + super(EGLPlatform, self).__init__(viewport_width, viewport_height) + if device is None: + device = get_default_device() + + self._egl_device = device + self._egl_display = None + self._egl_context = None + + def init_context(self): + _ensure_egl_loaded() + + from OpenGL.EGL import ( + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_BLUE_SIZE, + EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_DEPTH_SIZE, + EGL_COLOR_BUFFER_TYPE, EGL_RGB_BUFFER, + EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, EGL_CONFORMANT, + EGL_NONE, EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, + EGL_OPENGL_API, EGL_CONTEXT_MAJOR_VERSION, + EGL_CONTEXT_MINOR_VERSION, + EGL_CONTEXT_OPENGL_PROFILE_MASK, + EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT, + eglGetDisplay, eglInitialize, eglChooseConfig, + eglBindAPI, eglCreateContext, EGLConfig + ) + from OpenGL import arrays + + config_attributes = arrays.GLintArray.asArray([ + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, + EGL_BLUE_SIZE, 8, + EGL_RED_SIZE, 8, + EGL_GREEN_SIZE, 8, + EGL_DEPTH_SIZE, 24, + EGL_COLOR_BUFFER_TYPE, EGL_RGB_BUFFER, + EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, + EGL_CONFORMANT, EGL_OPENGL_BIT, + EGL_NONE + ]) + context_attributes = arrays.GLintArray.asArray([ + EGL_CONTEXT_MAJOR_VERSION, 4, + EGL_CONTEXT_MINOR_VERSION, 1, + EGL_CONTEXT_OPENGL_PROFILE_MASK, + EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT, + EGL_NONE + ]) + major, minor = ctypes.c_long(), ctypes.c_long() + num_configs = ctypes.c_long() + configs = (EGLConfig * 1)() + + # Cache DISPLAY if necessary and get an off-screen EGL display + orig_dpy = None + if 'DISPLAY' in os.environ: + orig_dpy = os.environ['DISPLAY'] + del os.environ['DISPLAY'] + + self._egl_display = self._egl_device.get_display() + if orig_dpy is not None: + os.environ['DISPLAY'] = orig_dpy + + # Initialize EGL + assert eglInitialize(self._egl_display, major, minor) + assert eglChooseConfig( + self._egl_display, config_attributes, configs, 1, num_configs + ) + + # Bind EGL to the OpenGL API + assert eglBindAPI(EGL_OPENGL_API) + + # Create an EGL context + self._egl_context = eglCreateContext( + self._egl_display, configs[0], + EGL_NO_CONTEXT, context_attributes + ) + + # Make it current + self.make_current() + + def make_current(self): + from OpenGL.EGL import eglMakeCurrent, EGL_NO_SURFACE + assert eglMakeCurrent( + self._egl_display, EGL_NO_SURFACE, EGL_NO_SURFACE, + self._egl_context + ) + + def make_uncurrent(self): + """Make the OpenGL context uncurrent. + """ + pass + + def delete_context(self): + from OpenGL.EGL import eglDestroyContext, eglTerminate + if self._egl_display is not None: + if self._egl_context is not None: + eglDestroyContext(self._egl_display, self._egl_context) + self._egl_context = None + eglTerminate(self._egl_display) + self._egl_display = None + + def supports_framebuffers(self): + return True + + +__all__ = ['EGLPlatform'] diff --git a/pyrender/pyrender/platforms/osmesa.py b/pyrender/pyrender/platforms/osmesa.py new file mode 100644 index 0000000000000000000000000000000000000000..deaa5ff44031a107883913ae9a18fc425d650f3d --- /dev/null +++ b/pyrender/pyrender/platforms/osmesa.py @@ -0,0 +1,59 @@ +from .base import Platform + + +__all__ = ['OSMesaPlatform'] + + +class OSMesaPlatform(Platform): + """Renders into a software buffer using OSMesa. Requires special versions + of OSMesa to be installed, plus PyOpenGL upgrade. + """ + + def __init__(self, viewport_width, viewport_height): + super(OSMesaPlatform, self).__init__(viewport_width, viewport_height) + self._context = None + self._buffer = None + + def init_context(self): + from OpenGL import arrays + from OpenGL.osmesa import ( + OSMesaCreateContextAttribs, OSMESA_FORMAT, + OSMESA_RGBA, OSMESA_PROFILE, OSMESA_CORE_PROFILE, + OSMESA_CONTEXT_MAJOR_VERSION, OSMESA_CONTEXT_MINOR_VERSION, + OSMESA_DEPTH_BITS + ) + + attrs = arrays.GLintArray.asArray([ + OSMESA_FORMAT, OSMESA_RGBA, + OSMESA_DEPTH_BITS, 24, + OSMESA_PROFILE, OSMESA_CORE_PROFILE, + OSMESA_CONTEXT_MAJOR_VERSION, 3, + OSMESA_CONTEXT_MINOR_VERSION, 3, + 0 + ]) + self._context = OSMesaCreateContextAttribs(attrs, None) + self._buffer = arrays.GLubyteArray.zeros( + (self.viewport_height, self.viewport_width, 4) + ) + + def make_current(self): + from OpenGL import GL as gl + from OpenGL.osmesa import OSMesaMakeCurrent + assert(OSMesaMakeCurrent( + self._context, self._buffer, gl.GL_UNSIGNED_BYTE, + self.viewport_width, self.viewport_height + )) + + def make_uncurrent(self): + """Make the OpenGL context uncurrent. + """ + pass + + def delete_context(self): + from OpenGL.osmesa import OSMesaDestroyContext + OSMesaDestroyContext(self._context) + self._context = None + self._buffer = None + + def supports_framebuffers(self): + return False diff --git a/pyrender/pyrender/platforms/pyglet_platform.py b/pyrender/pyrender/platforms/pyglet_platform.py new file mode 100644 index 0000000000000000000000000000000000000000..a70cf7b659bc85a92f6c9c8ebcc360662a068507 --- /dev/null +++ b/pyrender/pyrender/platforms/pyglet_platform.py @@ -0,0 +1,90 @@ +from pyrender.constants import (TARGET_OPEN_GL_MAJOR, TARGET_OPEN_GL_MINOR, + MIN_OPEN_GL_MAJOR, MIN_OPEN_GL_MINOR) +from .base import Platform + +import OpenGL + + +__all__ = ['PygletPlatform'] + + +class PygletPlatform(Platform): + """Renders on-screen using a 1x1 hidden Pyglet window for getting + an OpenGL context. + """ + + def __init__(self, viewport_width, viewport_height): + super(PygletPlatform, self).__init__(viewport_width, viewport_height) + self._window = None + + def init_context(self): + import pyglet + pyglet.options['shadow_window'] = False + + try: + pyglet.lib.x11.xlib.XInitThreads() + except Exception: + pass + + self._window = None + confs = [pyglet.gl.Config(sample_buffers=1, samples=4, + depth_size=24, + double_buffer=True, + major_version=TARGET_OPEN_GL_MAJOR, + minor_version=TARGET_OPEN_GL_MINOR), + pyglet.gl.Config(depth_size=24, + double_buffer=True, + major_version=TARGET_OPEN_GL_MAJOR, + minor_version=TARGET_OPEN_GL_MINOR), + pyglet.gl.Config(sample_buffers=1, samples=4, + depth_size=24, + double_buffer=True, + major_version=MIN_OPEN_GL_MAJOR, + minor_version=MIN_OPEN_GL_MINOR), + pyglet.gl.Config(depth_size=24, + double_buffer=True, + major_version=MIN_OPEN_GL_MAJOR, + minor_version=MIN_OPEN_GL_MINOR)] + for conf in confs: + try: + self._window = pyglet.window.Window(config=conf, visible=False, + resizable=False, + width=1, height=1) + break + except pyglet.window.NoSuchConfigException as e: + pass + + if not self._window: + raise ValueError( + 'Failed to initialize Pyglet window with an OpenGL >= 3+ ' + 'context. If you\'re logged in via SSH, ensure that you\'re ' + 'running your script with vglrun (i.e. VirtualGL). The ' + 'internal error message was "{}"'.format(e) + ) + + def make_current(self): + if self._window: + self._window.switch_to() + + def make_uncurrent(self): + try: + import pyglet + pyglet.gl.xlib.glx.glXMakeContextCurrent(self._window.context.x_display, 0, 0, None) + except Exception: + pass + + def delete_context(self): + if self._window is not None: + self.make_current() + cid = OpenGL.contextdata.getContext() + try: + self._window.context.destroy() + self._window.close() + except Exception: + pass + self._window = None + OpenGL.contextdata.cleanupContext(cid) + del cid + + def supports_framebuffers(self): + return True diff --git a/pyrender/pyrender/primitive.py b/pyrender/pyrender/primitive.py new file mode 100644 index 0000000000000000000000000000000000000000..7f83f46f532b126a4573e715dd03d079fef755ca --- /dev/null +++ b/pyrender/pyrender/primitive.py @@ -0,0 +1,489 @@ +"""Primitives, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-primitive + +Author: Matthew Matl +""" +import numpy as np + +from OpenGL.GL import * + +from .material import Material, MetallicRoughnessMaterial +from .constants import FLOAT_SZ, UINT_SZ, BufFlags, GLTF +from .utils import format_color_array + + +class Primitive(object): + """A primitive object which can be rendered. + + Parameters + ---------- + positions : (n, 3) float + XYZ vertex positions. + normals : (n, 3) float + Normalized XYZ vertex normals. + tangents : (n, 4) float + XYZW vertex tangents where the w component is a sign value + (either +1 or -1) indicating the handedness of the tangent basis. + texcoord_0 : (n, 2) float + The first set of UV texture coordinates. + texcoord_1 : (n, 2) float + The second set of UV texture coordinates. + color_0 : (n, 4) float + RGBA vertex colors. + joints_0 : (n, 4) float + Joint information. + weights_0 : (n, 4) float + Weight information for morphing. + indices : (m, 3) int + Face indices for triangle meshes or fans. + material : :class:`Material` + The material to apply to this primitive when rendering. + mode : int + The type of primitives to render, one of the following: + + - ``0``: POINTS + - ``1``: LINES + - ``2``: LINE_LOOP + - ``3``: LINE_STRIP + - ``4``: TRIANGLES + - ``5``: TRIANGLES_STRIP + - ``6``: TRIANGLES_FAN + targets : (k,) int + Morph target indices. + poses : (x,4,4), float + Array of 4x4 transformation matrices for instancing this object. + """ + + def __init__(self, + positions, + normals=None, + tangents=None, + texcoord_0=None, + texcoord_1=None, + color_0=None, + joints_0=None, + weights_0=None, + indices=None, + material=None, + mode=None, + targets=None, + poses=None): + + if mode is None: + mode = GLTF.TRIANGLES + + self.positions = positions + self.normals = normals + self.tangents = tangents + self.texcoord_0 = texcoord_0 + self.texcoord_1 = texcoord_1 + self.color_0 = color_0 + self.joints_0 = joints_0 + self.weights_0 = weights_0 + self.indices = indices + self.material = material + self.mode = mode + self.targets = targets + self.poses = poses + + self._bounds = None + self._vaid = None + self._buffers = [] + self._is_transparent = None + self._buf_flags = None + + @property + def positions(self): + """(n,3) float : XYZ vertex positions. + """ + return self._positions + + @positions.setter + def positions(self, value): + value = np.asanyarray(value, dtype=np.float32) + self._positions = np.ascontiguousarray(value) + self._bounds = None + + @property + def normals(self): + """(n,3) float : Normalized XYZ vertex normals. + """ + return self._normals + + @normals.setter + def normals(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if value.shape != self.positions.shape: + raise ValueError('Incorrect normals shape') + self._normals = value + + @property + def tangents(self): + """(n,4) float : XYZW vertex tangents. + """ + return self._tangents + + @tangents.setter + def tangents(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if value.shape != (self.positions.shape[0], 4): + raise ValueError('Incorrect tangent shape') + self._tangents = value + + @property + def texcoord_0(self): + """(n,2) float : The first set of UV texture coordinates. + """ + return self._texcoord_0 + + @texcoord_0.setter + def texcoord_0(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if (value.ndim != 2 or value.shape[0] != self.positions.shape[0] or + value.shape[1] < 2): + raise ValueError('Incorrect texture coordinate shape') + if value.shape[1] > 2: + value = value[:,:2] + self._texcoord_0 = value + + @property + def texcoord_1(self): + """(n,2) float : The second set of UV texture coordinates. + """ + return self._texcoord_1 + + @texcoord_1.setter + def texcoord_1(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if (value.ndim != 2 or value.shape[0] != self.positions.shape[0] or + value.shape[1] != 2): + raise ValueError('Incorrect texture coordinate shape') + self._texcoord_1 = value + + @property + def color_0(self): + """(n,4) float : RGBA vertex colors. + """ + return self._color_0 + + @color_0.setter + def color_0(self, value): + if value is not None: + value = np.ascontiguousarray( + format_color_array(value, shape=(len(self.positions), 4)) + ) + self._is_transparent = None + self._color_0 = value + + @property + def joints_0(self): + """(n,4) float : Joint information. + """ + return self._joints_0 + + @joints_0.setter + def joints_0(self, value): + self._joints_0 = value + + @property + def weights_0(self): + """(n,4) float : Weight information for morphing. + """ + return self._weights_0 + + @weights_0.setter + def weights_0(self, value): + self._weights_0 = value + + @property + def indices(self): + """(m,3) int : Face indices for triangle meshes or fans. + """ + return self._indices + + @indices.setter + def indices(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + self._indices = value + + @property + def material(self): + """:class:`Material` : The material for this primitive. + """ + return self._material + + @material.setter + def material(self, value): + # Create default material + if value is None: + value = MetallicRoughnessMaterial() + else: + if not isinstance(value, Material): + raise TypeError('Object material must be of type Material') + self._material = value + + @property + def mode(self): + """int : The type of primitive to render. + """ + return self._mode + + @mode.setter + def mode(self, value): + value = int(value) + if value < GLTF.POINTS or value > GLTF.TRIANGLE_FAN: + raise ValueError('Invalid mode') + self._mode = value + + @property + def targets(self): + """(k,) int : Morph target indices. + """ + return self._targets + + @targets.setter + def targets(self, value): + self._targets = value + + @property + def poses(self): + """(x,4,4) float : Homogenous transforms for instancing this primitive. + """ + return self._poses + + @poses.setter + def poses(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if value.ndim == 2: + value = value[np.newaxis,:,:] + if value.shape[1] != 4 or value.shape[2] != 4: + raise ValueError('Pose matrices must be of shape (n,4,4), ' + 'got {}'.format(value.shape)) + self._poses = value + self._bounds = None + + @property + def bounds(self): + if self._bounds is None: + self._bounds = self._compute_bounds() + return self._bounds + + @property + def centroid(self): + """(3,) float : The centroid of the primitive's AABB. + """ + return np.mean(self.bounds, axis=0) + + @property + def extents(self): + """(3,) float : The lengths of the axes of the primitive's AABB. + """ + return np.diff(self.bounds, axis=0).reshape(-1) + + @property + def scale(self): + """(3,) float : The length of the diagonal of the primitive's AABB. + """ + return np.linalg.norm(self.extents) + + @property + def buf_flags(self): + """int : The flags for the render buffer. + """ + if self._buf_flags is None: + self._buf_flags = self._compute_buf_flags() + return self._buf_flags + + def delete(self): + self._unbind() + self._remove_from_context() + + @property + def is_transparent(self): + """bool : If True, the mesh is partially-transparent. + """ + return self._compute_transparency() + + def _add_to_context(self): + if self._vaid is not None: + raise ValueError('Mesh is already bound to a context') + + # Generate and bind VAO + self._vaid = glGenVertexArrays(1) + glBindVertexArray(self._vaid) + + ####################################################################### + # Fill vertex buffer + ####################################################################### + + # Generate and bind vertex buffer + vertexbuffer = glGenBuffers(1) + self._buffers.append(vertexbuffer) + glBindBuffer(GL_ARRAY_BUFFER, vertexbuffer) + + # positions + vertex_data = self.positions + attr_sizes = [3] + + # Normals + if self.normals is not None: + vertex_data = np.hstack((vertex_data, self.normals)) + attr_sizes.append(3) + + # Tangents + if self.tangents is not None: + vertex_data = np.hstack((vertex_data, self.tangents)) + attr_sizes.append(4) + + # Texture Coordinates + if self.texcoord_0 is not None: + vertex_data = np.hstack((vertex_data, self.texcoord_0)) + attr_sizes.append(2) + if self.texcoord_1 is not None: + vertex_data = np.hstack((vertex_data, self.texcoord_1)) + attr_sizes.append(2) + + # Color + if self.color_0 is not None: + vertex_data = np.hstack((vertex_data, self.color_0)) + attr_sizes.append(4) + + # TODO JOINTS AND WEIGHTS + # PASS + + # Copy data to buffer + vertex_data = np.ascontiguousarray( + vertex_data.flatten().astype(np.float32) + ) + glBufferData( + GL_ARRAY_BUFFER, FLOAT_SZ * len(vertex_data), + vertex_data, GL_STATIC_DRAW + ) + total_sz = sum(attr_sizes) + offset = 0 + for i, sz in enumerate(attr_sizes): + glVertexAttribPointer( + i, sz, GL_FLOAT, GL_FALSE, FLOAT_SZ * total_sz, + ctypes.c_void_p(FLOAT_SZ * offset) + ) + glEnableVertexAttribArray(i) + offset += sz + + ####################################################################### + # Fill model matrix buffer + ####################################################################### + + if self.poses is not None: + pose_data = np.ascontiguousarray( + np.transpose(self.poses, [0,2,1]).flatten().astype(np.float32) + ) + else: + pose_data = np.ascontiguousarray( + np.eye(4).flatten().astype(np.float32) + ) + + modelbuffer = glGenBuffers(1) + self._buffers.append(modelbuffer) + glBindBuffer(GL_ARRAY_BUFFER, modelbuffer) + glBufferData( + GL_ARRAY_BUFFER, FLOAT_SZ * len(pose_data), + pose_data, GL_STATIC_DRAW + ) + + for i in range(0, 4): + idx = i + len(attr_sizes) + glEnableVertexAttribArray(idx) + glVertexAttribPointer( + idx, 4, GL_FLOAT, GL_FALSE, FLOAT_SZ * 4 * 4, + ctypes.c_void_p(4 * FLOAT_SZ * i) + ) + glVertexAttribDivisor(idx, 1) + + ####################################################################### + # Fill element buffer + ####################################################################### + if self.indices is not None: + elementbuffer = glGenBuffers(1) + self._buffers.append(elementbuffer) + glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, elementbuffer) + glBufferData(GL_ELEMENT_ARRAY_BUFFER, UINT_SZ * self.indices.size, + self.indices.flatten().astype(np.uint32), + GL_STATIC_DRAW) + + glBindVertexArray(0) + + def _remove_from_context(self): + if self._vaid is not None: + glDeleteVertexArrays(1, [self._vaid]) + glDeleteBuffers(len(self._buffers), self._buffers) + self._vaid = None + self._buffers = [] + + def _in_context(self): + return self._vaid is not None + + def _bind(self): + if self._vaid is None: + raise ValueError('Cannot bind a Mesh that has not been added ' + 'to a context') + glBindVertexArray(self._vaid) + + def _unbind(self): + glBindVertexArray(0) + + def _compute_bounds(self): + """Compute the bounds of this object. + """ + # Compute bounds of this object + bounds = np.array([np.min(self.positions, axis=0), + np.max(self.positions, axis=0)]) + + # If instanced, compute translations for approximate bounds + if self.poses is not None: + bounds += np.array([np.min(self.poses[:,:3,3], axis=0), + np.max(self.poses[:,:3,3], axis=0)]) + return bounds + + def _compute_transparency(self): + """Compute whether or not this object is transparent. + """ + if self.material.is_transparent: + return True + if self._is_transparent is None: + self._is_transparent = False + if self.color_0 is not None: + if np.any(self._color_0[:,3] != 1.0): + self._is_transparent = True + return self._is_transparent + + def _compute_buf_flags(self): + buf_flags = BufFlags.POSITION + + if self.normals is not None: + buf_flags |= BufFlags.NORMAL + if self.tangents is not None: + buf_flags |= BufFlags.TANGENT + if self.texcoord_0 is not None: + buf_flags |= BufFlags.TEXCOORD_0 + if self.texcoord_1 is not None: + buf_flags |= BufFlags.TEXCOORD_1 + if self.color_0 is not None: + buf_flags |= BufFlags.COLOR_0 + if self.joints_0 is not None: + buf_flags |= BufFlags.JOINTS_0 + if self.weights_0 is not None: + buf_flags |= BufFlags.WEIGHTS_0 + + return buf_flags diff --git a/pyrender/pyrender/renderer.py b/pyrender/pyrender/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae14c5cdb1785226a52ae6b71b08f01de069962 --- /dev/null +++ b/pyrender/pyrender/renderer.py @@ -0,0 +1,1339 @@ +"""PBR renderer for Python. + +Author: Matthew Matl +""" +import sys + +import numpy as np +import PIL + +from .constants import (RenderFlags, TextAlign, GLTF, BufFlags, TexFlags, + ProgramFlags, DEFAULT_Z_FAR, DEFAULT_Z_NEAR, + SHADOW_TEX_SZ, MAX_N_LIGHTS) +from .shader_program import ShaderProgramCache +from .material import MetallicRoughnessMaterial, SpecularGlossinessMaterial +from .light import PointLight, SpotLight, DirectionalLight +from .font import FontCache +from .utils import format_color_vector + +from OpenGL.GL import * + + +class Renderer(object): + """Class for handling all rendering operations on a scene. + + Note + ---- + This renderer relies on the existence of an OpenGL context and + does not create one on its own. + + Parameters + ---------- + viewport_width : int + Width of the viewport in pixels. + viewport_height : int + Width of the viewport height in pixels. + point_size : float, optional + Size of points in pixels. Defaults to 1.0. + """ + + def __init__(self, viewport_width, viewport_height, point_size=1.0): + self.dpscale = 1 + # Scaling needed on retina displays + if sys.platform == 'darwin': + self.dpscale = 2 + + self.viewport_width = viewport_width + self.viewport_height = viewport_height + self.point_size = point_size + + # Optional framebuffer for offscreen renders + self._main_fb = None + self._main_cb = None + self._main_db = None + self._main_fb_ms = None + self._main_cb_ms = None + self._main_db_ms = None + self._main_fb_dims = (None, None) + self._shadow_fb = None + self._latest_znear = DEFAULT_Z_NEAR + self._latest_zfar = DEFAULT_Z_FAR + + # Shader Program Cache + self._program_cache = ShaderProgramCache() + self._font_cache = FontCache() + self._meshes = set() + self._mesh_textures = set() + self._shadow_textures = set() + self._texture_alloc_idx = 0 + + @property + def viewport_width(self): + """int : The width of the main viewport, in pixels. + """ + return self._viewport_width + + @viewport_width.setter + def viewport_width(self, value): + self._viewport_width = self.dpscale * value + + @property + def viewport_height(self): + """int : The height of the main viewport, in pixels. + """ + return self._viewport_height + + @viewport_height.setter + def viewport_height(self, value): + self._viewport_height = self.dpscale * value + + @property + def point_size(self): + """float : The size of screen-space points, in pixels. + """ + return self._point_size + + @point_size.setter + def point_size(self, value): + self._point_size = float(value) + + def render(self, scene, flags, seg_node_map=None): + """Render a scene with the given set of flags. + + Parameters + ---------- + scene : :class:`Scene` + A scene to render. + flags : int + A specification from :class:`.RenderFlags`. + seg_node_map : dict + A map from :class:`.Node` objects to (3,) colors for each. + If specified along with flags set to :attr:`.RenderFlags.SEG`, + the color image will be a segmentation image. + + Returns + ------- + color_im : (h, w, 3) uint8 or (h, w, 4) uint8 + If :attr:`RenderFlags.OFFSCREEN` is set, the color buffer. This is + normally an RGB buffer, but if :attr:`.RenderFlags.RGBA` is set, + the buffer will be a full RGBA buffer. + depth_im : (h, w) float32 + If :attr:`RenderFlags.OFFSCREEN` is set, the depth buffer + in linear units. + """ + # Update context with meshes and textures + self._update_context(scene, flags) + + # Render necessary shadow maps + if not bool(flags & RenderFlags.DEPTH_ONLY or flags & RenderFlags.SEG): + for ln in scene.light_nodes: + take_pass = False + if (isinstance(ln.light, DirectionalLight) and + bool(flags & RenderFlags.SHADOWS_DIRECTIONAL)): + take_pass = True + elif (isinstance(ln.light, SpotLight) and + bool(flags & RenderFlags.SHADOWS_SPOT)): + take_pass = True + elif (isinstance(ln.light, PointLight) and + bool(flags & RenderFlags.SHADOWS_POINT)): + take_pass = True + if take_pass: + self._shadow_mapping_pass(scene, ln, flags) + + # Make forward pass + retval = self._forward_pass(scene, flags, seg_node_map=seg_node_map) + + # If necessary, make normals pass + if flags & (RenderFlags.VERTEX_NORMALS | RenderFlags.FACE_NORMALS): + self._normals_pass(scene, flags) + + # Update camera settings for retrieving depth buffers + self._latest_znear = scene.main_camera_node.camera.znear + self._latest_zfar = scene.main_camera_node.camera.zfar + + return retval + + def render_text(self, text, x, y, font_name='OpenSans-Regular', + font_pt=40, color=None, scale=1.0, + align=TextAlign.BOTTOM_LEFT): + """Render text into the current viewport. + + Note + ---- + This cannot be done into an offscreen buffer. + + Parameters + ---------- + text : str + The text to render. + x : int + Horizontal pixel location of text. + y : int + Vertical pixel location of text. + font_name : str + Name of font, from the ``pyrender/fonts`` folder, or + a path to a ``.ttf`` file. + font_pt : int + Height of the text, in font points. + color : (4,) float + The color of the text. Default is black. + scale : int + Scaling factor for text. + align : int + One of the :class:`TextAlign` options which specifies where the + ``x`` and ``y`` parameters lie on the text. For example, + :attr:`TextAlign.BOTTOM_LEFT` means that ``x`` and ``y`` indicate + the position of the bottom-left corner of the textbox. + """ + x *= self.dpscale + y *= self.dpscale + font_pt *= self.dpscale + + if color is None: + color = np.array([0.0, 0.0, 0.0, 1.0]) + else: + color = format_color_vector(color, 4) + + # Set up viewport for render + self._configure_forward_pass_viewport(0) + + # Load font + font = self._font_cache.get_font(font_name, font_pt) + if not font._in_context(): + font._add_to_context() + + # Load program + program = self._get_text_program() + program._bind() + + # Set uniforms + p = np.eye(4) + p[0,0] = 2.0 / self.viewport_width + p[0,3] = -1.0 + p[1,1] = 2.0 / self.viewport_height + p[1,3] = -1.0 + program.set_uniform('projection', p) + program.set_uniform('text_color', color) + + # Draw text + font.render_string(text, x, y, scale, align) + + def read_color_buf(self): + """Read and return the current viewport's color buffer. + + Alpha cannot be computed for an on-screen buffer. + + Returns + ------- + color_im : (h, w, 3) uint8 + The color buffer in RGB byte format. + """ + # Extract color image from frame buffer + width, height = self.viewport_width, self.viewport_height + glBindFramebuffer(GL_READ_FRAMEBUFFER, 0) + glReadBuffer(GL_FRONT) + color_buf = glReadPixels(0, 0, width, height, GL_RGB, GL_UNSIGNED_BYTE) + + # Re-format them into numpy arrays + color_im = np.frombuffer(color_buf, dtype=np.uint8) + color_im = color_im.reshape((height, width, 3)) + color_im = np.flip(color_im, axis=0) + + # Resize for macos if needed + if sys.platform == 'darwin': + color_im = self._resize_image(color_im, True) + + return color_im + + def read_depth_buf(self): + """Read and return the current viewport's color buffer. + + Returns + ------- + depth_im : (h, w) float32 + The depth buffer in linear units. + """ + width, height = self.viewport_width, self.viewport_height + glBindFramebuffer(GL_READ_FRAMEBUFFER, 0) + glReadBuffer(GL_FRONT) + depth_buf = glReadPixels( + 0, 0, width, height, GL_DEPTH_COMPONENT, GL_FLOAT + ) + + depth_im = np.frombuffer(depth_buf, dtype=np.float32) + depth_im = depth_im.reshape((height, width)) + depth_im = np.flip(depth_im, axis=0) + + inf_inds = (depth_im == 1.0) + depth_im = 2.0 * depth_im - 1.0 + z_near, z_far = self._latest_znear, self._latest_zfar + noninf = np.logical_not(inf_inds) + if z_far is None: + depth_im[noninf] = 2 * z_near / (1.0 - depth_im[noninf]) + else: + depth_im[noninf] = ((2.0 * z_near * z_far) / + (z_far + z_near - depth_im[noninf] * + (z_far - z_near))) + depth_im[inf_inds] = 0.0 + + # Resize for macos if needed + if sys.platform == 'darwin': + depth_im = self._resize_image(depth_im) + + return depth_im + + def delete(self): + """Free all allocated OpenGL resources. + """ + # Free shaders + self._program_cache.clear() + + # Free fonts + self._font_cache.clear() + + # Free meshes + for mesh in self._meshes: + for p in mesh.primitives: + p.delete() + + # Free textures + for mesh_texture in self._mesh_textures: + mesh_texture.delete() + + for shadow_texture in self._shadow_textures: + shadow_texture.delete() + + self._meshes = set() + self._mesh_textures = set() + self._shadow_textures = set() + self._texture_alloc_idx = 0 + + self._delete_main_framebuffer() + self._delete_shadow_framebuffer() + + def __del__(self): + try: + self.delete() + except Exception: + pass + + ########################################################################### + # Rendering passes + ########################################################################### + + def _forward_pass(self, scene, flags, seg_node_map=None): + # Set up viewport for render + self._configure_forward_pass_viewport(flags) + + # Clear it + if bool(flags & RenderFlags.SEG): + glClearColor(0.0, 0.0, 0.0, 1.0) + if seg_node_map is None: + seg_node_map = {} + else: + glClearColor(*scene.bg_color) + + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) + + if not bool(flags & RenderFlags.SEG): + glEnable(GL_MULTISAMPLE) + else: + glDisable(GL_MULTISAMPLE) + + # Set up camera matrices + V, P = self._get_camera_matrices(scene) + + program = None + # Now, render each object in sorted order + for node in self._sorted_mesh_nodes(scene): + mesh = node.mesh + + # Skip the mesh if it's not visible + if not mesh.is_visible: + continue + + # If SEG, set color + if bool(flags & RenderFlags.SEG): + if node not in seg_node_map: + continue + color = seg_node_map[node] + if not isinstance(color, (list, tuple, np.ndarray)): + color = np.repeat(color, 3) + else: + color = np.asanyarray(color) + color = color / 255.0 + + for primitive in mesh.primitives: + + # First, get and bind the appropriate program + program = self._get_primitive_program( + primitive, flags, ProgramFlags.USE_MATERIAL + ) + program._bind() + + # Set the camera uniforms + program.set_uniform('V', V) + program.set_uniform('P', P) + program.set_uniform( + 'cam_pos', scene.get_pose(scene.main_camera_node)[:3,3] + ) + if bool(flags & RenderFlags.SEG): + program.set_uniform('color', color) + + # Next, bind the lighting + if not (flags & RenderFlags.DEPTH_ONLY or flags & RenderFlags.FLAT or + flags & RenderFlags.SEG): + self._bind_lighting(scene, program, node, flags) + + # Finally, bind and draw the primitive + self._bind_and_draw_primitive( + primitive=primitive, + pose=scene.get_pose(node), + program=program, + flags=flags + ) + self._reset_active_textures() + + # Unbind the shader and flush the output + if program is not None: + program._unbind() + glFlush() + + # If doing offscreen render, copy result from framebuffer and return + if flags & RenderFlags.OFFSCREEN: + return self._read_main_framebuffer(scene, flags) + else: + return + + def _shadow_mapping_pass(self, scene, light_node, flags): + light = light_node.light + + # Set up viewport for render + self._configure_shadow_mapping_viewport(light, flags) + + # Set up camera matrices + V, P = self._get_light_cam_matrices(scene, light_node, flags) + + # Now, render each object in sorted order + for node in self._sorted_mesh_nodes(scene): + mesh = node.mesh + + # Skip the mesh if it's not visible + if not mesh.is_visible: + continue + + for primitive in mesh.primitives: + + # First, get and bind the appropriate program + program = self._get_primitive_program( + primitive, flags, ProgramFlags.NONE + ) + program._bind() + + # Set the camera uniforms + program.set_uniform('V', V) + program.set_uniform('P', P) + program.set_uniform( + 'cam_pos', scene.get_pose(scene.main_camera_node)[:3,3] + ) + + # Finally, bind and draw the primitive + self._bind_and_draw_primitive( + primitive=primitive, + pose=scene.get_pose(node), + program=program, + flags=RenderFlags.DEPTH_ONLY + ) + self._reset_active_textures() + + # Unbind the shader and flush the output + if program is not None: + program._unbind() + glFlush() + + def _normals_pass(self, scene, flags): + # Set up viewport for render + self._configure_forward_pass_viewport(flags) + program = None + + # Set up camera matrices + V, P = self._get_camera_matrices(scene) + + # Now, render each object in sorted order + for node in self._sorted_mesh_nodes(scene): + mesh = node.mesh + + # Skip the mesh if it's not visible + if not mesh.is_visible: + continue + + for primitive in mesh.primitives: + + # Skip objects that don't have normals + if not primitive.buf_flags & BufFlags.NORMAL: + continue + + # First, get and bind the appropriate program + pf = ProgramFlags.NONE + if flags & RenderFlags.VERTEX_NORMALS: + pf = pf | ProgramFlags.VERTEX_NORMALS + if flags & RenderFlags.FACE_NORMALS: + pf = pf | ProgramFlags.FACE_NORMALS + program = self._get_primitive_program(primitive, flags, pf) + program._bind() + + # Set the camera uniforms + program.set_uniform('V', V) + program.set_uniform('P', P) + program.set_uniform('normal_magnitude', 0.05 * primitive.scale) + program.set_uniform( + 'normal_color', np.array([0.1, 0.1, 1.0, 1.0]) + ) + + # Finally, bind and draw the primitive + self._bind_and_draw_primitive( + primitive=primitive, + pose=scene.get_pose(node), + program=program, + flags=RenderFlags.DEPTH_ONLY + ) + self._reset_active_textures() + + # Unbind the shader and flush the output + if program is not None: + program._unbind() + glFlush() + + ########################################################################### + # Handlers for binding uniforms and drawing primitives + ########################################################################### + + def _bind_and_draw_primitive(self, primitive, pose, program, flags): + # Set model pose matrix + program.set_uniform('M', pose) + + # Bind mesh buffers + primitive._bind() + + # Bind mesh material + if not (flags & RenderFlags.DEPTH_ONLY or flags & RenderFlags.SEG): + material = primitive.material + + # Bind textures + tf = material.tex_flags + if tf & TexFlags.NORMAL: + self._bind_texture(material.normalTexture, + 'material.normal_texture', program) + if tf & TexFlags.OCCLUSION: + self._bind_texture(material.occlusionTexture, + 'material.occlusion_texture', program) + if tf & TexFlags.EMISSIVE: + self._bind_texture(material.emissiveTexture, + 'material.emissive_texture', program) + if tf & TexFlags.BASE_COLOR: + self._bind_texture(material.baseColorTexture, + 'material.base_color_texture', program) + if tf & TexFlags.METALLIC_ROUGHNESS: + self._bind_texture(material.metallicRoughnessTexture, + 'material.metallic_roughness_texture', + program) + if tf & TexFlags.DIFFUSE: + self._bind_texture(material.diffuseTexture, + 'material.diffuse_texture', program) + if tf & TexFlags.SPECULAR_GLOSSINESS: + self._bind_texture(material.specularGlossinessTexture, + 'material.specular_glossiness_texture', + program) + + # Bind other uniforms + b = 'material.{}' + program.set_uniform(b.format('emissive_factor'), + material.emissiveFactor) + if isinstance(material, MetallicRoughnessMaterial): + program.set_uniform(b.format('base_color_factor'), + material.baseColorFactor) + program.set_uniform(b.format('metallic_factor'), + material.metallicFactor) + program.set_uniform(b.format('roughness_factor'), + material.roughnessFactor) + elif isinstance(material, SpecularGlossinessMaterial): + program.set_uniform(b.format('diffuse_factor'), + material.diffuseFactor) + program.set_uniform(b.format('specular_factor'), + material.specularFactor) + program.set_uniform(b.format('glossiness_factor'), + material.glossinessFactor) + + # Set blending options + if material.alphaMode == 'BLEND': + glEnable(GL_BLEND) + glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) + else: + glEnable(GL_BLEND) + glBlendFunc(GL_ONE, GL_ZERO) + + # Set wireframe mode + wf = material.wireframe + if flags & RenderFlags.FLIP_WIREFRAME: + wf = not wf + if (flags & RenderFlags.ALL_WIREFRAME) or wf: + glPolygonMode(GL_FRONT_AND_BACK, GL_LINE) + else: + glPolygonMode(GL_FRONT_AND_BACK, GL_FILL) + + # Set culling mode + if material.doubleSided or flags & RenderFlags.SKIP_CULL_FACES: + glDisable(GL_CULL_FACE) + else: + glEnable(GL_CULL_FACE) + glCullFace(GL_BACK) + else: + glEnable(GL_CULL_FACE) + glEnable(GL_BLEND) + glCullFace(GL_BACK) + glBlendFunc(GL_ONE, GL_ZERO) + glPolygonMode(GL_FRONT_AND_BACK, GL_FILL) + + # Set point size if needed + glDisable(GL_PROGRAM_POINT_SIZE) + if primitive.mode == GLTF.POINTS: + glEnable(GL_PROGRAM_POINT_SIZE) + glPointSize(self.point_size) + + # Render mesh + n_instances = 1 + if primitive.poses is not None: + n_instances = len(primitive.poses) + + if primitive.indices is not None: + glDrawElementsInstanced( + primitive.mode, primitive.indices.size, GL_UNSIGNED_INT, + ctypes.c_void_p(0), n_instances + ) + else: + glDrawArraysInstanced( + primitive.mode, 0, len(primitive.positions), n_instances + ) + + # Unbind mesh buffers + primitive._unbind() + + def _bind_lighting(self, scene, program, node, flags): + """Bind all lighting uniform values for a scene. + """ + max_n_lights = self._compute_max_n_lights(flags) + + n_d = min(len(scene.directional_light_nodes), max_n_lights[0]) + n_s = min(len(scene.spot_light_nodes), max_n_lights[1]) + n_p = min(len(scene.point_light_nodes), max_n_lights[2]) + program.set_uniform('ambient_light', scene.ambient_light) + program.set_uniform('n_directional_lights', n_d) + program.set_uniform('n_spot_lights', n_s) + program.set_uniform('n_point_lights', n_p) + plc = 0 + slc = 0 + dlc = 0 + + light_nodes = scene.light_nodes + if (len(scene.directional_light_nodes) > max_n_lights[0] or + len(scene.spot_light_nodes) > max_n_lights[1] or + len(scene.point_light_nodes) > max_n_lights[2]): + light_nodes = self._sorted_nodes_by_distance( + scene, scene.light_nodes, node + ) + + for n in light_nodes: + light = n.light + pose = scene.get_pose(n) + position = pose[:3,3] + direction = -pose[:3,2] + + if isinstance(light, PointLight): + if plc == max_n_lights[2]: + continue + b = 'point_lights[{}].'.format(plc) + plc += 1 + shadow = bool(flags & RenderFlags.SHADOWS_POINT) + program.set_uniform(b + 'position', position) + elif isinstance(light, SpotLight): + if slc == max_n_lights[1]: + continue + b = 'spot_lights[{}].'.format(slc) + slc += 1 + shadow = bool(flags & RenderFlags.SHADOWS_SPOT) + las = 1.0 / max(0.001, np.cos(light.innerConeAngle) - + np.cos(light.outerConeAngle)) + lao = -np.cos(light.outerConeAngle) * las + program.set_uniform(b + 'direction', direction) + program.set_uniform(b + 'position', position) + program.set_uniform(b + 'light_angle_scale', las) + program.set_uniform(b + 'light_angle_offset', lao) + else: + if dlc == max_n_lights[0]: + continue + b = 'directional_lights[{}].'.format(dlc) + dlc += 1 + shadow = bool(flags & RenderFlags.SHADOWS_DIRECTIONAL) + program.set_uniform(b + 'direction', direction) + + program.set_uniform(b + 'color', light.color) + program.set_uniform(b + 'intensity', light.intensity) + # if light.range is not None: + # program.set_uniform(b + 'range', light.range) + # else: + # program.set_uniform(b + 'range', 0) + + if shadow: + self._bind_texture(light.shadow_texture, + b + 'shadow_map', program) + if not isinstance(light, PointLight): + V, P = self._get_light_cam_matrices(scene, n, flags) + program.set_uniform(b + 'light_matrix', P.dot(V)) + else: + raise NotImplementedError( + 'Point light shadows not implemented' + ) + + def _sorted_mesh_nodes(self, scene): + cam_loc = scene.get_pose(scene.main_camera_node)[:3,3] + solid_nodes = [] + trans_nodes = [] + for node in scene.mesh_nodes: + mesh = node.mesh + if mesh.is_transparent: + trans_nodes.append(node) + else: + solid_nodes.append(node) + + # TODO BETTER SORTING METHOD + trans_nodes.sort( + key=lambda n: -np.linalg.norm(scene.get_pose(n)[:3,3] - cam_loc) + ) + solid_nodes.sort( + key=lambda n: -np.linalg.norm(scene.get_pose(n)[:3,3] - cam_loc) + ) + + return solid_nodes + trans_nodes + + def _sorted_nodes_by_distance(self, scene, nodes, compare_node): + nodes = list(nodes) + compare_posn = scene.get_pose(compare_node)[:3,3] + nodes.sort(key=lambda n: np.linalg.norm( + scene.get_pose(n)[:3,3] - compare_posn) + ) + return nodes + + ########################################################################### + # Context Management + ########################################################################### + + def _update_context(self, scene, flags): + + # Update meshes + scene_meshes = scene.meshes + + # Add new meshes to context + for mesh in scene_meshes - self._meshes: + for p in mesh.primitives: + p._add_to_context() + + # Remove old meshes from context + for mesh in self._meshes - scene_meshes: + for p in mesh.primitives: + p.delete() + + self._meshes = scene_meshes.copy() + + # Update mesh textures + mesh_textures = set() + for m in scene_meshes: + for p in m.primitives: + mesh_textures |= p.material.textures + + # Add new textures to context + for texture in mesh_textures - self._mesh_textures: + texture._add_to_context() + + # Remove old textures from context + for texture in self._mesh_textures - mesh_textures: + texture.delete() + + self._mesh_textures = mesh_textures.copy() + + shadow_textures = set() + for l in scene.lights: + # Create if needed + active = False + if (isinstance(l, DirectionalLight) and + flags & RenderFlags.SHADOWS_DIRECTIONAL): + active = True + elif (isinstance(l, PointLight) and + flags & RenderFlags.SHADOWS_POINT): + active = True + elif isinstance(l, SpotLight) and flags & RenderFlags.SHADOWS_SPOT: + active = True + + if active and l.shadow_texture is None: + l._generate_shadow_texture() + if l.shadow_texture is not None: + shadow_textures.add(l.shadow_texture) + + # Add new textures to context + for texture in shadow_textures - self._shadow_textures: + texture._add_to_context() + + # Remove old textures from context + for texture in self._shadow_textures - shadow_textures: + texture.delete() + + self._shadow_textures = shadow_textures.copy() + + ########################################################################### + # Texture Management + ########################################################################### + + def _bind_texture(self, texture, uniform_name, program): + """Bind a texture to an active texture unit and return + the texture unit index that was used. + """ + tex_id = self._get_next_active_texture() + glActiveTexture(GL_TEXTURE0 + tex_id) + texture._bind() + program.set_uniform(uniform_name, tex_id) + + def _get_next_active_texture(self): + val = self._texture_alloc_idx + self._texture_alloc_idx += 1 + return val + + def _reset_active_textures(self): + self._texture_alloc_idx = 0 + + ########################################################################### + # Camera Matrix Management + ########################################################################### + + def _get_camera_matrices(self, scene): + main_camera_node = scene.main_camera_node + if main_camera_node is None: + raise ValueError('Cannot render scene without a camera') + P = main_camera_node.camera.get_projection_matrix( + width=self.viewport_width, height=self.viewport_height + ) + pose = scene.get_pose(main_camera_node) + V = np.linalg.inv(pose) # V maps from world to camera + return V, P + + def _get_light_cam_matrices(self, scene, light_node, flags): + light = light_node.light + pose = scene.get_pose(light_node).copy() + s = scene.scale + camera = light._get_shadow_camera(s) + P = camera.get_projection_matrix() + if isinstance(light, DirectionalLight): + direction = -pose[:3,2] + c = scene.centroid + loc = c - direction * s + pose[:3,3] = loc + V = np.linalg.inv(pose) # V maps from world to camera + return V, P + + ########################################################################### + # Shader Program Management + ########################################################################### + + def _get_text_program(self): + program = self._program_cache.get_program( + vertex_shader='text.vert', + fragment_shader='text.frag' + ) + + if not program._in_context(): + program._add_to_context() + + return program + + def _compute_max_n_lights(self, flags): + max_n_lights = [MAX_N_LIGHTS, MAX_N_LIGHTS, MAX_N_LIGHTS] + n_tex_units = glGetIntegerv(GL_MAX_TEXTURE_IMAGE_UNITS) + + # Reserved texture units: 6 + # Normal Map + # Occlusion Map + # Emissive Map + # Base Color or Diffuse Map + # MR or SG Map + # Environment cubemap + + n_reserved_textures = 6 + n_available_textures = n_tex_units - n_reserved_textures + + # Distribute textures evenly among lights with shadows, with + # a preference for directional lights + n_shadow_types = 0 + if flags & RenderFlags.SHADOWS_DIRECTIONAL: + n_shadow_types += 1 + if flags & RenderFlags.SHADOWS_SPOT: + n_shadow_types += 1 + if flags & RenderFlags.SHADOWS_POINT: + n_shadow_types += 1 + + if n_shadow_types > 0: + tex_per_light = n_available_textures // n_shadow_types + + if flags & RenderFlags.SHADOWS_DIRECTIONAL: + max_n_lights[0] = ( + tex_per_light + + (n_available_textures - tex_per_light * n_shadow_types) + ) + if flags & RenderFlags.SHADOWS_SPOT: + max_n_lights[1] = tex_per_light + if flags & RenderFlags.SHADOWS_POINT: + max_n_lights[2] = tex_per_light + + return max_n_lights + + def _get_primitive_program(self, primitive, flags, program_flags): + vertex_shader = None + fragment_shader = None + geometry_shader = None + defines = {} + + if (bool(program_flags & ProgramFlags.USE_MATERIAL) and + not flags & RenderFlags.DEPTH_ONLY and + not flags & RenderFlags.FLAT and + not flags & RenderFlags.SEG): + vertex_shader = 'mesh.vert' + fragment_shader = 'mesh.frag' + elif bool(program_flags & (ProgramFlags.VERTEX_NORMALS | + ProgramFlags.FACE_NORMALS)): + vertex_shader = 'vertex_normals.vert' + if primitive.mode == GLTF.POINTS: + geometry_shader = 'vertex_normals_pc.geom' + else: + geometry_shader = 'vertex_normals.geom' + fragment_shader = 'vertex_normals.frag' + elif flags & RenderFlags.FLAT: + vertex_shader = 'flat.vert' + fragment_shader = 'flat.frag' + elif flags & RenderFlags.SEG: + vertex_shader = 'segmentation.vert' + fragment_shader = 'segmentation.frag' + else: + vertex_shader = 'mesh_depth.vert' + fragment_shader = 'mesh_depth.frag' + + # Set up vertex buffer DEFINES + bf = primitive.buf_flags + buf_idx = 1 + if bf & BufFlags.NORMAL: + defines['NORMAL_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.TANGENT: + defines['TANGENT_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.TEXCOORD_0: + defines['TEXCOORD_0_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.TEXCOORD_1: + defines['TEXCOORD_1_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.COLOR_0: + defines['COLOR_0_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.JOINTS_0: + defines['JOINTS_0_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.WEIGHTS_0: + defines['WEIGHTS_0_LOC'] = buf_idx + buf_idx += 1 + defines['INST_M_LOC'] = buf_idx + + # Set up shadow mapping defines + if flags & RenderFlags.SHADOWS_DIRECTIONAL: + defines['DIRECTIONAL_LIGHT_SHADOWS'] = 1 + if flags & RenderFlags.SHADOWS_SPOT: + defines['SPOT_LIGHT_SHADOWS'] = 1 + if flags & RenderFlags.SHADOWS_POINT: + defines['POINT_LIGHT_SHADOWS'] = 1 + max_n_lights = self._compute_max_n_lights(flags) + defines['MAX_DIRECTIONAL_LIGHTS'] = max_n_lights[0] + defines['MAX_SPOT_LIGHTS'] = max_n_lights[1] + defines['MAX_POINT_LIGHTS'] = max_n_lights[2] + + # Set up vertex normal defines + if program_flags & ProgramFlags.VERTEX_NORMALS: + defines['VERTEX_NORMALS'] = 1 + if program_flags & ProgramFlags.FACE_NORMALS: + defines['FACE_NORMALS'] = 1 + + # Set up material texture defines + if bool(program_flags & ProgramFlags.USE_MATERIAL): + tf = primitive.material.tex_flags + if tf & TexFlags.NORMAL: + defines['HAS_NORMAL_TEX'] = 1 + if tf & TexFlags.OCCLUSION: + defines['HAS_OCCLUSION_TEX'] = 1 + if tf & TexFlags.EMISSIVE: + defines['HAS_EMISSIVE_TEX'] = 1 + if tf & TexFlags.BASE_COLOR: + defines['HAS_BASE_COLOR_TEX'] = 1 + if tf & TexFlags.METALLIC_ROUGHNESS: + defines['HAS_METALLIC_ROUGHNESS_TEX'] = 1 + if tf & TexFlags.DIFFUSE: + defines['HAS_DIFFUSE_TEX'] = 1 + if tf & TexFlags.SPECULAR_GLOSSINESS: + defines['HAS_SPECULAR_GLOSSINESS_TEX'] = 1 + if isinstance(primitive.material, MetallicRoughnessMaterial): + defines['USE_METALLIC_MATERIAL'] = 1 + elif isinstance(primitive.material, SpecularGlossinessMaterial): + defines['USE_GLOSSY_MATERIAL'] = 1 + + program = self._program_cache.get_program( + vertex_shader=vertex_shader, + fragment_shader=fragment_shader, + geometry_shader=geometry_shader, + defines=defines + ) + + if not program._in_context(): + program._add_to_context() + + return program + + ########################################################################### + # Viewport Management + ########################################################################### + + def _configure_forward_pass_viewport(self, flags): + + # If using offscreen render, bind main framebuffer + if flags & RenderFlags.OFFSCREEN: + self._configure_main_framebuffer() + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self._main_fb_ms) + else: + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, 0) + + glViewport(0, 0, self.viewport_width, self.viewport_height) + glEnable(GL_DEPTH_TEST) + glDepthMask(GL_TRUE) + glDepthFunc(GL_LESS) + glDepthRange(0.0, 1.0) + + def _configure_shadow_mapping_viewport(self, light, flags): + self._configure_shadow_framebuffer() + glBindFramebuffer(GL_FRAMEBUFFER, self._shadow_fb) + light.shadow_texture._bind() + light.shadow_texture._bind_as_depth_attachment() + glActiveTexture(GL_TEXTURE0) + light.shadow_texture._bind() + glDrawBuffer(GL_NONE) + glReadBuffer(GL_NONE) + + glClear(GL_DEPTH_BUFFER_BIT) + glViewport(0, 0, SHADOW_TEX_SZ, SHADOW_TEX_SZ) + glEnable(GL_DEPTH_TEST) + glDepthMask(GL_TRUE) + glDepthFunc(GL_LESS) + glDepthRange(0.0, 1.0) + glDisable(GL_CULL_FACE) + glDisable(GL_BLEND) + + ########################################################################### + # Framebuffer Management + ########################################################################### + + def _configure_shadow_framebuffer(self): + if self._shadow_fb is None: + self._shadow_fb = glGenFramebuffers(1) + + def _delete_shadow_framebuffer(self): + if self._shadow_fb is not None: + glDeleteFramebuffers(1, [self._shadow_fb]) + + def _configure_main_framebuffer(self): + # If mismatch with prior framebuffer, delete it + if (self._main_fb is not None and + self.viewport_width != self._main_fb_dims[0] or + self.viewport_height != self._main_fb_dims[1]): + self._delete_main_framebuffer() + + # If framebuffer doesn't exist, create it + if self._main_fb is None: + # Generate standard buffer + self._main_cb, self._main_db = glGenRenderbuffers(2) + + glBindRenderbuffer(GL_RENDERBUFFER, self._main_cb) + glRenderbufferStorage( + GL_RENDERBUFFER, GL_RGBA, + self.viewport_width, self.viewport_height + ) + + glBindRenderbuffer(GL_RENDERBUFFER, self._main_db) + glRenderbufferStorage( + GL_RENDERBUFFER, GL_DEPTH_COMPONENT24, + self.viewport_width, self.viewport_height + ) + + self._main_fb = glGenFramebuffers(1) + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self._main_fb) + glFramebufferRenderbuffer( + GL_DRAW_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + GL_RENDERBUFFER, self._main_cb + ) + glFramebufferRenderbuffer( + GL_DRAW_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, + GL_RENDERBUFFER, self._main_db + ) + + # Generate multisample buffer + self._main_cb_ms, self._main_db_ms = glGenRenderbuffers(2) + glBindRenderbuffer(GL_RENDERBUFFER, self._main_cb_ms) + # glRenderbufferStorageMultisample( + # GL_RENDERBUFFER, 4, GL_RGBA, + # self.viewport_width, self.viewport_height + # ) + # glBindRenderbuffer(GL_RENDERBUFFER, self._main_db_ms) + # glRenderbufferStorageMultisample( + # GL_RENDERBUFFER, 4, GL_DEPTH_COMPONENT24, + # self.viewport_width, self.viewport_height + # ) + # 增加这一行 + num_samples = min(glGetIntegerv(GL_MAX_SAMPLES), 4) # No more than GL_MAX_SAMPLES + + # 其实就是把 4 替换成 num_samples,其余不变 + glRenderbufferStorageMultisample(GL_RENDERBUFFER, num_samples, GL_RGBA, self.viewport_width, self.viewport_height) + + glBindRenderbuffer(GL_RENDERBUFFER, self._main_db_ms) # 这行不变 + + # 这一行也是将 4 替换成 num_samples + glRenderbufferStorageMultisample(GL_RENDERBUFFER, num_samples, GL_DEPTH_COMPONENT24, self.viewport_width, self.viewport_height) + + self._main_fb_ms = glGenFramebuffers(1) + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self._main_fb_ms) + glFramebufferRenderbuffer( + GL_DRAW_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + GL_RENDERBUFFER, self._main_cb_ms + ) + glFramebufferRenderbuffer( + GL_DRAW_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, + GL_RENDERBUFFER, self._main_db_ms + ) + + self._main_fb_dims = (self.viewport_width, self.viewport_height) + + def _delete_main_framebuffer(self): + if self._main_fb is not None: + glDeleteFramebuffers(2, [self._main_fb, self._main_fb_ms]) + if self._main_cb is not None: + glDeleteRenderbuffers(2, [self._main_cb, self._main_cb_ms]) + if self._main_db is not None: + glDeleteRenderbuffers(2, [self._main_db, self._main_db_ms]) + + self._main_fb = None + self._main_cb = None + self._main_db = None + self._main_fb_ms = None + self._main_cb_ms = None + self._main_db_ms = None + self._main_fb_dims = (None, None) + + def _read_main_framebuffer(self, scene, flags): + width, height = self._main_fb_dims[0], self._main_fb_dims[1] + + # Bind framebuffer and blit buffers + glBindFramebuffer(GL_READ_FRAMEBUFFER, self._main_fb_ms) + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self._main_fb) + glBlitFramebuffer( + 0, 0, width, height, 0, 0, width, height, + GL_COLOR_BUFFER_BIT, GL_LINEAR + ) + glBlitFramebuffer( + 0, 0, width, height, 0, 0, width, height, + GL_DEPTH_BUFFER_BIT, GL_NEAREST + ) + glBindFramebuffer(GL_READ_FRAMEBUFFER, self._main_fb) + + # Read depth + depth_buf = glReadPixels( + 0, 0, width, height, GL_DEPTH_COMPONENT, GL_FLOAT + ) + depth_im = np.frombuffer(depth_buf, dtype=np.float32) + depth_im = depth_im.reshape((height, width)) + depth_im = np.flip(depth_im, axis=0) + inf_inds = (depth_im == 1.0) + depth_im = 2.0 * depth_im - 1.0 + z_near = scene.main_camera_node.camera.znear + z_far = scene.main_camera_node.camera.zfar + noninf = np.logical_not(inf_inds) + if z_far is None: + depth_im[noninf] = 2 * z_near / (1.0 - depth_im[noninf]) + else: + depth_im[noninf] = ((2.0 * z_near * z_far) / + (z_far + z_near - depth_im[noninf] * + (z_far - z_near))) + depth_im[inf_inds] = 0.0 + + # Resize for macos if needed + if sys.platform == 'darwin': + depth_im = self._resize_image(depth_im) + + if flags & RenderFlags.DEPTH_ONLY: + return depth_im + + # Read color + if flags & RenderFlags.RGBA: + color_buf = glReadPixels( + 0, 0, width, height, GL_RGBA, GL_UNSIGNED_BYTE + ) + color_im = np.frombuffer(color_buf, dtype=np.uint8) + color_im = color_im.reshape((height, width, 4)) + else: + color_buf = glReadPixels( + 0, 0, width, height, GL_RGB, GL_UNSIGNED_BYTE + ) + color_im = np.frombuffer(color_buf, dtype=np.uint8) + color_im = color_im.reshape((height, width, 3)) + color_im = np.flip(color_im, axis=0) + + # Resize for macos if needed + if sys.platform == 'darwin': + color_im = self._resize_image(color_im, True) + + return color_im, depth_im + + def _resize_image(self, value, antialias=False): + """If needed, rescale the render for MacOS.""" + img = PIL.Image.fromarray(value) + resample = PIL.Image.NEAREST + if antialias: + resample = PIL.Image.BILINEAR + size = (self.viewport_width // self.dpscale, + self.viewport_height // self.dpscale) + img = img.resize(size, resample=resample) + return np.array(img) + + ########################################################################### + # Shadowmap Debugging + ########################################################################### + + def _forward_pass_no_reset(self, scene, flags): + # Set up camera matrices + V, P = self._get_camera_matrices(scene) + + # Now, render each object in sorted order + for node in self._sorted_mesh_nodes(scene): + mesh = node.mesh + + # Skip the mesh if it's not visible + if not mesh.is_visible: + continue + + for primitive in mesh.primitives: + + # First, get and bind the appropriate program + program = self._get_primitive_program( + primitive, flags, ProgramFlags.USE_MATERIAL + ) + program._bind() + + # Set the camera uniforms + program.set_uniform('V', V) + program.set_uniform('P', P) + program.set_uniform( + 'cam_pos', scene.get_pose(scene.main_camera_node)[:3,3] + ) + + # Next, bind the lighting + if not flags & RenderFlags.DEPTH_ONLY and not flags & RenderFlags.FLAT: + self._bind_lighting(scene, program, node, flags) + + # Finally, bind and draw the primitive + self._bind_and_draw_primitive( + primitive=primitive, + pose=scene.get_pose(node), + program=program, + flags=flags + ) + self._reset_active_textures() + + # Unbind the shader and flush the output + if program is not None: + program._unbind() + glFlush() + + def _render_light_shadowmaps(self, scene, light_nodes, flags, tile=False): + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, 0) + glClearColor(*scene.bg_color) + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) + glEnable(GL_DEPTH_TEST) + glDepthMask(GL_TRUE) + glDepthFunc(GL_LESS) + glDepthRange(0.0, 1.0) + + w = self.viewport_width + h = self.viewport_height + + num_nodes = len(light_nodes) + viewport_dims = { + (0, 2): [0, h // 2, w // 2, h], + (1, 2): [w // 2, h // 2, w, h], + (0, 3): [0, h // 2, w // 2, h], + (1, 3): [w // 2, h // 2, w, h], + (2, 3): [0, 0, w // 2, h // 2], + (0, 4): [0, h // 2, w // 2, h], + (1, 4): [w // 2, h // 2, w, h], + (2, 4): [0, 0, w // 2, h // 2], + (3, 4): [w // 2, 0, w, h // 2] + } + + if tile: + for i, ln in enumerate(light_nodes): + light = ln.light + + if light.shadow_texture is None: + raise ValueError('Light does not have a shadow texture') + + glViewport(*viewport_dims[(i, num_nodes + 1)]) + + program = self._get_debug_quad_program() + program._bind() + self._bind_texture(light.shadow_texture, 'depthMap', program) + self._render_debug_quad() + self._reset_active_textures() + glFlush() + i += 1 + glViewport(*viewport_dims[(i, num_nodes + 1)]) + self._forward_pass_no_reset(scene, flags) + else: + for i, ln in enumerate(light_nodes): + light = ln.light + + if light.shadow_texture is None: + raise ValueError('Light does not have a shadow texture') + + glViewport(0, 0, self.viewport_width, self.viewport_height) + + program = self._get_debug_quad_program() + program._bind() + self._bind_texture(light.shadow_texture, 'depthMap', program) + self._render_debug_quad() + self._reset_active_textures() + glFlush() + return + + def _get_debug_quad_program(self): + program = self._program_cache.get_program( + vertex_shader='debug_quad.vert', + fragment_shader='debug_quad.frag' + ) + if not program._in_context(): + program._add_to_context() + return program + + def _render_debug_quad(self): + x = glGenVertexArrays(1) + glBindVertexArray(x) + glDrawArrays(GL_TRIANGLES, 0, 6) + glBindVertexArray(0) + glDeleteVertexArrays(1, [x]) diff --git a/pyrender/pyrender/sampler.py b/pyrender/pyrender/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..e4784d068f808a40a56c8e748d83175f7f4e6233 --- /dev/null +++ b/pyrender/pyrender/sampler.py @@ -0,0 +1,102 @@ +"""Samplers, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-sampler + +Author: Matthew Matl +""" +from .constants import GLTF + + +class Sampler(object): + """Texture sampler properties for filtering and wrapping modes. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + magFilter : int, optional + Magnification filter. Valid values: + - :attr:`.GLTF.NEAREST` + - :attr:`.GLTF.LINEAR` + minFilter : int, optional + Minification filter. Valid values: + - :attr:`.GLTF.NEAREST` + - :attr:`.GLTF.LINEAR` + - :attr:`.GLTF.NEAREST_MIPMAP_NEAREST` + - :attr:`.GLTF.LINEAR_MIPMAP_NEAREST` + - :attr:`.GLTF.NEAREST_MIPMAP_LINEAR` + - :attr:`.GLTF.LINEAR_MIPMAP_LINEAR` + wrapS : int, optional + S (U) wrapping mode. Valid values: + - :attr:`.GLTF.CLAMP_TO_EDGE` + - :attr:`.GLTF.MIRRORED_REPEAT` + - :attr:`.GLTF.REPEAT` + wrapT : int, optional + T (V) wrapping mode. Valid values: + - :attr:`.GLTF.CLAMP_TO_EDGE` + - :attr:`.GLTF.MIRRORED_REPEAT` + - :attr:`.GLTF.REPEAT` + """ + + def __init__(self, + name=None, + magFilter=None, + minFilter=None, + wrapS=GLTF.REPEAT, + wrapT=GLTF.REPEAT): + self.name = name + self.magFilter = magFilter + self.minFilter = minFilter + self.wrapS = wrapS + self.wrapT = wrapT + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def magFilter(self): + """int : Magnification filter type. + """ + return self._magFilter + + @magFilter.setter + def magFilter(self, value): + self._magFilter = value + + @property + def minFilter(self): + """int : Minification filter type. + """ + return self._minFilter + + @minFilter.setter + def minFilter(self, value): + self._minFilter = value + + @property + def wrapS(self): + """int : S (U) wrapping mode. + """ + return self._wrapS + + @wrapS.setter + def wrapS(self, value): + self._wrapS = value + + @property + def wrapT(self): + """int : T (V) wrapping mode. + """ + return self._wrapT + + @wrapT.setter + def wrapT(self, value): + self._wrapT = value diff --git a/pyrender/pyrender/scene.py b/pyrender/pyrender/scene.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe057ec66f52f2dd9c1363aacf72a7c6cec4e6c --- /dev/null +++ b/pyrender/pyrender/scene.py @@ -0,0 +1,585 @@ +"""Scenes, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-scene + +Author: Matthew Matl +""" +import numpy as np +import networkx as nx +import trimesh + +from .mesh import Mesh +from .camera import Camera +from .light import Light, PointLight, DirectionalLight, SpotLight +from .node import Node +from .utils import format_color_vector + + +class Scene(object): + """A hierarchical scene graph. + + Parameters + ---------- + nodes : list of :class:`Node` + The set of all nodes in the scene. + bg_color : (4,) float, optional + Background color of scene. + ambient_light : (3,) float, optional + Color of ambient light. Defaults to no ambient light. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + nodes=None, + bg_color=None, + ambient_light=None, + name=None): + + if bg_color is None: + bg_color = np.ones(4) + else: + bg_color = format_color_vector(bg_color, 4) + + if ambient_light is None: + ambient_light = np.zeros(3) + + if nodes is None: + nodes = set() + self._nodes = set() # Will be added at the end of this function + + self.bg_color = bg_color + self.ambient_light = ambient_light + self.name = name + + self._name_to_nodes = {} + self._obj_to_nodes = {} + self._obj_name_to_nodes = {} + self._mesh_nodes = set() + self._point_light_nodes = set() + self._spot_light_nodes = set() + self._directional_light_nodes = set() + self._camera_nodes = set() + self._main_camera_node = None + self._bounds = None + + # Transform tree + self._digraph = nx.DiGraph() + self._digraph.add_node('world') + self._path_cache = {} + + # Find root nodes and add them + if len(nodes) > 0: + node_parent_map = {n: None for n in nodes} + for node in nodes: + for child in node.children: + if node_parent_map[child] is not None: + raise ValueError('Nodes may not have more than ' + 'one parent') + node_parent_map[child] = node + for node in node_parent_map: + if node_parent_map[node] is None: + self.add_node(node) + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def nodes(self): + """set of :class:`Node` : Set of nodes in the scene. + """ + return self._nodes + + @property + def bg_color(self): + """(3,) float : The scene background color. + """ + return self._bg_color + + @bg_color.setter + def bg_color(self, value): + if value is None: + value = np.ones(4) + else: + value = format_color_vector(value, 4) + self._bg_color = value + + @property + def ambient_light(self): + """(3,) float : The ambient light in the scene. + """ + return self._ambient_light + + @ambient_light.setter + def ambient_light(self, value): + if value is None: + value = np.zeros(3) + else: + value = format_color_vector(value, 3) + self._ambient_light = value + + @property + def meshes(self): + """set of :class:`Mesh` : The meshes in the scene. + """ + return set([n.mesh for n in self.mesh_nodes]) + + @property + def mesh_nodes(self): + """set of :class:`Node` : The nodes containing meshes. + """ + return self._mesh_nodes + + @property + def lights(self): + """set of :class:`Light` : The lights in the scene. + """ + return self.point_lights | self.spot_lights | self.directional_lights + + @property + def light_nodes(self): + """set of :class:`Node` : The nodes containing lights. + """ + return (self.point_light_nodes | self.spot_light_nodes | + self.directional_light_nodes) + + @property + def point_lights(self): + """set of :class:`PointLight` : The point lights in the scene. + """ + return set([n.light for n in self.point_light_nodes]) + + @property + def point_light_nodes(self): + """set of :class:`Node` : The nodes containing point lights. + """ + return self._point_light_nodes + + @property + def spot_lights(self): + """set of :class:`SpotLight` : The spot lights in the scene. + """ + return set([n.light for n in self.spot_light_nodes]) + + @property + def spot_light_nodes(self): + """set of :class:`Node` : The nodes containing spot lights. + """ + return self._spot_light_nodes + + @property + def directional_lights(self): + """set of :class:`DirectionalLight` : The directional lights in + the scene. + """ + return set([n.light for n in self.directional_light_nodes]) + + @property + def directional_light_nodes(self): + """set of :class:`Node` : The nodes containing directional lights. + """ + return self._directional_light_nodes + + @property + def cameras(self): + """set of :class:`Camera` : The cameras in the scene. + """ + return set([n.camera for n in self.camera_nodes]) + + @property + def camera_nodes(self): + """set of :class:`Node` : The nodes containing cameras in the scene. + """ + return self._camera_nodes + + @property + def main_camera_node(self): + """set of :class:`Node` : The node containing the main camera in the + scene. + """ + return self._main_camera_node + + @main_camera_node.setter + def main_camera_node(self, value): + if value not in self.nodes: + raise ValueError('New main camera node must already be in scene') + self._main_camera_node = value + + @property + def bounds(self): + """(2,3) float : The axis-aligned bounds of the scene. + """ + if self._bounds is None: + # Compute corners + corners = [] + for mesh_node in self.mesh_nodes: + mesh = mesh_node.mesh + pose = self.get_pose(mesh_node) + corners_local = trimesh.bounds.corners(mesh.bounds) + corners_world = pose[:3,:3].dot(corners_local.T).T + pose[:3,3] + corners.append(corners_world) + if len(corners) == 0: + self._bounds = np.zeros((2,3)) + else: + corners = np.vstack(corners) + self._bounds = np.array([np.min(corners, axis=0), + np.max(corners, axis=0)]) + return self._bounds + + @property + def centroid(self): + """(3,) float : The centroid of the scene's axis-aligned bounding box + (AABB). + """ + return np.mean(self.bounds, axis=0) + + @property + def extents(self): + """(3,) float : The lengths of the axes of the scene's AABB. + """ + return np.diff(self.bounds, axis=0).reshape(-1) + + @property + def scale(self): + """(3,) float : The length of the diagonal of the scene's AABB. + """ + return np.linalg.norm(self.extents) + + def add(self, obj, name=None, pose=None, + parent_node=None, parent_name=None): + """Add an object (mesh, light, or camera) to the scene. + + Parameters + ---------- + obj : :class:`Mesh`, :class:`Light`, or :class:`Camera` + The object to add to the scene. + name : str + A name for the new node to be created. + pose : (4,4) float + The local pose of this node relative to its parent node. + parent_node : :class:`Node` + The parent of this Node. If None, the new node is a root node. + parent_name : str + The name of the parent node, can be specified instead of + `parent_node`. + + Returns + ------- + node : :class:`Node` + The newly-created and inserted node. + """ + if isinstance(obj, Mesh): + node = Node(name=name, matrix=pose, mesh=obj) + elif isinstance(obj, Light): + node = Node(name=name, matrix=pose, light=obj) + elif isinstance(obj, Camera): + node = Node(name=name, matrix=pose, camera=obj) + else: + raise TypeError('Unrecognized object type') + + if parent_node is None and parent_name is not None: + parent_nodes = self.get_nodes(name=parent_name) + if len(parent_nodes) == 0: + raise ValueError('No parent node with name {} found' + .format(parent_name)) + elif len(parent_nodes) > 1: + raise ValueError('More than one parent node with name {} found' + .format(parent_name)) + parent_node = list(parent_nodes)[0] + + self.add_node(node, parent_node=parent_node) + + return node + + def get_nodes(self, node=None, name=None, obj=None, obj_name=None): + """Search for existing nodes. Only nodes matching all specified + parameters is returned, or None if no such node exists. + + Parameters + ---------- + node : :class:`Node`, optional + If present, returns this node if it is in the scene. + name : str + A name for the Node. + obj : :class:`Mesh`, :class:`Light`, or :class:`Camera` + An object that is attached to the node. + obj_name : str + The name of an object that is attached to the node. + + Returns + ------- + nodes : set of :class:`.Node` + The nodes that match all query terms. + """ + if node is not None: + if node in self.nodes: + return set([node]) + else: + return set() + nodes = set(self.nodes) + if name is not None: + matches = set() + if name in self._name_to_nodes: + matches = self._name_to_nodes[name] + nodes = nodes & matches + if obj is not None: + matches = set() + if obj in self._obj_to_nodes: + matches = self._obj_to_nodes[obj] + nodes = nodes & matches + if obj_name is not None: + matches = set() + if obj_name in self._obj_name_to_nodes: + matches = self._obj_name_to_nodes[obj_name] + nodes = nodes & matches + + return nodes + + def add_node(self, node, parent_node=None): + """Add a Node to the scene. + + Parameters + ---------- + node : :class:`Node` + The node to be added. + parent_node : :class:`Node` + The parent of this Node. If None, the new node is a root node. + """ + if node in self.nodes: + raise ValueError('Node already in scene') + self.nodes.add(node) + + # Add node to sets + if node.name is not None: + if node.name not in self._name_to_nodes: + self._name_to_nodes[node.name] = set() + self._name_to_nodes[node.name].add(node) + for obj in [node.mesh, node.camera, node.light]: + if obj is not None: + if obj not in self._obj_to_nodes: + self._obj_to_nodes[obj] = set() + self._obj_to_nodes[obj].add(node) + if obj.name is not None: + if obj.name not in self._obj_name_to_nodes: + self._obj_name_to_nodes[obj.name] = set() + self._obj_name_to_nodes[obj.name].add(node) + if node.mesh is not None: + self._mesh_nodes.add(node) + if node.light is not None: + if isinstance(node.light, PointLight): + self._point_light_nodes.add(node) + if isinstance(node.light, SpotLight): + self._spot_light_nodes.add(node) + if isinstance(node.light, DirectionalLight): + self._directional_light_nodes.add(node) + if node.camera is not None: + self._camera_nodes.add(node) + if self._main_camera_node is None: + self._main_camera_node = node + + if parent_node is None: + parent_node = 'world' + elif parent_node not in self.nodes: + raise ValueError('Parent node must already be in scene') + elif node not in parent_node.children: + parent_node.children.append(node) + + # Create node in graph + self._digraph.add_node(node) + self._digraph.add_edge(node, parent_node) + + # Iterate over children + for child in node.children: + self.add_node(child, node) + + self._path_cache = {} + self._bounds = None + + def has_node(self, node): + """Check if a node is already in the scene. + + Parameters + ---------- + node : :class:`Node` + The node to be checked. + + Returns + ------- + has_node : bool + True if the node is already in the scene and false otherwise. + """ + return node in self.nodes + + def remove_node(self, node): + """Remove a node and all its children from the scene. + + Parameters + ---------- + node : :class:`Node` + The node to be removed. + """ + # Disconnect self from parent who is staying in the graph + parent = list(self._digraph.neighbors(node))[0] + self._remove_node(node) + if isinstance(parent, Node): + parent.children.remove(node) + self._path_cache = {} + self._bounds = None + + def get_pose(self, node): + """Get the world-frame pose of a node in the scene. + + Parameters + ---------- + node : :class:`Node` + The node to find the pose of. + + Returns + ------- + pose : (4,4) float + The transform matrix for this node. + """ + if node not in self.nodes: + raise ValueError('Node must already be in scene') + if node in self._path_cache: + path = self._path_cache[node] + else: + # Get path from from_frame to to_frame + path = nx.shortest_path(self._digraph, node, 'world') + self._path_cache[node] = path + + # Traverse from from_node to to_node + pose = np.eye(4) + for n in path[:-1]: + pose = np.dot(n.matrix, pose) + + return pose + + def set_pose(self, node, pose): + """Set the local-frame pose of a node in the scene. + + Parameters + ---------- + node : :class:`Node` + The node to set the pose of. + pose : (4,4) float + The pose to set the node to. + """ + if node not in self.nodes: + raise ValueError('Node must already be in scene') + node._matrix = pose + if node.mesh is not None: + self._bounds = None + + def clear(self): + """Clear out all nodes to form an empty scene. + """ + self._nodes = set() + + self._name_to_nodes = {} + self._obj_to_nodes = {} + self._obj_name_to_nodes = {} + self._mesh_nodes = set() + self._point_light_nodes = set() + self._spot_light_nodes = set() + self._directional_light_nodes = set() + self._camera_nodes = set() + self._main_camera_node = None + self._bounds = None + + # Transform tree + self._digraph = nx.DiGraph() + self._digraph.add_node('world') + self._path_cache = {} + + def _remove_node(self, node): + """Remove a node and all its children from the scene. + + Parameters + ---------- + node : :class:`Node` + The node to be removed. + """ + + # Remove self from nodes + self.nodes.remove(node) + + # Remove children + for child in node.children: + self._remove_node(child) + + # Remove self from the graph + self._digraph.remove_node(node) + + # Remove from maps + if node.name in self._name_to_nodes: + self._name_to_nodes[node.name].remove(node) + if len(self._name_to_nodes[node.name]) == 0: + self._name_to_nodes.pop(node.name) + for obj in [node.mesh, node.camera, node.light]: + if obj is None: + continue + self._obj_to_nodes[obj].remove(node) + if len(self._obj_to_nodes[obj]) == 0: + self._obj_to_nodes.pop(obj) + if obj.name is not None: + self._obj_name_to_nodes[obj.name].remove(node) + if len(self._obj_name_to_nodes[obj.name]) == 0: + self._obj_name_to_nodes.pop(obj.name) + if node.mesh is not None: + self._mesh_nodes.remove(node) + if node.light is not None: + if isinstance(node.light, PointLight): + self._point_light_nodes.remove(node) + if isinstance(node.light, SpotLight): + self._spot_light_nodes.remove(node) + if isinstance(node.light, DirectionalLight): + self._directional_light_nodes.remove(node) + if node.camera is not None: + self._camera_nodes.remove(node) + if self._main_camera_node == node: + if len(self._camera_nodes) > 0: + self._main_camera_node = next(iter(self._camera_nodes)) + else: + self._main_camera_node = None + + @staticmethod + def from_trimesh_scene(trimesh_scene, + bg_color=None, ambient_light=None): + """Create a :class:`.Scene` from a :class:`trimesh.scene.scene.Scene`. + + Parameters + ---------- + trimesh_scene : :class:`trimesh.scene.scene.Scene` + Scene with :class:~`trimesh.base.Trimesh` objects. + bg_color : (4,) float + Background color for the created scene. + ambient_light : (3,) float or None + Ambient light in the scene. + + Returns + ------- + scene_pr : :class:`Scene` + A scene containing the same geometry as the trimesh scene. + """ + # convert trimesh geometries to pyrender geometries + geometries = {name: Mesh.from_trimesh(geom) + for name, geom in trimesh_scene.geometry.items()} + + # create the pyrender scene object + scene_pr = Scene(bg_color=bg_color, ambient_light=ambient_light) + + # add every node with geometry to the pyrender scene + for node in trimesh_scene.graph.nodes_geometry: + pose, geom_name = trimesh_scene.graph[node] + scene_pr.add(geometries[geom_name], pose=pose) + + return scene_pr diff --git a/pyrender/pyrender/shader_program.py b/pyrender/pyrender/shader_program.py new file mode 100644 index 0000000000000000000000000000000000000000..c1803f280c98033abe0769771a9ad8ecfec942e3 --- /dev/null +++ b/pyrender/pyrender/shader_program.py @@ -0,0 +1,283 @@ +"""OpenGL shader program wrapper. +""" +import numpy as np +import os +import re + +import OpenGL +from OpenGL.GL import * +from OpenGL.GL import shaders as gl_shader_utils + + +class ShaderProgramCache(object): + """A cache for shader programs. + """ + + def __init__(self, shader_dir=None): + self._program_cache = {} + self.shader_dir = shader_dir + if self.shader_dir is None: + base_dir, _ = os.path.split(os.path.realpath(__file__)) + self.shader_dir = os.path.join(base_dir, 'shaders') + + def get_program(self, vertex_shader, fragment_shader, + geometry_shader=None, defines=None): + """Get a program via a list of shader files to include in the program. + + Parameters + ---------- + vertex_shader : str + The vertex shader filename. + fragment_shader : str + The fragment shader filename. + geometry_shader : str + The geometry shader filename. + defines : dict + Defines and their values for the shader. + + Returns + ------- + program : :class:`.ShaderProgram` + The program. + """ + shader_names = [] + if defines is None: + defines = {} + shader_filenames = [ + x for x in [vertex_shader, fragment_shader, geometry_shader] + if x is not None + ] + for fn in shader_filenames: + if fn is None: + continue + _, name = os.path.split(fn) + shader_names.append(name) + cid = OpenGL.contextdata.getContext() + key = tuple([cid] + sorted( + [(s,1) for s in shader_names] + [(d, defines[d]) for d in defines] + )) + + if key not in self._program_cache: + shader_filenames = [ + os.path.join(self.shader_dir, fn) for fn in shader_filenames + ] + if len(shader_filenames) == 2: + shader_filenames.append(None) + vs, fs, gs = shader_filenames + self._program_cache[key] = ShaderProgram( + vertex_shader=vs, fragment_shader=fs, + geometry_shader=gs, defines=defines + ) + return self._program_cache[key] + + def clear(self): + for key in self._program_cache: + self._program_cache[key].delete() + self._program_cache = {} + + +class ShaderProgram(object): + """A thin wrapper about OpenGL shader programs that supports easy creation, + binding, and uniform-setting. + + Parameters + ---------- + vertex_shader : str + The vertex shader filename. + fragment_shader : str + The fragment shader filename. + geometry_shader : str + The geometry shader filename. + defines : dict + Defines and their values for the shader. + """ + + def __init__(self, vertex_shader, fragment_shader, + geometry_shader=None, defines=None): + + self.vertex_shader = vertex_shader + self.fragment_shader = fragment_shader + self.geometry_shader = geometry_shader + + self.defines = defines + if self.defines is None: + self.defines = {} + + self._program_id = None + self._vao_id = None # PYOPENGL BUG + + # DEBUG + # self._unif_map = {} + + def _add_to_context(self): + if self._program_id is not None: + raise ValueError('Shader program already in context') + shader_ids = [] + + # Load vert shader + shader_ids.append(gl_shader_utils.compileShader( + self._load(self.vertex_shader), GL_VERTEX_SHADER) + ) + # Load frag shader + shader_ids.append(gl_shader_utils.compileShader( + self._load(self.fragment_shader), GL_FRAGMENT_SHADER) + ) + # Load geometry shader + if self.geometry_shader is not None: + shader_ids.append(gl_shader_utils.compileShader( + self._load(self.geometry_shader), GL_GEOMETRY_SHADER) + ) + + # Bind empty VAO PYOPENGL BUG + if self._vao_id is None: + self._vao_id = glGenVertexArrays(1) + glBindVertexArray(self._vao_id) + + # Compile program + self._program_id = gl_shader_utils.compileProgram(*shader_ids) + + # Unbind empty VAO PYOPENGL BUG + glBindVertexArray(0) + + def _in_context(self): + return self._program_id is not None + + def _remove_from_context(self): + if self._program_id is not None: + glDeleteProgram(self._program_id) + glDeleteVertexArrays(1, [self._vao_id]) + self._program_id = None + self._vao_id = None + + def _load(self, shader_filename): + path, _ = os.path.split(shader_filename) + + with open(shader_filename) as f: + text = f.read() + + def ifdef(matchobj): + if matchobj.group(1) in self.defines: + return '#if 1' + else: + return '#if 0' + + def ifndef(matchobj): + if matchobj.group(1) in self.defines: + return '#if 0' + else: + return '#if 1' + + ifdef_regex = re.compile( + '#ifdef\\s+([a-zA-Z_][a-zA-Z_0-9]*)\\s*$', re.MULTILINE + ) + ifndef_regex = re.compile( + '#ifndef\\s+([a-zA-Z_][a-zA-Z_0-9]*)\\s*$', re.MULTILINE + ) + text = re.sub(ifdef_regex, ifdef, text) + text = re.sub(ifndef_regex, ifndef, text) + + for define in self.defines: + value = str(self.defines[define]) + text = text.replace(define, value) + + return text + + def _bind(self): + """Bind this shader program to the current OpenGL context. + """ + if self._program_id is None: + raise ValueError('Cannot bind program that is not in context') + # glBindVertexArray(self._vao_id) + glUseProgram(self._program_id) + + def _unbind(self): + """Unbind this shader program from the current OpenGL context. + """ + glUseProgram(0) + + def delete(self): + """Delete this shader program from the current OpenGL context. + """ + self._remove_from_context() + + def set_uniform(self, name, value, unsigned=False): + """Set a uniform value in the current shader program. + + Parameters + ---------- + name : str + Name of the uniform to set. + value : int, float, or ndarray + Value to set the uniform to. + unsigned : bool + If True, ints will be treated as unsigned values. + """ + try: + # DEBUG + # self._unif_map[name] = 1, (1,) + loc = glGetUniformLocation(self._program_id, name) + + if loc == -1: + raise ValueError('Invalid shader variable: {}'.format(name)) + + if isinstance(value, np.ndarray): + # DEBUG + # self._unif_map[name] = value.size, value.shape + if value.ndim == 1: + if (np.issubdtype(value.dtype, np.unsignedinteger) or + unsigned): + dtype = 'u' + value = value.astype(np.uint32) + elif np.issubdtype(value.dtype, np.integer): + dtype = 'i' + value = value.astype(np.int32) + else: + dtype = 'f' + value = value.astype(np.float32) + self._FUNC_MAP[(value.shape[0], dtype)](loc, 1, value) + else: + self._FUNC_MAP[(value.shape[0], value.shape[1])]( + loc, 1, GL_TRUE, value + ) + + # Call correct uniform function + elif isinstance(value, float): + glUniform1f(loc, value) + elif isinstance(value, int): + if unsigned: + glUniform1ui(loc, value) + else: + glUniform1i(loc, value) + elif isinstance(value, bool): + if unsigned: + glUniform1ui(loc, int(value)) + else: + glUniform1i(loc, int(value)) + else: + raise ValueError('Invalid data type') + except Exception: + pass + + _FUNC_MAP = { + (1,'u'): glUniform1uiv, + (2,'u'): glUniform2uiv, + (3,'u'): glUniform3uiv, + (4,'u'): glUniform4uiv, + (1,'i'): glUniform1iv, + (2,'i'): glUniform2iv, + (3,'i'): glUniform3iv, + (4,'i'): glUniform4iv, + (1,'f'): glUniform1fv, + (2,'f'): glUniform2fv, + (3,'f'): glUniform3fv, + (4,'f'): glUniform4fv, + (2,2): glUniformMatrix2fv, + (2,3): glUniformMatrix2x3fv, + (2,4): glUniformMatrix2x4fv, + (3,2): glUniformMatrix3x2fv, + (3,3): glUniformMatrix3fv, + (3,4): glUniformMatrix3x4fv, + (4,2): glUniformMatrix4x2fv, + (4,3): glUniformMatrix4x3fv, + (4,4): glUniformMatrix4fv, + } diff --git a/pyrender/pyrender/shaders/debug_quad.frag b/pyrender/pyrender/shaders/debug_quad.frag new file mode 100644 index 0000000000000000000000000000000000000000..4647bb50dfa1e4510e2d4afb37959c7f57532eca --- /dev/null +++ b/pyrender/pyrender/shaders/debug_quad.frag @@ -0,0 +1,23 @@ +#version 330 core +out vec4 FragColor; + +in vec2 TexCoords; + +uniform sampler2D depthMap; +//uniform float near_plane; +//uniform float far_plane; +// +//// required when using a perspective projection matrix +//float LinearizeDepth(float depth) +//{ +// float z = depth * 2.0 - 1.0; // Back to NDC +// return (2.0 * near_plane * far_plane) / (far_plane + near_plane - z * (far_plane - near_plane)); +//} + +void main() +{ + float depthValue = texture(depthMap, TexCoords).r; + // FragColor = vec4(vec3(LinearizeDepth(depthValue) / far_plane), 1.0); // perspective + FragColor = vec4(vec3(depthValue), 1.0); // orthographic + //FragColor = vec4(1.0, 1.0, 0.0, 1.0); +} diff --git a/pyrender/pyrender/shaders/debug_quad.vert b/pyrender/pyrender/shaders/debug_quad.vert new file mode 100644 index 0000000000000000000000000000000000000000..d2f2fcb7626f6c22e0d52bf4d6c91251cbdb9f52 --- /dev/null +++ b/pyrender/pyrender/shaders/debug_quad.vert @@ -0,0 +1,25 @@ +#version 330 core +//layout (location = 0) in vec3 aPos; +//layout (location = 1) in vec2 aTexCoords; +// +//out vec2 TexCoords; +// +//void main() +//{ +// TexCoords = aTexCoords; +// gl_Position = vec4(aPos, 1.0); +//} +// +// +//layout(location = 0) out vec2 uv; + +out vec2 TexCoords; + +void main() +{ + float x = float(((uint(gl_VertexID) + 2u) / 3u)%2u); + float y = float(((uint(gl_VertexID) + 1u) / 3u)%2u); + + gl_Position = vec4(-1.0f + x*2.0f, -1.0f+y*2.0f, 0.0f, 1.0f); + TexCoords = vec2(x, y); +} diff --git a/pyrender/pyrender/shaders/flat.frag b/pyrender/pyrender/shaders/flat.frag new file mode 100644 index 0000000000000000000000000000000000000000..7ec01c6d095ec5dacc693accd3ad507ced61a79a --- /dev/null +++ b/pyrender/pyrender/shaders/flat.frag @@ -0,0 +1,126 @@ +#version 330 core +/////////////////////////////////////////////////////////////////////////////// +// Structs +/////////////////////////////////////////////////////////////////////////////// + +struct Material { + vec3 emissive_factor; + +#ifdef USE_METALLIC_MATERIAL + vec4 base_color_factor; + float metallic_factor; + float roughness_factor; +#endif + +#ifdef USE_GLOSSY_MATERIAL + vec4 diffuse_factor; + vec3 specular_factor; + float glossiness_factor; +#endif + +#ifdef HAS_NORMAL_TEX + sampler2D normal_texture; +#endif +#ifdef HAS_OCCLUSION_TEX + sampler2D occlusion_texture; +#endif +#ifdef HAS_EMISSIVE_TEX + sampler2D emissive_texture; +#endif +#ifdef HAS_BASE_COLOR_TEX + sampler2D base_color_texture; +#endif +#ifdef HAS_METALLIC_ROUGHNESS_TEX + sampler2D metallic_roughness_texture; +#endif +#ifdef HAS_DIFFUSE_TEX + sampler2D diffuse_texture; +#endif +#ifdef HAS_SPECULAR_GLOSSINESS_TEX + sampler2D specular_glossiness; +#endif +}; + +/////////////////////////////////////////////////////////////////////////////// +// Uniforms +/////////////////////////////////////////////////////////////////////////////// +uniform Material material; +uniform vec3 cam_pos; + +#ifdef USE_IBL +uniform samplerCube diffuse_env; +uniform samplerCube specular_env; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Inputs +/////////////////////////////////////////////////////////////////////////////// + +in vec3 frag_position; +#ifdef NORMAL_LOC +in vec3 frag_normal; +#endif +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC +in mat3 tbn; +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC +in vec2 uv_0; +#endif +#ifdef TEXCOORD_1_LOC +in vec2 uv_1; +#endif +#ifdef COLOR_0_LOC +in vec4 color_multiplier; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// OUTPUTS +/////////////////////////////////////////////////////////////////////////////// + +out vec4 frag_color; + +/////////////////////////////////////////////////////////////////////////////// +// Constants +/////////////////////////////////////////////////////////////////////////////// +const float PI = 3.141592653589793; +const float min_roughness = 0.04; + +/////////////////////////////////////////////////////////////////////////////// +// Utility Functions +/////////////////////////////////////////////////////////////////////////////// +vec4 srgb_to_linear(vec4 srgb) +{ +#ifndef SRGB_CORRECTED + // Fast Approximation + //vec3 linOut = pow(srgbIn.xyz,vec3(2.2)); + // + vec3 b_less = step(vec3(0.04045),srgb.xyz); + vec3 lin_out = mix( srgb.xyz/vec3(12.92), pow((srgb.xyz+vec3(0.055))/vec3(1.055),vec3(2.4)), b_less ); + return vec4(lin_out, srgb.w); +#else + return srgb; +#endif +} + +/////////////////////////////////////////////////////////////////////////////// +// MAIN +/////////////////////////////////////////////////////////////////////////////// +void main() +{ + + // Compute albedo + vec4 base_color = material.base_color_factor; +#ifdef HAS_BASE_COLOR_TEX + base_color = base_color * texture(material.base_color_texture, uv_0); +#endif + +#ifdef COLOR_0_LOC + base_color *= color_multiplier; +#endif + + frag_color = clamp(base_color, 0.0, 1.0); +} diff --git a/pyrender/pyrender/shaders/flat.vert b/pyrender/pyrender/shaders/flat.vert new file mode 100644 index 0000000000000000000000000000000000000000..cfd241c3544718a261f961c3aa3c03aa13c97761 --- /dev/null +++ b/pyrender/pyrender/shaders/flat.vert @@ -0,0 +1,86 @@ +#version 330 core + +// Vertex Attributes +layout(location = 0) in vec3 position; +#ifdef NORMAL_LOC +layout(location = NORMAL_LOC) in vec3 normal; +#endif +#ifdef TANGENT_LOC +layout(location = TANGENT_LOC) in vec4 tangent; +#endif +#ifdef TEXCOORD_0_LOC +layout(location = TEXCOORD_0_LOC) in vec2 texcoord_0; +#endif +#ifdef TEXCOORD_1_LOC +layout(location = TEXCOORD_1_LOC) in vec2 texcoord_1; +#endif +#ifdef COLOR_0_LOC +layout(location = COLOR_0_LOC) in vec4 color_0; +#endif +#ifdef JOINTS_0_LOC +layout(location = JOINTS_0_LOC) in vec4 joints_0; +#endif +#ifdef WEIGHTS_0_LOC +layout(location = WEIGHTS_0_LOC) in vec4 weights_0; +#endif +layout(location = INST_M_LOC) in mat4 inst_m; + +// Uniforms +uniform mat4 M; +uniform mat4 V; +uniform mat4 P; + +// Outputs +out vec3 frag_position; +#ifdef NORMAL_LOC +out vec3 frag_normal; +#endif +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC +out mat3 tbn; +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC +out vec2 uv_0; +#endif +#ifdef TEXCOORD_1_LOC +out vec2 uv_1; +#endif +#ifdef COLOR_0_LOC +out vec4 color_multiplier; +#endif + + +void main() +{ + gl_Position = P * V * M * inst_m * vec4(position, 1); + frag_position = vec3(M * inst_m * vec4(position, 1.0)); + + mat4 N = transpose(inverse(M * inst_m)); + +#ifdef NORMAL_LOC + frag_normal = normalize(vec3(N * vec4(normal, 0.0))); +#endif + +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC + vec3 normal_w = normalize(vec3(N * vec4(normal, 0.0))); + vec3 tangent_w = normalize(vec3(N * vec4(tangent.xyz, 0.0))); + vec3 bitangent_w = cross(normal_w, tangent_w) * tangent.w; + tbn = mat3(tangent_w, bitangent_w, normal_w); +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC + uv_0 = texcoord_0; +#endif +#ifdef TEXCOORD_1_LOC + uv_1 = texcoord_1; +#endif +#ifdef COLOR_0_LOC + color_multiplier = color_0; +#endif +} diff --git a/pyrender/pyrender/shaders/mesh.frag b/pyrender/pyrender/shaders/mesh.frag new file mode 100644 index 0000000000000000000000000000000000000000..43187621b4388b18badf4e562a7ad300e59b029d --- /dev/null +++ b/pyrender/pyrender/shaders/mesh.frag @@ -0,0 +1,456 @@ +#version 330 core +/////////////////////////////////////////////////////////////////////////////// +// Structs +/////////////////////////////////////////////////////////////////////////////// + +struct SpotLight { + vec3 color; + float intensity; + float range; + vec3 position; + vec3 direction; + float light_angle_scale; + float light_angle_offset; + + #ifdef SPOT_LIGHT_SHADOWS + sampler2D shadow_map; + mat4 light_matrix; + #endif +}; + +struct DirectionalLight { + vec3 color; + float intensity; + vec3 direction; + + #ifdef DIRECTIONAL_LIGHT_SHADOWS + sampler2D shadow_map; + mat4 light_matrix; + #endif +}; + +struct PointLight { + vec3 color; + float intensity; + float range; + vec3 position; + + #ifdef POINT_LIGHT_SHADOWS + samplerCube shadow_map; + #endif +}; + +struct Material { + vec3 emissive_factor; + +#ifdef USE_METALLIC_MATERIAL + vec4 base_color_factor; + float metallic_factor; + float roughness_factor; +#endif + +#ifdef USE_GLOSSY_MATERIAL + vec4 diffuse_factor; + vec3 specular_factor; + float glossiness_factor; +#endif + +#ifdef HAS_NORMAL_TEX + sampler2D normal_texture; +#endif +#ifdef HAS_OCCLUSION_TEX + sampler2D occlusion_texture; +#endif +#ifdef HAS_EMISSIVE_TEX + sampler2D emissive_texture; +#endif +#ifdef HAS_BASE_COLOR_TEX + sampler2D base_color_texture; +#endif +#ifdef HAS_METALLIC_ROUGHNESS_TEX + sampler2D metallic_roughness_texture; +#endif +#ifdef HAS_DIFFUSE_TEX + sampler2D diffuse_texture; +#endif +#ifdef HAS_SPECULAR_GLOSSINESS_TEX + sampler2D specular_glossiness; +#endif +}; + +struct PBRInfo { + float nl; + float nv; + float nh; + float lh; + float vh; + float roughness; + float metallic; + vec3 f0; + vec3 c_diff; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Uniforms +/////////////////////////////////////////////////////////////////////////////// +uniform Material material; +uniform PointLight point_lights[MAX_POINT_LIGHTS]; +uniform int n_point_lights; +uniform DirectionalLight directional_lights[MAX_DIRECTIONAL_LIGHTS]; +uniform int n_directional_lights; +uniform SpotLight spot_lights[MAX_SPOT_LIGHTS]; +uniform int n_spot_lights; +uniform vec3 cam_pos; +uniform vec3 ambient_light; + +#ifdef USE_IBL +uniform samplerCube diffuse_env; +uniform samplerCube specular_env; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Inputs +/////////////////////////////////////////////////////////////////////////////// + +in vec3 frag_position; +#ifdef NORMAL_LOC +in vec3 frag_normal; +#endif +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC +in mat3 tbn; +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC +in vec2 uv_0; +#endif +#ifdef TEXCOORD_1_LOC +in vec2 uv_1; +#endif +#ifdef COLOR_0_LOC +in vec4 color_multiplier; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// OUTPUTS +/////////////////////////////////////////////////////////////////////////////// + +out vec4 frag_color; + +/////////////////////////////////////////////////////////////////////////////// +// Constants +/////////////////////////////////////////////////////////////////////////////// +const float PI = 3.141592653589793; +const float min_roughness = 0.04; + +/////////////////////////////////////////////////////////////////////////////// +// Utility Functions +/////////////////////////////////////////////////////////////////////////////// +vec4 srgb_to_linear(vec4 srgb) +{ +#ifndef SRGB_CORRECTED + // Fast Approximation + //vec3 linOut = pow(srgbIn.xyz,vec3(2.2)); + // + vec3 b_less = step(vec3(0.04045),srgb.xyz); + vec3 lin_out = mix( srgb.xyz/vec3(12.92), pow((srgb.xyz+vec3(0.055))/vec3(1.055),vec3(2.4)), b_less ); + return vec4(lin_out, srgb.w); +#else + return srgb; +#endif +} + +// Normal computation +vec3 get_normal() +{ +#ifdef HAS_NORMAL_TEX + +#ifndef HAS_TANGENTS + vec3 pos_dx = dFdx(frag_position); + vec3 pos_dy = dFdy(frag_position); + vec3 tex_dx = dFdx(vec3(uv_0, 0.0)); + vec3 tex_dy = dFdy(vec3(uv_0, 0.0)); + vec3 t = (tex_dy.t * pos_dx - tex_dx.t * pos_dy) / (tex_dx.s * tex_dy.t - tex_dy.s * tex_dx.t); + +#ifdef NORMAL_LOC + vec3 ng = normalize(frag_normal); +#else + vec3 = cross(pos_dx, pos_dy); +#endif + + t = normalize(t - ng * dot(ng, t)); + vec3 b = normalize(cross(ng, t)); + mat3 tbn_n = mat3(t, b, ng); + +#else + + mat3 tbn_n = tbn; + +#endif + + vec3 n = texture(material.normal_texture, uv_0).rgb; + n = normalize(tbn_n * ((2.0 * n - 1.0) * vec3(1.0, 1.0, 1.0))); + return n; // TODO NORMAL MAPPING + +#else + +#ifdef NORMAL_LOC + return frag_normal; +#else + return normalize(cam_pos - frag_position); +#endif + +#endif +} + +// Fresnel +vec3 specular_reflection(PBRInfo info) +{ + vec3 res = info.f0 + (1.0 - info.f0) * pow(clamp(1.0 - info.vh, 0.0, 1.0), 5.0); + return res; +} + +// Smith +float geometric_occlusion(PBRInfo info) +{ + float r = info.roughness + 1.0; + float k = r * r / 8.0; + float g1 = info.nv / (info.nv * (1.0 - k) + k); + float g2 = info.nl / (info.nl * (1.0 - k) + k); + //float k = info.roughness * sqrt(2.0 / PI); + //float g1 = info.lh / (info.lh * (1.0 - k) + k); + //float g2 = info.nh / (info.nh * (1.0 - k) + k); + return g1 * g2; +} + +float microfacet_distribution(PBRInfo info) +{ + float a = info.roughness * info.roughness; + float a2 = a * a; + float nh2 = info.nh * info.nh; + + float denom = (nh2 * (a2 - 1.0) + 1.0); + return a2 / (PI * denom * denom); +} + +vec3 compute_brdf(vec3 n, vec3 v, vec3 l, + float roughness, float metalness, + vec3 f0, vec3 c_diff, vec3 albedo, + vec3 radiance) +{ + vec3 h = normalize(l+v); + float nl = clamp(dot(n, l), 0.001, 1.0); + float nv = clamp(abs(dot(n, v)), 0.001, 1.0); + float nh = clamp(dot(n, h), 0.0, 1.0); + float lh = clamp(dot(l, h), 0.0, 1.0); + float vh = clamp(dot(v, h), 0.0, 1.0); + + PBRInfo info = PBRInfo(nl, nv, nh, lh, vh, roughness, metalness, f0, c_diff); + + // Compute PBR terms + vec3 F = specular_reflection(info); + float G = geometric_occlusion(info); + float D = microfacet_distribution(info); + + // Compute BRDF + vec3 diffuse_contrib = (1.0 - F) * c_diff / PI; + vec3 spec_contrib = F * G * D / (4.0 * nl * nv + 0.001); + + vec3 color = nl * radiance * (diffuse_contrib + spec_contrib); + return color; +} + +float texture2DCompare(sampler2D depths, vec2 uv, float compare) { + return compare > texture(depths, uv.xy).r ? 1.0 : 0.0; +} + +float texture2DShadowLerp(sampler2D depths, vec2 size, vec2 uv, float compare) { + vec2 texelSize = vec2(1.0)/size; + vec2 f = fract(uv*size+0.5); + vec2 centroidUV = floor(uv*size+0.5)/size; + + float lb = texture2DCompare(depths, centroidUV+texelSize*vec2(0.0, 0.0), compare); + float lt = texture2DCompare(depths, centroidUV+texelSize*vec2(0.0, 1.0), compare); + float rb = texture2DCompare(depths, centroidUV+texelSize*vec2(1.0, 0.0), compare); + float rt = texture2DCompare(depths, centroidUV+texelSize*vec2(1.0, 1.0), compare); + float a = mix(lb, lt, f.y); + float b = mix(rb, rt, f.y); + float c = mix(a, b, f.x); + return c; +} + +float PCF(sampler2D depths, vec2 size, vec2 uv, float compare){ + float result = 0.0; + for(int x=-1; x<=1; x++){ + for(int y=-1; y<=1; y++){ + vec2 off = vec2(x,y)/size; + result += texture2DShadowLerp(depths, size, uv+off, compare); + } + } + return result/9.0; +} + +float shadow_calc(mat4 light_matrix, sampler2D shadow_map, float nl) +{ + // Compute light texture UV coords + vec4 proj_coords = vec4(light_matrix * vec4(frag_position.xyz, 1.0)); + vec3 light_coords = proj_coords.xyz / proj_coords.w; + light_coords = light_coords * 0.5 + 0.5; + float current_depth = light_coords.z; + float bias = max(0.001 * (1.0 - nl), 0.0001) / proj_coords.w; + float compare = (current_depth - bias); + float shadow = PCF(shadow_map, textureSize(shadow_map, 0), light_coords.xy, compare); + if (light_coords.z > 1.0) { + shadow = 0.0; + } + return shadow; +} + +/////////////////////////////////////////////////////////////////////////////// +// MAIN +/////////////////////////////////////////////////////////////////////////////// +void main() +{ + + vec4 color = vec4(vec3(0.0), 1.0); +/////////////////////////////////////////////////////////////////////////////// +// Handle Metallic Materials +/////////////////////////////////////////////////////////////////////////////// +#ifdef USE_METALLIC_MATERIAL + + // Compute metallic/roughness factors + float roughness = material.roughness_factor; + float metallic = material.metallic_factor; +#ifdef HAS_METALLIC_ROUGHNESS_TEX + vec2 mr = texture(material.metallic_roughness_texture, uv_0).rg; + roughness = roughness * mr.r; + metallic = metallic * mr.g; +#endif + roughness = clamp(roughness, min_roughness, 1.0); + metallic = clamp(metallic, 0.0, 1.0); + // In convention, material roughness is perceputal roughness ^ 2 + float alpha_roughness = roughness * roughness; + + // Compute albedo + vec4 base_color = material.base_color_factor; +#ifdef HAS_BASE_COLOR_TEX + base_color = base_color * srgb_to_linear(texture(material.base_color_texture, uv_0)); +#endif + + // Compute specular and diffuse colors + vec3 dialectric_spec = vec3(min_roughness); + vec3 c_diff = mix(vec3(0.0), base_color.rgb * (1 - min_roughness), 1.0 - metallic); + vec3 f0 = mix(dialectric_spec, base_color.rgb, metallic); + + // Compute normal + vec3 n = normalize(get_normal()); + + // Loop over lights + for (int i = 0; i < n_directional_lights; i++) { + vec3 direction = directional_lights[i].direction; + vec3 v = normalize(cam_pos - frag_position); // Vector towards camera + vec3 l = normalize(-1.0 * direction); // Vector towards light + + // Compute attenuation and radiance + float attenuation = directional_lights[i].intensity; + vec3 radiance = attenuation * directional_lights[i].color; + + // Compute outbound color + vec3 res = compute_brdf(n, v, l, roughness, metallic, + f0, c_diff, base_color.rgb, radiance); + + // Compute shadow +#ifdef DIRECTIONAL_LIGHT_SHADOWS + float nl = clamp(dot(n,l), 0.0, 1.0); + float shadow = shadow_calc( + directional_lights[i].light_matrix, + directional_lights[i].shadow_map, + nl + ); + res = res * (1.0 - shadow); +#endif + color.xyz += res; + } + + for (int i = 0; i < n_point_lights; i++) { + vec3 position = point_lights[i].position; + vec3 v = normalize(cam_pos - frag_position); // Vector towards camera + vec3 l = normalize(position - frag_position); // Vector towards light + + // Compute attenuation and radiance + float dist = length(position - frag_position); + float attenuation = point_lights[i].intensity / (dist * dist); + vec3 radiance = attenuation * point_lights[i].color; + + // Compute outbound color + vec3 res = compute_brdf(n, v, l, roughness, metallic, + f0, c_diff, base_color.rgb, radiance); + color.xyz += res; + } + for (int i = 0; i < n_spot_lights; i++) { + vec3 position = spot_lights[i].position; + vec3 v = normalize(cam_pos - frag_position); // Vector towards camera + vec3 l = normalize(position - frag_position); // Vector towards light + + // Compute attenuation and radiance + vec3 direction = spot_lights[i].direction; + float las = spot_lights[i].light_angle_scale; + float lao = spot_lights[i].light_angle_offset; + float dist = length(position - frag_position); + float cd = clamp(dot(direction, -l), 0.0, 1.0); + float attenuation = clamp(cd * las + lao, 0.0, 1.0); + attenuation = attenuation * attenuation * spot_lights[i].intensity; + attenuation = attenuation / (dist * dist); + vec3 radiance = attenuation * spot_lights[i].color; + + // Compute outbound color + vec3 res = compute_brdf(n, v, l, roughness, metallic, + f0, c_diff, base_color.rgb, radiance); +#ifdef SPOT_LIGHT_SHADOWS + float nl = clamp(dot(n,l), 0.0, 1.0); + float shadow = shadow_calc( + spot_lights[i].light_matrix, + spot_lights[i].shadow_map, + nl + ); + res = res * (1.0 - shadow); +#endif + color.xyz += res; + } + color.xyz += base_color.xyz * ambient_light; + + // Calculate lighting from environment +#ifdef USE_IBL + // TODO +#endif + + // Apply occlusion +#ifdef HAS_OCCLUSION_TEX + float ao = texture(material.occlusion_texture, uv_0).r; + color.xyz *= ao; +#endif + + // Apply emissive map + vec3 emissive = material.emissive_factor; +#ifdef HAS_EMISSIVE_TEX + emissive *= srgb_to_linear(texture(material.emissive_texture, uv_0)).rgb; +#endif + color.xyz += emissive * material.emissive_factor; + +#ifdef COLOR_0_LOC + color *= color_multiplier; +#endif + + frag_color = clamp(vec4(pow(color.xyz, vec3(1.0/2.2)), color.a * base_color.a), 0.0, 1.0); + +#else + // TODO GLOSSY MATERIAL BRDF +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Handle Glossy Materials +/////////////////////////////////////////////////////////////////////////////// + +} diff --git a/pyrender/pyrender/shaders/mesh.vert b/pyrender/pyrender/shaders/mesh.vert new file mode 100644 index 0000000000000000000000000000000000000000..cfd241c3544718a261f961c3aa3c03aa13c97761 --- /dev/null +++ b/pyrender/pyrender/shaders/mesh.vert @@ -0,0 +1,86 @@ +#version 330 core + +// Vertex Attributes +layout(location = 0) in vec3 position; +#ifdef NORMAL_LOC +layout(location = NORMAL_LOC) in vec3 normal; +#endif +#ifdef TANGENT_LOC +layout(location = TANGENT_LOC) in vec4 tangent; +#endif +#ifdef TEXCOORD_0_LOC +layout(location = TEXCOORD_0_LOC) in vec2 texcoord_0; +#endif +#ifdef TEXCOORD_1_LOC +layout(location = TEXCOORD_1_LOC) in vec2 texcoord_1; +#endif +#ifdef COLOR_0_LOC +layout(location = COLOR_0_LOC) in vec4 color_0; +#endif +#ifdef JOINTS_0_LOC +layout(location = JOINTS_0_LOC) in vec4 joints_0; +#endif +#ifdef WEIGHTS_0_LOC +layout(location = WEIGHTS_0_LOC) in vec4 weights_0; +#endif +layout(location = INST_M_LOC) in mat4 inst_m; + +// Uniforms +uniform mat4 M; +uniform mat4 V; +uniform mat4 P; + +// Outputs +out vec3 frag_position; +#ifdef NORMAL_LOC +out vec3 frag_normal; +#endif +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC +out mat3 tbn; +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC +out vec2 uv_0; +#endif +#ifdef TEXCOORD_1_LOC +out vec2 uv_1; +#endif +#ifdef COLOR_0_LOC +out vec4 color_multiplier; +#endif + + +void main() +{ + gl_Position = P * V * M * inst_m * vec4(position, 1); + frag_position = vec3(M * inst_m * vec4(position, 1.0)); + + mat4 N = transpose(inverse(M * inst_m)); + +#ifdef NORMAL_LOC + frag_normal = normalize(vec3(N * vec4(normal, 0.0))); +#endif + +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC + vec3 normal_w = normalize(vec3(N * vec4(normal, 0.0))); + vec3 tangent_w = normalize(vec3(N * vec4(tangent.xyz, 0.0))); + vec3 bitangent_w = cross(normal_w, tangent_w) * tangent.w; + tbn = mat3(tangent_w, bitangent_w, normal_w); +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC + uv_0 = texcoord_0; +#endif +#ifdef TEXCOORD_1_LOC + uv_1 = texcoord_1; +#endif +#ifdef COLOR_0_LOC + color_multiplier = color_0; +#endif +} diff --git a/pyrender/pyrender/shaders/mesh_depth.frag b/pyrender/pyrender/shaders/mesh_depth.frag new file mode 100644 index 0000000000000000000000000000000000000000..d8b1fac6091cfa457ba835ae0758e955f06d8754 --- /dev/null +++ b/pyrender/pyrender/shaders/mesh_depth.frag @@ -0,0 +1,8 @@ +#version 330 core + +out vec4 frag_color; + +void main() +{ + frag_color = vec4(1.0); +} diff --git a/pyrender/pyrender/shaders/mesh_depth.vert b/pyrender/pyrender/shaders/mesh_depth.vert new file mode 100644 index 0000000000000000000000000000000000000000..e534c058fb3e7b0efbec090513d55982db68ccaf --- /dev/null +++ b/pyrender/pyrender/shaders/mesh_depth.vert @@ -0,0 +1,13 @@ +#version 330 core +layout(location = 0) in vec3 position; +layout(location = INST_M_LOC) in mat4 inst_m; + +uniform mat4 P; +uniform mat4 V; +uniform mat4 M; + +void main() +{ + mat4 light_matrix = P * V; + gl_Position = light_matrix * M * inst_m * vec4(position, 1.0); +} diff --git a/pyrender/pyrender/shaders/segmentation.frag b/pyrender/pyrender/shaders/segmentation.frag new file mode 100644 index 0000000000000000000000000000000000000000..40deb92cbdef3ec9fd952632624cd5f4b5ce0c84 --- /dev/null +++ b/pyrender/pyrender/shaders/segmentation.frag @@ -0,0 +1,13 @@ +#version 330 core + +uniform vec3 color; +out vec4 frag_color; + +/////////////////////////////////////////////////////////////////////////////// +// MAIN +/////////////////////////////////////////////////////////////////////////////// +void main() +{ + frag_color = vec4(color, 1.0); + //frag_color = vec4(1.0, 0.5, 0.5, 1.0); +} diff --git a/pyrender/pyrender/shaders/segmentation.vert b/pyrender/pyrender/shaders/segmentation.vert new file mode 100644 index 0000000000000000000000000000000000000000..503382599dae3c9415845f35b99d6678cfc7f716 --- /dev/null +++ b/pyrender/pyrender/shaders/segmentation.vert @@ -0,0 +1,14 @@ +#version 330 core +layout(location = 0) in vec3 position; +layout(location = INST_M_LOC) in mat4 inst_m; + +uniform mat4 P; +uniform mat4 V; +uniform mat4 M; + +void main() +{ + mat4 light_matrix = P * V; + gl_Position = light_matrix * M * inst_m * vec4(position, 1.0); +} + diff --git a/pyrender/pyrender/shaders/text.frag b/pyrender/pyrender/shaders/text.frag new file mode 100644 index 0000000000000000000000000000000000000000..486c97dc94ed5e9083ae348bc1e85c5cb26c44dc --- /dev/null +++ b/pyrender/pyrender/shaders/text.frag @@ -0,0 +1,12 @@ +#version 330 core +in vec2 uv; +out vec4 color; + +uniform sampler2D text; +uniform vec4 text_color; + +void main() +{ + vec4 sampled = vec4(1.0, 1.0, 1.0, texture(text, uv).r); + color = text_color * sampled; +} diff --git a/pyrender/pyrender/shaders/text.vert b/pyrender/pyrender/shaders/text.vert new file mode 100644 index 0000000000000000000000000000000000000000..005bc439b3d63522df99e5db2088953eb8defcf4 --- /dev/null +++ b/pyrender/pyrender/shaders/text.vert @@ -0,0 +1,12 @@ +#version 330 core +layout (location = 0) in vec4 vertex; + +out vec2 uv; + +uniform mat4 projection; + +void main() +{ + gl_Position = projection * vec4(vertex.xy, 0.0, 1.0); + uv = vertex.zw; +} diff --git a/pyrender/pyrender/shaders/vertex_normals.frag b/pyrender/pyrender/shaders/vertex_normals.frag new file mode 100644 index 0000000000000000000000000000000000000000..edf5beb7f283dd67e1710bff922555539966cee4 --- /dev/null +++ b/pyrender/pyrender/shaders/vertex_normals.frag @@ -0,0 +1,10 @@ +#version 330 core + +out vec4 frag_color; + +uniform vec4 normal_color; + +void main() +{ + frag_color = normal_color; +} diff --git a/pyrender/pyrender/shaders/vertex_normals.geom b/pyrender/pyrender/shaders/vertex_normals.geom new file mode 100644 index 0000000000000000000000000000000000000000..57f0b0e645e72d41116f5767d66fc37d01ed2714 --- /dev/null +++ b/pyrender/pyrender/shaders/vertex_normals.geom @@ -0,0 +1,74 @@ +#version 330 core + +layout (triangles) in; + +#ifdef FACE_NORMALS + +#ifdef VERTEX_NORMALS + layout (line_strip, max_vertices = 8) out; +#else + layout (line_strip, max_vertices = 2) out; +#endif + +#else + + layout (line_strip, max_vertices = 6) out; + +#endif + +in VS_OUT { + vec3 position; + vec3 normal; + mat4 mvp; +} gs_in[]; + +uniform float normal_magnitude; + +void GenerateVertNormal(int index) +{ + + vec4 p0 = gs_in[index].mvp * vec4(gs_in[index].position, 1.0); + vec4 p1 = gs_in[index].mvp * vec4(normal_magnitude * normalize(gs_in[index].normal) + gs_in[index].position, 1.0); + gl_Position = p0; + EmitVertex(); + gl_Position = p1; + EmitVertex(); + EndPrimitive(); +} + +void GenerateFaceNormal() +{ + vec3 p0 = gs_in[0].position.xyz; + vec3 p1 = gs_in[1].position.xyz; + vec3 p2 = gs_in[2].position.xyz; + + vec3 v0 = p0 - p1; + vec3 v1 = p2 - p1; + + vec3 N = normalize(cross(v1, v0)); + vec3 P = (p0 + p1 + p2) / 3.0; + + vec4 np0 = gs_in[0].mvp * vec4(P, 1.0); + vec4 np1 = gs_in[0].mvp * vec4(normal_magnitude * N + P, 1.0); + + gl_Position = np0; + EmitVertex(); + gl_Position = np1; + EmitVertex(); + EndPrimitive(); +} + +void main() +{ + +#ifdef FACE_NORMALS + GenerateFaceNormal(); +#endif + +#ifdef VERTEX_NORMALS + GenerateVertNormal(0); + GenerateVertNormal(1); + GenerateVertNormal(2); +#endif + +} diff --git a/pyrender/pyrender/shaders/vertex_normals.vert b/pyrender/pyrender/shaders/vertex_normals.vert new file mode 100644 index 0000000000000000000000000000000000000000..be22eed2a0e904bcaf1ac5a4721558e574cddc62 --- /dev/null +++ b/pyrender/pyrender/shaders/vertex_normals.vert @@ -0,0 +1,27 @@ +#version 330 core + +// Inputs +layout(location = 0) in vec3 position; +layout(location = NORMAL_LOC) in vec3 normal; +layout(location = INST_M_LOC) in mat4 inst_m; + +// Output data +out VS_OUT { + vec3 position; + vec3 normal; + mat4 mvp; +} vs_out; + +// Uniform data +uniform mat4 M; +uniform mat4 V; +uniform mat4 P; + +// Render loop +void main() { + vs_out.mvp = P * V * M * inst_m; + vs_out.position = position; + vs_out.normal = normal; + + gl_Position = vec4(position, 1.0); +} diff --git a/pyrender/pyrender/shaders/vertex_normals_pc.geom b/pyrender/pyrender/shaders/vertex_normals_pc.geom new file mode 100644 index 0000000000000000000000000000000000000000..4ea4e7b8542703f64b8d28fd187e425137861fe4 --- /dev/null +++ b/pyrender/pyrender/shaders/vertex_normals_pc.geom @@ -0,0 +1,29 @@ +#version 330 core + +layout (points) in; + +layout (line_strip, max_vertices = 2) out; + +in VS_OUT { + vec3 position; + vec3 normal; + mat4 mvp; +} gs_in[]; + +uniform float normal_magnitude; + +void GenerateVertNormal(int index) +{ + vec4 p0 = gs_in[index].mvp * vec4(gs_in[index].position, 1.0); + vec4 p1 = gs_in[index].mvp * vec4(normal_magnitude * normalize(gs_in[index].normal) + gs_in[index].position, 1.0); + gl_Position = p0; + EmitVertex(); + gl_Position = p1; + EmitVertex(); + EndPrimitive(); +} + +void main() +{ + GenerateVertNormal(0); +} diff --git a/pyrender/pyrender/texture.py b/pyrender/pyrender/texture.py new file mode 100644 index 0000000000000000000000000000000000000000..477759729d7b995a4f276e81d649617d045a066e --- /dev/null +++ b/pyrender/pyrender/texture.py @@ -0,0 +1,259 @@ +"""Textures, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-texture + +Author: Matthew Matl +""" +import numpy as np + +from OpenGL.GL import * + +from .utils import format_texture_source +from .sampler import Sampler + + +class Texture(object): + """A texture and its sampler. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + sampler : :class:`Sampler` + The sampler used by this texture. + source : (h,w,c) uint8 or (h,w,c) float or :class:`PIL.Image.Image` + The image used by this texture. If None, the texture is created + empty and width and height must be specified. + source_channels : str + Either `D`, `R`, `RG`, `GB`, `RGB`, or `RGBA`. Indicates the + channels to extract from `source`. Any missing channels will be filled + with `1.0`. + width : int, optional + For empty textures, the width of the texture buffer. + height : int, optional + For empty textures, the height of the texture buffer. + tex_type : int + Either GL_TEXTURE_2D or GL_TEXTURE_CUBE. + data_format : int + For now, just GL_FLOAT. + """ + + def __init__(self, + name=None, + sampler=None, + source=None, + source_channels=None, + width=None, + height=None, + tex_type=GL_TEXTURE_2D, + data_format=GL_UNSIGNED_BYTE): + self.source_channels = source_channels + self.name = name + self.sampler = sampler + self.source = source + self.width = width + self.height = height + self.tex_type = tex_type + self.data_format = data_format + + self._texid = None + self._is_transparent = False + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def sampler(self): + """:class:`Sampler` : The sampler used by this texture. + """ + return self._sampler + + @sampler.setter + def sampler(self, value): + if value is None: + value = Sampler() + self._sampler = value + + @property + def source(self): + """(h,w,c) uint8 or float or :class:`PIL.Image.Image` : The image + used in this texture. + """ + return self._source + + @source.setter + def source(self, value): + if value is None: + self._source = None + else: + self._source = format_texture_source(value, self.source_channels) + self._is_transparent = False + + @property + def source_channels(self): + """str : The channels that were extracted from the original source. + """ + return self._source_channels + + @source_channels.setter + def source_channels(self, value): + self._source_channels = value + + @property + def width(self): + """int : The width of the texture buffer. + """ + return self._width + + @width.setter + def width(self, value): + self._width = value + + @property + def height(self): + """int : The height of the texture buffer. + """ + return self._height + + @height.setter + def height(self, value): + self._height = value + + @property + def tex_type(self): + """int : The type of the texture. + """ + return self._tex_type + + @tex_type.setter + def tex_type(self, value): + self._tex_type = value + + @property + def data_format(self): + """int : The format of the texture data. + """ + return self._data_format + + @data_format.setter + def data_format(self, value): + self._data_format = value + + def is_transparent(self, cutoff=1.0): + """bool : If True, the texture is partially transparent. + """ + if self._is_transparent is None: + self._is_transparent = False + if self.source_channels == 'RGBA' and self.source is not None: + if np.any(self.source[:,:,3] < cutoff): + self._is_transparent = True + return self._is_transparent + + def delete(self): + """Remove this texture from the OpenGL context. + """ + self._unbind() + self._remove_from_context() + + ################## + # OpenGL code + ################## + def _add_to_context(self): + if self._texid is not None: + raise ValueError('Texture already loaded into OpenGL context') + + fmt = GL_DEPTH_COMPONENT + if self.source_channels == 'R': + fmt = GL_RED + elif self.source_channels == 'RG' or self.source_channels == 'GB': + fmt = GL_RG + elif self.source_channels == 'RGB': + fmt = GL_RGB + elif self.source_channels == 'RGBA': + fmt = GL_RGBA + + # Generate the OpenGL texture + self._texid = glGenTextures(1) + glBindTexture(self.tex_type, self._texid) + + # Flip data for OpenGL buffer + data = None + width = self.width + height = self.height + if self.source is not None: + data = np.ascontiguousarray(np.flip(self.source, axis=0).flatten()) + width = self.source.shape[1] + height = self.source.shape[0] + + # Bind texture and generate mipmaps + glTexImage2D( + self.tex_type, 0, fmt, width, height, 0, fmt, + self.data_format, data + ) + if self.source is not None: + glGenerateMipmap(self.tex_type) + + if self.sampler.magFilter is not None: + glTexParameteri( + self.tex_type, GL_TEXTURE_MAG_FILTER, self.sampler.magFilter + ) + else: + if self.source is not None: + glTexParameteri(self.tex_type, GL_TEXTURE_MAG_FILTER, GL_LINEAR) + else: + glTexParameteri(self.tex_type, GL_TEXTURE_MAG_FILTER, GL_NEAREST) + if self.sampler.minFilter is not None: + glTexParameteri( + self.tex_type, GL_TEXTURE_MIN_FILTER, self.sampler.minFilter + ) + else: + if self.source is not None: + glTexParameteri(self.tex_type, GL_TEXTURE_MIN_FILTER, GL_LINEAR_MIPMAP_LINEAR) + else: + glTexParameteri(self.tex_type, GL_TEXTURE_MIN_FILTER, GL_NEAREST) + + glTexParameteri(self.tex_type, GL_TEXTURE_WRAP_S, self.sampler.wrapS) + glTexParameteri(self.tex_type, GL_TEXTURE_WRAP_T, self.sampler.wrapT) + border_color = 255 * np.ones(4).astype(np.uint8) + if self.data_format == GL_FLOAT: + border_color = np.ones(4).astype(np.float32) + glTexParameterfv( + self.tex_type, GL_TEXTURE_BORDER_COLOR, + border_color + ) + + # Unbind texture + glBindTexture(self.tex_type, 0) + + def _remove_from_context(self): + if self._texid is not None: + # TODO OPENGL BUG? + # glDeleteTextures(1, [self._texid]) + glDeleteTextures([self._texid]) + self._texid = None + + def _in_context(self): + return self._texid is not None + + def _bind(self): + # TODO HANDLE INDEXING INTO OTHER UV's + glBindTexture(self.tex_type, self._texid) + + def _unbind(self): + glBindTexture(self.tex_type, 0) + + def _bind_as_depth_attachment(self): + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, + self.tex_type, self._texid, 0) + + def _bind_as_color_attachment(self): + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + self.tex_type, self._texid, 0) diff --git a/pyrender/pyrender/trackball.py b/pyrender/pyrender/trackball.py new file mode 100644 index 0000000000000000000000000000000000000000..3e57a0e82d3f07b80754f575c28a0e05cb73fc50 --- /dev/null +++ b/pyrender/pyrender/trackball.py @@ -0,0 +1,216 @@ +"""Trackball class for 3D manipulation of viewpoints. +""" +import numpy as np + +import trimesh.transformations as transformations + + +class Trackball(object): + """A trackball class for creating camera transforms from mouse movements. + """ + STATE_ROTATE = 0 + STATE_PAN = 1 + STATE_ROLL = 2 + STATE_ZOOM = 3 + + def __init__(self, pose, size, scale, + target=np.array([0.0, 0.0, 0.0])): + """Initialize a trackball with an initial camera-to-world pose + and the given parameters. + + Parameters + ---------- + pose : [4,4] + An initial camera-to-world pose for the trackball. + + size : (float, float) + The width and height of the camera image in pixels. + + scale : float + The diagonal of the scene's bounding box -- + used for ensuring translation motions are sufficiently + fast for differently-sized scenes. + + target : (3,) float + The center of the scene in world coordinates. + The trackball will revolve around this point. + """ + self._size = np.array(size) + self._scale = float(scale) + + self._pose = pose + self._n_pose = pose + + self._target = target + self._n_target = target + + self._state = Trackball.STATE_ROTATE + + @property + def pose(self): + """autolab_core.RigidTransform : The current camera-to-world pose. + """ + return self._n_pose + + def set_state(self, state): + """Set the state of the trackball in order to change the effect of + dragging motions. + + Parameters + ---------- + state : int + One of Trackball.STATE_ROTATE, Trackball.STATE_PAN, + Trackball.STATE_ROLL, and Trackball.STATE_ZOOM. + """ + self._state = state + + def resize(self, size): + """Resize the window. + + Parameters + ---------- + size : (float, float) + The new width and height of the camera image in pixels. + """ + self._size = np.array(size) + + def down(self, point): + """Record an initial mouse press at a given point. + + Parameters + ---------- + point : (2,) int + The x and y pixel coordinates of the mouse press. + """ + self._pdown = np.array(point, dtype=np.float32) + self._pose = self._n_pose + self._target = self._n_target + + def drag(self, point): + """Update the tracball during a drag. + + Parameters + ---------- + point : (2,) int + The current x and y pixel coordinates of the mouse during a drag. + This will compute a movement for the trackball with the relative + motion between this point and the one marked by down(). + """ + point = np.array(point, dtype=np.float32) + dx, dy = point - self._pdown + mindim = 0.3 * np.min(self._size) + + target = self._target + x_axis = self._pose[:3,0].flatten() + y_axis = self._pose[:3,1].flatten() + z_axis = self._pose[:3,2].flatten() + eye = self._pose[:3,3].flatten() + + # Interpret drag as a rotation + if self._state == Trackball.STATE_ROTATE: + x_angle = -dx / mindim + x_rot_mat = transformations.rotation_matrix( + x_angle, y_axis, target + ) + + y_angle = dy / mindim + y_rot_mat = transformations.rotation_matrix( + y_angle, x_axis, target + ) + + self._n_pose = y_rot_mat.dot(x_rot_mat.dot(self._pose)) + + # Interpret drag as a roll about the camera axis + elif self._state == Trackball.STATE_ROLL: + center = self._size / 2.0 + v_init = self._pdown - center + v_curr = point - center + v_init = v_init / np.linalg.norm(v_init) + v_curr = v_curr / np.linalg.norm(v_curr) + + theta = (-np.arctan2(v_curr[1], v_curr[0]) + + np.arctan2(v_init[1], v_init[0])) + + rot_mat = transformations.rotation_matrix(theta, z_axis, target) + + self._n_pose = rot_mat.dot(self._pose) + + # Interpret drag as a camera pan in view plane + elif self._state == Trackball.STATE_PAN: + dx = -dx / (5.0 * mindim) * self._scale + dy = -dy / (5.0 * mindim) * self._scale + + translation = dx * x_axis + dy * y_axis + self._n_target = self._target + translation + t_tf = np.eye(4) + t_tf[:3,3] = translation + self._n_pose = t_tf.dot(self._pose) + + # Interpret drag as a zoom motion + elif self._state == Trackball.STATE_ZOOM: + radius = np.linalg.norm(eye - target) + ratio = 0.0 + if dy > 0: + ratio = np.exp(abs(dy) / (0.5 * self._size[1])) - 1.0 + elif dy < 0: + ratio = 1.0 - np.exp(dy / (0.5 * (self._size[1]))) + translation = -np.sign(dy) * ratio * radius * z_axis + t_tf = np.eye(4) + t_tf[:3,3] = translation + self._n_pose = t_tf.dot(self._pose) + + def scroll(self, clicks): + """Zoom using a mouse scroll wheel motion. + + Parameters + ---------- + clicks : int + The number of clicks. Positive numbers indicate forward wheel + movement. + """ + target = self._target + ratio = 0.90 + + mult = 1.0 + if clicks > 0: + mult = ratio**clicks + elif clicks < 0: + mult = (1.0 / ratio)**abs(clicks) + + z_axis = self._n_pose[:3,2].flatten() + eye = self._n_pose[:3,3].flatten() + radius = np.linalg.norm(eye - target) + translation = (mult * radius - radius) * z_axis + t_tf = np.eye(4) + t_tf[:3,3] = translation + self._n_pose = t_tf.dot(self._n_pose) + + z_axis = self._pose[:3,2].flatten() + eye = self._pose[:3,3].flatten() + radius = np.linalg.norm(eye - target) + translation = (mult * radius - radius) * z_axis + t_tf = np.eye(4) + t_tf[:3,3] = translation + self._pose = t_tf.dot(self._pose) + + def rotate(self, azimuth, axis=None): + """Rotate the trackball about the "Up" axis by azimuth radians. + + Parameters + ---------- + azimuth : float + The number of radians to rotate. + """ + target = self._target + + y_axis = self._n_pose[:3,1].flatten() + if axis is not None: + y_axis = axis + x_rot_mat = transformations.rotation_matrix(azimuth, y_axis, target) + self._n_pose = x_rot_mat.dot(self._n_pose) + + y_axis = self._pose[:3,1].flatten() + if axis is not None: + y_axis = axis + x_rot_mat = transformations.rotation_matrix(azimuth, y_axis, target) + self._pose = x_rot_mat.dot(self._pose) diff --git a/pyrender/pyrender/utils.py b/pyrender/pyrender/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..48a11faf991606ad7fb0691582f0bc6f06101a45 --- /dev/null +++ b/pyrender/pyrender/utils.py @@ -0,0 +1,115 @@ +import numpy as np +from PIL import Image + + +def format_color_vector(value, length): + """Format a color vector. + """ + if isinstance(value, int): + value = value / 255.0 + if isinstance(value, float): + value = np.repeat(value, length) + if isinstance(value, list) or isinstance(value, tuple): + value = np.array(value) + if isinstance(value, np.ndarray): + value = value.squeeze() + if np.issubdtype(value.dtype, np.integer): + value = (value / 255.0).astype(np.float32) + if value.ndim != 1: + raise ValueError('Format vector takes only 1-D vectors') + if length > value.shape[0]: + value = np.hstack((value, np.ones(length - value.shape[0]))) + elif length < value.shape[0]: + value = value[:length] + else: + raise ValueError('Invalid vector data type') + + return value.squeeze().astype(np.float32) + + +def format_color_array(value, shape): + """Format an array of colors. + """ + # Convert uint8 to floating + value = np.asanyarray(value) + if np.issubdtype(value.dtype, np.integer): + value = (value / 255.0).astype(np.float32) + + # Match up shapes + if value.ndim == 1: + value = np.tile(value, (shape[0],1)) + if value.shape[1] < shape[1]: + nc = shape[1] - value.shape[1] + value = np.column_stack((value, np.ones((value.shape[0], nc)))) + elif value.shape[1] > shape[1]: + value = value[:,:shape[1]] + return value.astype(np.float32) + + +def format_texture_source(texture, target_channels='RGB'): + """Format a texture as a float32 np array. + """ + + # Pass through None + if texture is None: + return None + + # Convert PIL images into numpy arrays + if isinstance(texture, Image.Image): + if texture.mode == 'P' and target_channels in ('RGB', 'RGBA'): + texture = np.array(texture.convert(target_channels)) + else: + texture = np.array(texture) + + # Format numpy arrays + if isinstance(texture, np.ndarray): + if np.issubdtype(texture.dtype, np.floating): + texture = np.array(texture * 255.0, dtype=np.uint8) + elif np.issubdtype(texture.dtype, np.integer): + texture = texture.astype(np.uint8) + else: + raise TypeError('Invalid type {} for texture'.format( + type(texture) + )) + + # Format array by picking out correct texture channels or padding + if texture.ndim == 2: + texture = texture[:,:,np.newaxis] + if target_channels == 'R': + texture = texture[:,:,0] + texture = texture.squeeze() + elif target_channels == 'RG': + if texture.shape[2] == 1: + texture = np.repeat(texture, 2, axis=2) + else: + texture = texture[:,:,(0,1)] + elif target_channels == 'GB': + if texture.shape[2] == 1: + texture = np.repeat(texture, 2, axis=2) + elif texture.shape[2] > 2: + texture = texture[:,:,(1,2)] + elif target_channels == 'RGB': + if texture.shape[2] == 1: + texture = np.repeat(texture, 3, axis=2) + elif texture.shape[2] == 2: + raise ValueError('Cannot reformat 2-channel texture into RGB') + else: + texture = texture[:,:,(0,1,2)] + elif target_channels == 'RGBA': + if texture.shape[2] == 1: + texture = np.repeat(texture, 4, axis=2) + texture[:,:,3] = 255 + elif texture.shape[2] == 2: + raise ValueError('Cannot reformat 2-channel texture into RGBA') + elif texture.shape[2] == 3: + tx = np.empty((texture.shape[0], texture.shape[1], 4), dtype=np.uint8) + tx[:,:,:3] = texture + tx[:,:,3] = 255 + texture = tx + else: + raise ValueError('Invalid texture channel specification: {}' + .format(target_channels)) + else: + raise TypeError('Invalid type {} for texture'.format(type(texture))) + + return texture diff --git a/pyrender/pyrender/version.py b/pyrender/pyrender/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a33fc87f61f528780e3319a5160769cc84512b1b --- /dev/null +++ b/pyrender/pyrender/version.py @@ -0,0 +1 @@ +__version__ = '0.1.45' diff --git a/pyrender/pyrender/viewer.py b/pyrender/pyrender/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..d2326c38205c6eaddb4f567e3b088329187af258 --- /dev/null +++ b/pyrender/pyrender/viewer.py @@ -0,0 +1,1160 @@ +"""A pyglet-based interactive 3D scene viewer. +""" +import copy +import os +import sys +from threading import Thread, RLock +import time + +import imageio +import numpy as np +import OpenGL +import trimesh + +try: + from Tkinter import Tk, tkFileDialog as filedialog +except Exception: + try: + from tkinter import Tk, filedialog as filedialog + except Exception: + pass + +from .constants import (TARGET_OPEN_GL_MAJOR, TARGET_OPEN_GL_MINOR, + MIN_OPEN_GL_MAJOR, MIN_OPEN_GL_MINOR, + TEXT_PADDING, DEFAULT_SCENE_SCALE, + DEFAULT_Z_FAR, DEFAULT_Z_NEAR, RenderFlags, TextAlign) +from .light import DirectionalLight +from .node import Node +from .camera import PerspectiveCamera, OrthographicCamera, IntrinsicsCamera +from .trackball import Trackball +from .renderer import Renderer +from .mesh import Mesh + +import pyglet +from pyglet import clock +pyglet.options['shadow_window'] = False + + +class Viewer(pyglet.window.Window): + """An interactive viewer for 3D scenes. + + The viewer's camera is separate from the scene's, but will take on + the parameters of the scene's main view camera and start in the same pose. + If the scene does not have a camera, a suitable default will be provided. + + Parameters + ---------- + scene : :class:`Scene` + The scene to visualize. + viewport_size : (2,) int + The width and height of the initial viewing window. + render_flags : dict + A set of flags for rendering the scene. Described in the note below. + viewer_flags : dict + A set of flags for controlling the viewer's behavior. + Described in the note below. + registered_keys : dict + A map from ASCII key characters to tuples containing: + + - A function to be called whenever the key is pressed, + whose first argument will be the viewer itself. + - (Optionally) A list of additional positional arguments + to be passed to the function. + - (Optionally) A dict of keyword arguments to be passed + to the function. + + kwargs : dict + Any keyword arguments left over will be interpreted as belonging to + either the :attr:`.Viewer.render_flags` or :attr:`.Viewer.viewer_flags` + dictionaries. Those flag sets will be updated appropriately. + + Note + ---- + The basic commands for moving about the scene are given as follows: + + - **Rotating about the scene**: Hold the left mouse button and + drag the cursor. + - **Rotating about the view axis**: Hold ``CTRL`` and the left mouse + button and drag the cursor. + - **Panning**: + + - Hold SHIFT, then hold the left mouse button and drag the cursor, or + - Hold the middle mouse button and drag the cursor. + + - **Zooming**: + + - Scroll the mouse wheel, or + - Hold the right mouse button and drag the cursor. + + Other keyboard commands are as follows: + + - ``a``: Toggles rotational animation mode. + - ``c``: Toggles backface culling. + - ``f``: Toggles fullscreen mode. + - ``h``: Toggles shadow rendering. + - ``i``: Toggles axis display mode + (no axes, world axis, mesh axes, all axes). + - ``l``: Toggles lighting mode + (scene lighting, Raymond lighting, or direct lighting). + - ``m``: Toggles face normal visualization. + - ``n``: Toggles vertex normal visualization. + - ``o``: Toggles orthographic mode. + - ``q``: Quits the viewer. + - ``r``: Starts recording a GIF, and pressing again stops recording + and opens a file dialog. + - ``s``: Opens a file dialog to save the current view as an image. + - ``w``: Toggles wireframe mode + (scene default, flip wireframes, all wireframe, or all solid). + - ``z``: Resets the camera to the initial view. + + Note + ---- + The valid keys for ``render_flags`` are as follows: + + - ``flip_wireframe``: `bool`, If `True`, all objects will have their + wireframe modes flipped from what their material indicates. + Defaults to `False`. + - ``all_wireframe``: `bool`, If `True`, all objects will be rendered + in wireframe mode. Defaults to `False`. + - ``all_solid``: `bool`, If `True`, all objects will be rendered in + solid mode. Defaults to `False`. + - ``shadows``: `bool`, If `True`, shadows will be rendered. + Defaults to `False`. + - ``vertex_normals``: `bool`, If `True`, vertex normals will be + rendered as blue lines. Defaults to `False`. + - ``face_normals``: `bool`, If `True`, face normals will be rendered as + blue lines. Defaults to `False`. + - ``cull_faces``: `bool`, If `True`, backfaces will be culled. + Defaults to `True`. + - ``point_size`` : float, The point size in pixels. Defaults to 1px. + + Note + ---- + The valid keys for ``viewer_flags`` are as follows: + + - ``rotate``: `bool`, If `True`, the scene's camera will rotate + about an axis. Defaults to `False`. + - ``rotate_rate``: `float`, The rate of rotation in radians per second. + Defaults to `PI / 3.0`. + - ``rotate_axis``: `(3,) float`, The axis in world coordinates to rotate + about. Defaults to ``[0,0,1]``. + - ``view_center``: `(3,) float`, The position to rotate the scene about. + Defaults to the scene's centroid. + - ``use_raymond_lighting``: `bool`, If `True`, an additional set of three + directional lights that move with the camera will be added to the scene. + Defaults to `False`. + - ``use_direct_lighting``: `bool`, If `True`, an additional directional + light that moves with the camera and points out of it will be added to + the scene. Defaults to `False`. + - ``lighting_intensity``: `float`, The overall intensity of the + viewer's additional lights (when they're in use). Defaults to 3.0. + - ``use_perspective_cam``: `bool`, If `True`, a perspective camera will + be used. Otherwise, an orthographic camera is used. Defaults to `True`. + - ``save_directory``: `str`, A directory to open the file dialogs in. + Defaults to `None`. + - ``window_title``: `str`, A title for the viewer's application window. + Defaults to `"Scene Viewer"`. + - ``refresh_rate``: `float`, A refresh rate for rendering, in Hertz. + Defaults to `30.0`. + - ``fullscreen``: `bool`, Whether to make viewer fullscreen. + Defaults to `False`. + - ``show_world_axis``: `bool`, Whether to show the world axis. + Defaults to `False`. + - ``show_mesh_axes``: `bool`, Whether to show the individual mesh axes. + Defaults to `False`. + - ``caption``: `list of dict`, Text caption(s) to display on the viewer. + Defaults to `None`. + + Note + ---- + Animation can be accomplished by running the viewer with ``run_in_thread`` + enabled. Then, just run a loop in your main thread, updating the scene as + needed. Before updating the scene, be sure to acquire the + :attr:`.Viewer.render_lock`, and release it when your update is done. + """ + + def __init__(self, scene, viewport_size=None, + render_flags=None, viewer_flags=None, + registered_keys=None, run_in_thread=False, + auto_start=True, + **kwargs): + + ####################################################################### + # Save attributes and flags + ####################################################################### + if viewport_size is None: + viewport_size = (640, 480) + self._scene = scene + self._viewport_size = viewport_size + self._render_lock = RLock() + self._is_active = False + self._should_close = False + self._run_in_thread = run_in_thread + self._auto_start = auto_start + + self._default_render_flags = { + 'flip_wireframe': False, + 'all_wireframe': False, + 'all_solid': False, + 'shadows': False, + 'vertex_normals': False, + 'face_normals': False, + 'cull_faces': True, + 'point_size': 1.0, + } + self._default_viewer_flags = { + 'mouse_pressed': False, + 'rotate': False, + 'rotate_rate': np.pi / 3.0, + 'rotate_axis': np.array([0.0, 0.0, 1.0]), + 'view_center': None, + 'record': False, + 'use_raymond_lighting': False, + 'use_direct_lighting': False, + 'lighting_intensity': 3.0, + 'use_perspective_cam': True, + 'save_directory': None, + 'window_title': 'Scene Viewer', + 'refresh_rate': 30.0, + 'fullscreen': False, + 'show_world_axis': False, + 'show_mesh_axes': False, + 'caption': None + } + self._render_flags = self._default_render_flags.copy() + self._viewer_flags = self._default_viewer_flags.copy() + self._viewer_flags['rotate_axis'] = ( + self._default_viewer_flags['rotate_axis'].copy() + ) + + if render_flags is not None: + self._render_flags.update(render_flags) + if viewer_flags is not None: + self._viewer_flags.update(viewer_flags) + + for key in kwargs: + if key in self.render_flags: + self._render_flags[key] = kwargs[key] + elif key in self.viewer_flags: + self._viewer_flags[key] = kwargs[key] + + # TODO MAC OS BUG FOR SHADOWS + if sys.platform == 'darwin': + self._render_flags['shadows'] = False + + self._registered_keys = {} + if registered_keys is not None: + self._registered_keys = { + ord(k.lower()): registered_keys[k] for k in registered_keys + } + + ####################################################################### + # Save internal settings + ####################################################################### + + # Set up caption stuff + self._message_text = None + self._ticks_till_fade = 2.0 / 3.0 * self.viewer_flags['refresh_rate'] + self._message_opac = 1.0 + self._ticks_till_fade + + # Set up raymond lights and direct lights + self._raymond_lights = self._create_raymond_lights() + self._direct_light = self._create_direct_light() + + # Set up axes + self._axes = {} + self._axis_mesh = Mesh.from_trimesh( + trimesh.creation.axis(origin_size=0.1, axis_radius=0.05, + axis_length=1.0), smooth=False) + if self.viewer_flags['show_world_axis']: + self._set_axes(world=self.viewer_flags['show_world_axis'], + mesh=self.viewer_flags['show_mesh_axes']) + + ####################################################################### + # Set up camera node + ####################################################################### + self._camera_node = None + self._prior_main_camera_node = None + self._default_camera_pose = None + self._default_persp_cam = None + self._default_orth_cam = None + self._trackball = None + self._saved_frames = [] + + # Extract main camera from scene and set up our mirrored copy + znear = None + zfar = None + if scene.main_camera_node is not None: + n = scene.main_camera_node + camera = copy.copy(n.camera) + if isinstance(camera, (PerspectiveCamera, IntrinsicsCamera)): + self._default_persp_cam = camera + znear = camera.znear + zfar = camera.zfar + elif isinstance(camera, OrthographicCamera): + self._default_orth_cam = camera + znear = camera.znear + zfar = camera.zfar + self._default_camera_pose = scene.get_pose(scene.main_camera_node) + self._prior_main_camera_node = n + + # Set defaults as needed + if zfar is None: + zfar = max(scene.scale * 10.0, DEFAULT_Z_FAR) + if znear is None or znear == 0: + if scene.scale == 0: + znear = DEFAULT_Z_NEAR + else: + znear = min(scene.scale / 10.0, DEFAULT_Z_NEAR) + + if self._default_persp_cam is None: + self._default_persp_cam = PerspectiveCamera( + yfov=np.pi / 3.0, znear=znear, zfar=zfar + ) + if self._default_orth_cam is None: + xmag = ymag = scene.scale + if scene.scale == 0: + xmag = ymag = 1.0 + self._default_orth_cam = OrthographicCamera( + xmag=xmag, ymag=ymag, + znear=znear, + zfar=zfar + ) + if self._default_camera_pose is None: + self._default_camera_pose = self._compute_initial_camera_pose() + + # Pick camera + if self.viewer_flags['use_perspective_cam']: + camera = self._default_persp_cam + else: + camera = self._default_orth_cam + + self._camera_node = Node( + matrix=self._default_camera_pose, camera=camera + ) + scene.add_node(self._camera_node) + scene.main_camera_node = self._camera_node + self._reset_view() + + ####################################################################### + # Initialize OpenGL context and renderer + ####################################################################### + self._renderer = Renderer( + self._viewport_size[0], self._viewport_size[1], + self.render_flags['point_size'] + ) + self._is_active = True + + if self.run_in_thread: + self._thread = Thread(target=self._init_and_start_app) + self._thread.start() + else: + if auto_start: + self._init_and_start_app() + + def start(self): + self._init_and_start_app() + + @property + def scene(self): + """:class:`.Scene` : The scene being visualized. + """ + return self._scene + + @property + def viewport_size(self): + """(2,) int : The width and height of the viewing window. + """ + return self._viewport_size + + @property + def render_lock(self): + """:class:`threading.RLock` : If acquired, prevents the viewer from + rendering until released. + + Run :meth:`.Viewer.render_lock.acquire` before making updates to + the scene in a different thread, and run + :meth:`.Viewer.render_lock.release` once you're done to let the viewer + continue. + """ + return self._render_lock + + @property + def is_active(self): + """bool : `True` if the viewer is active, or `False` if it has + been closed. + """ + return self._is_active + + @property + def run_in_thread(self): + """bool : Whether the viewer was run in a separate thread. + """ + return self._run_in_thread + + @property + def render_flags(self): + """dict : Flags for controlling the renderer's behavior. + + - ``flip_wireframe``: `bool`, If `True`, all objects will have their + wireframe modes flipped from what their material indicates. + Defaults to `False`. + - ``all_wireframe``: `bool`, If `True`, all objects will be rendered + in wireframe mode. Defaults to `False`. + - ``all_solid``: `bool`, If `True`, all objects will be rendered in + solid mode. Defaults to `False`. + - ``shadows``: `bool`, If `True`, shadows will be rendered. + Defaults to `False`. + - ``vertex_normals``: `bool`, If `True`, vertex normals will be + rendered as blue lines. Defaults to `False`. + - ``face_normals``: `bool`, If `True`, face normals will be rendered as + blue lines. Defaults to `False`. + - ``cull_faces``: `bool`, If `True`, backfaces will be culled. + Defaults to `True`. + - ``point_size`` : float, The point size in pixels. Defaults to 1px. + + """ + return self._render_flags + + @render_flags.setter + def render_flags(self, value): + self._render_flags = value + + @property + def viewer_flags(self): + """dict : Flags for controlling the viewer's behavior. + + The valid keys for ``viewer_flags`` are as follows: + + - ``rotate``: `bool`, If `True`, the scene's camera will rotate + about an axis. Defaults to `False`. + - ``rotate_rate``: `float`, The rate of rotation in radians per second. + Defaults to `PI / 3.0`. + - ``rotate_axis``: `(3,) float`, The axis in world coordinates to + rotate about. Defaults to ``[0,0,1]``. + - ``view_center``: `(3,) float`, The position to rotate the scene + about. Defaults to the scene's centroid. + - ``use_raymond_lighting``: `bool`, If `True`, an additional set of + three directional lights that move with the camera will be added to + the scene. Defaults to `False`. + - ``use_direct_lighting``: `bool`, If `True`, an additional directional + light that moves with the camera and points out of it will be + added to the scene. Defaults to `False`. + - ``lighting_intensity``: `float`, The overall intensity of the + viewer's additional lights (when they're in use). Defaults to 3.0. + - ``use_perspective_cam``: `bool`, If `True`, a perspective camera will + be used. Otherwise, an orthographic camera is used. Defaults to + `True`. + - ``save_directory``: `str`, A directory to open the file dialogs in. + Defaults to `None`. + - ``window_title``: `str`, A title for the viewer's application window. + Defaults to `"Scene Viewer"`. + - ``refresh_rate``: `float`, A refresh rate for rendering, in Hertz. + Defaults to `30.0`. + - ``fullscreen``: `bool`, Whether to make viewer fullscreen. + Defaults to `False`. + - ``show_world_axis``: `bool`, Whether to show the world axis. + Defaults to `False`. + - ``show_mesh_axes``: `bool`, Whether to show the individual mesh axes. + Defaults to `False`. + - ``caption``: `list of dict`, Text caption(s) to display on + the viewer. Defaults to `None`. + + """ + return self._viewer_flags + + @viewer_flags.setter + def viewer_flags(self, value): + self._viewer_flags = value + + @property + def registered_keys(self): + """dict : Map from ASCII key character to a handler function. + + This is a map from ASCII key characters to tuples containing: + + - A function to be called whenever the key is pressed, + whose first argument will be the viewer itself. + - (Optionally) A list of additional positional arguments + to be passed to the function. + - (Optionally) A dict of keyword arguments to be passed + to the function. + + """ + return self._registered_keys + + @registered_keys.setter + def registered_keys(self, value): + self._registered_keys = value + + def close_external(self): + """Close the viewer from another thread. + + This function will wait for the actual close, so you immediately + manipulate the scene afterwards. + """ + self._should_close = True + while self.is_active: + time.sleep(1.0 / self.viewer_flags['refresh_rate']) + + def save_gif(self, filename=None): + """Save the stored GIF frames to a file. + + To use this asynchronously, run the viewer with the ``record`` + flag and the ``run_in_thread`` flags set. + Kill the viewer after your desired time with + :meth:`.Viewer.close_external`, and then call :meth:`.Viewer.save_gif`. + + Parameters + ---------- + filename : str + The file to save the GIF to. If not specified, + a file dialog will be opened to ask the user where + to save the GIF file. + """ + if filename is None: + filename = self._get_save_filename(['gif', 'all']) + if filename is not None: + self.viewer_flags['save_directory'] = os.path.dirname(filename) + imageio.mimwrite(filename, self._saved_frames, + fps=self.viewer_flags['refresh_rate'], + palettesize=128, subrectangles=True) + self._saved_frames = [] + + def on_close(self): + """Exit the event loop when the window is closed. + """ + # Remove our camera and restore the prior one + if self._camera_node is not None: + self.scene.remove_node(self._camera_node) + if self._prior_main_camera_node is not None: + self.scene.main_camera_node = self._prior_main_camera_node + + # Delete any lighting nodes that we've attached + if self.viewer_flags['use_raymond_lighting']: + for n in self._raymond_lights: + if self.scene.has_node(n): + self.scene.remove_node(n) + if self.viewer_flags['use_direct_lighting']: + if self.scene.has_node(self._direct_light): + self.scene.remove_node(self._direct_light) + + # Delete any axis nodes that we've attached + self._remove_axes() + + # Delete renderer + if self._renderer is not None: + self._renderer.delete() + self._renderer = None + + # Force clean-up of OpenGL context data + try: + OpenGL.contextdata.cleanupContext() + self.close() + except Exception: + pass + finally: + self._is_active = False + super(Viewer, self).on_close() + pyglet.app.exit() + + def on_draw(self): + """Redraw the scene into the viewing window. + """ + if self._renderer is None: + return + + if self.run_in_thread or not self._auto_start: + self.render_lock.acquire() + + # Make OpenGL context current + self.switch_to() + + # Render the scene + self.clear() + self._render() + + if self._message_text is not None: + self._renderer.render_text( + self._message_text, + self.viewport_size[0] - TEXT_PADDING, + TEXT_PADDING, + font_pt=20, + color=np.array([0.1, 0.7, 0.2, + np.clip(self._message_opac, 0.0, 1.0)]), + align=TextAlign.BOTTOM_RIGHT + ) + + if self.viewer_flags['caption'] is not None: + for caption in self.viewer_flags['caption']: + xpos, ypos = self._location_to_x_y(caption['location']) + self._renderer.render_text( + caption['text'], + xpos, + ypos, + font_name=caption['font_name'], + font_pt=caption['font_pt'], + color=caption['color'], + scale=caption['scale'], + align=caption['location'] + ) + + if self.run_in_thread or not self._auto_start: + self.render_lock.release() + + def on_resize(self, width, height): + """Resize the camera and trackball when the window is resized. + """ + if self._renderer is None: + return + + self._viewport_size = (width, height) + self._trackball.resize(self._viewport_size) + self._renderer.viewport_width = self._viewport_size[0] + self._renderer.viewport_height = self._viewport_size[1] + self.on_draw() + + def on_mouse_press(self, x, y, buttons, modifiers): + """Record an initial mouse press. + """ + self._trackball.set_state(Trackball.STATE_ROTATE) + if (buttons == pyglet.window.mouse.LEFT): + ctrl = (modifiers & pyglet.window.key.MOD_CTRL) + shift = (modifiers & pyglet.window.key.MOD_SHIFT) + if (ctrl and shift): + self._trackball.set_state(Trackball.STATE_ZOOM) + elif ctrl: + self._trackball.set_state(Trackball.STATE_ROLL) + elif shift: + self._trackball.set_state(Trackball.STATE_PAN) + elif (buttons == pyglet.window.mouse.MIDDLE): + self._trackball.set_state(Trackball.STATE_PAN) + elif (buttons == pyglet.window.mouse.RIGHT): + self._trackball.set_state(Trackball.STATE_ZOOM) + + self._trackball.down(np.array([x, y])) + + # Stop animating while using the mouse + self.viewer_flags['mouse_pressed'] = True + + def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers): + """Record a mouse drag. + """ + self._trackball.drag(np.array([x, y])) + + def on_mouse_release(self, x, y, button, modifiers): + """Record a mouse release. + """ + self.viewer_flags['mouse_pressed'] = False + + def on_mouse_scroll(self, x, y, dx, dy): + """Record a mouse scroll. + """ + if self.viewer_flags['use_perspective_cam']: + self._trackball.scroll(dy) + else: + spfc = 0.95 + spbc = 1.0 / 0.95 + sf = 1.0 + if dy > 0: + sf = spfc * dy + elif dy < 0: + sf = - spbc * dy + + c = self._camera_node.camera + xmag = max(c.xmag * sf, 1e-8) + ymag = max(c.ymag * sf, 1e-8 * c.ymag / c.xmag) + c.xmag = xmag + c.ymag = ymag + + def on_key_press(self, symbol, modifiers): + """Record a key press. + """ + # First, check for registered key callbacks + if symbol in self.registered_keys: + tup = self.registered_keys[symbol] + callback = None + args = [] + kwargs = {} + if not isinstance(tup, (list, tuple, np.ndarray)): + callback = tup + else: + callback = tup[0] + if len(tup) == 2: + args = tup[1] + if len(tup) == 3: + kwargs = tup[2] + callback(self, *args, **kwargs) + return + + # Otherwise, use default key functions + + # A causes the frame to rotate + self._message_text = None + if symbol == pyglet.window.key.A: + self.viewer_flags['rotate'] = not self.viewer_flags['rotate'] + if self.viewer_flags['rotate']: + self._message_text = 'Rotation On' + else: + self._message_text = 'Rotation Off' + + # C toggles backface culling + elif symbol == pyglet.window.key.C: + self.render_flags['cull_faces'] = ( + not self.render_flags['cull_faces'] + ) + if self.render_flags['cull_faces']: + self._message_text = 'Cull Faces On' + else: + self._message_text = 'Cull Faces Off' + + # F toggles face normals + elif symbol == pyglet.window.key.F: + self.viewer_flags['fullscreen'] = ( + not self.viewer_flags['fullscreen'] + ) + self.set_fullscreen(self.viewer_flags['fullscreen']) + self.activate() + if self.viewer_flags['fullscreen']: + self._message_text = 'Fullscreen On' + else: + self._message_text = 'Fullscreen Off' + + # S toggles shadows + elif symbol == pyglet.window.key.H and sys.platform != 'darwin': + self.render_flags['shadows'] = not self.render_flags['shadows'] + if self.render_flags['shadows']: + self._message_text = 'Shadows On' + else: + self._message_text = 'Shadows Off' + + elif symbol == pyglet.window.key.I: + if (self.viewer_flags['show_world_axis'] and not + self.viewer_flags['show_mesh_axes']): + self.viewer_flags['show_world_axis'] = False + self.viewer_flags['show_mesh_axes'] = True + self._set_axes(False, True) + self._message_text = 'Mesh Axes On' + elif (not self.viewer_flags['show_world_axis'] and + self.viewer_flags['show_mesh_axes']): + self.viewer_flags['show_world_axis'] = True + self.viewer_flags['show_mesh_axes'] = True + self._set_axes(True, True) + self._message_text = 'All Axes On' + elif (self.viewer_flags['show_world_axis'] and + self.viewer_flags['show_mesh_axes']): + self.viewer_flags['show_world_axis'] = False + self.viewer_flags['show_mesh_axes'] = False + self._set_axes(False, False) + self._message_text = 'All Axes Off' + else: + self.viewer_flags['show_world_axis'] = True + self.viewer_flags['show_mesh_axes'] = False + self._set_axes(True, False) + self._message_text = 'World Axis On' + + # L toggles the lighting mode + elif symbol == pyglet.window.key.L: + if self.viewer_flags['use_raymond_lighting']: + self.viewer_flags['use_raymond_lighting'] = False + self.viewer_flags['use_direct_lighting'] = True + self._message_text = 'Direct Lighting' + elif self.viewer_flags['use_direct_lighting']: + self.viewer_flags['use_raymond_lighting'] = False + self.viewer_flags['use_direct_lighting'] = False + self._message_text = 'Default Lighting' + else: + self.viewer_flags['use_raymond_lighting'] = True + self.viewer_flags['use_direct_lighting'] = False + self._message_text = 'Raymond Lighting' + + # M toggles face normals + elif symbol == pyglet.window.key.M: + self.render_flags['face_normals'] = ( + not self.render_flags['face_normals'] + ) + if self.render_flags['face_normals']: + self._message_text = 'Face Normals On' + else: + self._message_text = 'Face Normals Off' + + # N toggles vertex normals + elif symbol == pyglet.window.key.N: + self.render_flags['vertex_normals'] = ( + not self.render_flags['vertex_normals'] + ) + if self.render_flags['vertex_normals']: + self._message_text = 'Vert Normals On' + else: + self._message_text = 'Vert Normals Off' + + # O toggles orthographic camera mode + elif symbol == pyglet.window.key.O: + self.viewer_flags['use_perspective_cam'] = ( + not self.viewer_flags['use_perspective_cam'] + ) + if self.viewer_flags['use_perspective_cam']: + camera = self._default_persp_cam + self._message_text = 'Perspective View' + else: + camera = self._default_orth_cam + self._message_text = 'Orthographic View' + + cam_pose = self._camera_node.matrix.copy() + cam_node = Node(matrix=cam_pose, camera=camera) + self.scene.remove_node(self._camera_node) + self.scene.add_node(cam_node) + self.scene.main_camera_node = cam_node + self._camera_node = cam_node + + # Q quits the viewer + elif symbol == pyglet.window.key.Q: + self.on_close() + + # R starts recording frames + elif symbol == pyglet.window.key.R: + if self.viewer_flags['record']: + self.save_gif() + self.set_caption(self.viewer_flags['window_title']) + else: + self.set_caption( + '{} (RECORDING)'.format(self.viewer_flags['window_title']) + ) + self.viewer_flags['record'] = not self.viewer_flags['record'] + + # S saves the current frame as an image + elif symbol == pyglet.window.key.S: + self._save_image() + + # W toggles through wireframe modes + elif symbol == pyglet.window.key.W: + if self.render_flags['flip_wireframe']: + self.render_flags['flip_wireframe'] = False + self.render_flags['all_wireframe'] = True + self.render_flags['all_solid'] = False + self._message_text = 'All Wireframe' + elif self.render_flags['all_wireframe']: + self.render_flags['flip_wireframe'] = False + self.render_flags['all_wireframe'] = False + self.render_flags['all_solid'] = True + self._message_text = 'All Solid' + elif self.render_flags['all_solid']: + self.render_flags['flip_wireframe'] = False + self.render_flags['all_wireframe'] = False + self.render_flags['all_solid'] = False + self._message_text = 'Default Wireframe' + else: + self.render_flags['flip_wireframe'] = True + self.render_flags['all_wireframe'] = False + self.render_flags['all_solid'] = False + self._message_text = 'Flip Wireframe' + + # Z resets the camera viewpoint + elif symbol == pyglet.window.key.Z: + self._reset_view() + + if self._message_text is not None: + self._message_opac = 1.0 + self._ticks_till_fade + + @staticmethod + def _time_event(dt, self): + """The timer callback. + """ + # Don't run old dead events after we've already closed + if not self._is_active: + return + + if self.viewer_flags['record']: + self._record() + if (self.viewer_flags['rotate'] and not + self.viewer_flags['mouse_pressed']): + self._rotate() + + # Manage message opacity + if self._message_text is not None: + if self._message_opac > 1.0: + self._message_opac -= 1.0 + else: + self._message_opac *= 0.90 + if self._message_opac < 0.05: + self._message_opac = 1.0 + self._ticks_till_fade + self._message_text = None + + if self._should_close: + self.on_close() + else: + self.on_draw() + + def _reset_view(self): + """Reset the view to a good initial state. + + The view is initially along the positive x-axis at a + sufficient distance from the scene. + """ + scale = self.scene.scale + if scale == 0.0: + scale = DEFAULT_SCENE_SCALE + centroid = self.scene.centroid + + if self.viewer_flags['view_center'] is not None: + centroid = self.viewer_flags['view_center'] + + self._camera_node.matrix = self._default_camera_pose + self._trackball = Trackball( + self._default_camera_pose, self.viewport_size, scale, centroid + ) + + def _get_save_filename(self, file_exts): + file_types = { + 'png': ('png files', '*.png'), + 'jpg': ('jpeg files', '*.jpg'), + 'gif': ('gif files', '*.gif'), + 'all': ('all files', '*'), + } + filetypes = [file_types[x] for x in file_exts] + try: + root = Tk() + save_dir = self.viewer_flags['save_directory'] + if save_dir is None: + save_dir = os.getcwd() + filename = filedialog.asksaveasfilename( + initialdir=save_dir, title='Select file save location', + filetypes=filetypes + ) + except Exception: + return None + + root.destroy() + if filename == (): + return None + return filename + + def _save_image(self): + filename = self._get_save_filename(['png', 'jpg', 'gif', 'all']) + if filename is not None: + self.viewer_flags['save_directory'] = os.path.dirname(filename) + imageio.imwrite(filename, self._renderer.read_color_buf()) + + def _record(self): + """Save another frame for the GIF. + """ + data = self._renderer.read_color_buf() + if not np.all(data == 0.0): + self._saved_frames.append(data) + + def _rotate(self): + """Animate the scene by rotating the camera. + """ + az = (self.viewer_flags['rotate_rate'] / + self.viewer_flags['refresh_rate']) + self._trackball.rotate(az, self.viewer_flags['rotate_axis']) + + def _render(self): + """Render the scene into the framebuffer and flip. + """ + scene = self.scene + self._camera_node.matrix = self._trackball.pose.copy() + + # Set lighting + vli = self.viewer_flags['lighting_intensity'] + if self.viewer_flags['use_raymond_lighting']: + for n in self._raymond_lights: + n.light.intensity = vli / 3.0 + if not self.scene.has_node(n): + scene.add_node(n, parent_node=self._camera_node) + else: + self._direct_light.light.intensity = vli + for n in self._raymond_lights: + if self.scene.has_node(n): + self.scene.remove_node(n) + + if self.viewer_flags['use_direct_lighting']: + if not self.scene.has_node(self._direct_light): + scene.add_node( + self._direct_light, parent_node=self._camera_node + ) + elif self.scene.has_node(self._direct_light): + self.scene.remove_node(self._direct_light) + + flags = RenderFlags.NONE + if self.render_flags['flip_wireframe']: + flags |= RenderFlags.FLIP_WIREFRAME + elif self.render_flags['all_wireframe']: + flags |= RenderFlags.ALL_WIREFRAME + elif self.render_flags['all_solid']: + flags |= RenderFlags.ALL_SOLID + + if self.render_flags['shadows']: + flags |= RenderFlags.SHADOWS_DIRECTIONAL | RenderFlags.SHADOWS_SPOT + if self.render_flags['vertex_normals']: + flags |= RenderFlags.VERTEX_NORMALS + if self.render_flags['face_normals']: + flags |= RenderFlags.FACE_NORMALS + if not self.render_flags['cull_faces']: + flags |= RenderFlags.SKIP_CULL_FACES + + self._renderer.render(self.scene, flags) + + def _init_and_start_app(self): + # Try multiple configs starting with target OpenGL version + # and multisampling and removing these options if exception + # Note: multisampling not available on all hardware + from pyglet.gl import Config + confs = [Config(sample_buffers=1, samples=4, + depth_size=24, + double_buffer=True, + major_version=TARGET_OPEN_GL_MAJOR, + minor_version=TARGET_OPEN_GL_MINOR), + Config(depth_size=24, + double_buffer=True, + major_version=TARGET_OPEN_GL_MAJOR, + minor_version=TARGET_OPEN_GL_MINOR), + Config(sample_buffers=1, samples=4, + depth_size=24, + double_buffer=True, + major_version=MIN_OPEN_GL_MAJOR, + minor_version=MIN_OPEN_GL_MINOR), + Config(depth_size=24, + double_buffer=True, + major_version=MIN_OPEN_GL_MAJOR, + minor_version=MIN_OPEN_GL_MINOR)] + for conf in confs: + try: + super(Viewer, self).__init__(config=conf, resizable=True, + width=self._viewport_size[0], + height=self._viewport_size[1]) + break + except pyglet.window.NoSuchConfigException: + pass + + if not self.context: + raise ValueError('Unable to initialize an OpenGL 3+ context') + clock.schedule_interval( + Viewer._time_event, 1.0 / self.viewer_flags['refresh_rate'], self + ) + self.switch_to() + self.set_caption(self.viewer_flags['window_title']) + pyglet.app.run() + + def _compute_initial_camera_pose(self): + centroid = self.scene.centroid + if self.viewer_flags['view_center'] is not None: + centroid = self.viewer_flags['view_center'] + scale = self.scene.scale + if scale == 0.0: + scale = DEFAULT_SCENE_SCALE + + s2 = 1.0 / np.sqrt(2.0) + cp = np.eye(4) + cp[:3,:3] = np.array([ + [0.0, -s2, s2], + [1.0, 0.0, 0.0], + [0.0, s2, s2] + ]) + hfov = np.pi / 6.0 + dist = scale / (2.0 * np.tan(hfov)) + cp[:3,3] = dist * np.array([1.0, 0.0, 1.0]) + centroid + + return cp + + def _create_raymond_lights(self): + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3,:3] = np.c_[x,y,z] + nodes.append(Node( + light=DirectionalLight(color=np.ones(3), intensity=1.0), + matrix=matrix + )) + + return nodes + + def _create_direct_light(self): + light = DirectionalLight(color=np.ones(3), intensity=1.0) + n = Node(light=light, matrix=np.eye(4)) + return n + + def _set_axes(self, world, mesh): + scale = self.scene.scale + if world: + if 'scene' not in self._axes: + n = Node(mesh=self._axis_mesh, scale=np.ones(3) * scale * 0.3) + self.scene.add_node(n) + self._axes['scene'] = n + else: + if 'scene' in self._axes: + self.scene.remove_node(self._axes['scene']) + self._axes.pop('scene') + + if mesh: + old_nodes = [] + existing_axes = set([self._axes[k] for k in self._axes]) + for node in self.scene.mesh_nodes: + if node not in existing_axes: + old_nodes.append(node) + + for node in old_nodes: + if node in self._axes: + continue + n = Node( + mesh=self._axis_mesh, + scale=np.ones(3) * node.mesh.scale * 0.5 + ) + self.scene.add_node(n, parent_node=node) + self._axes[node] = n + else: + to_remove = set() + for main_node in self._axes: + if main_node in self.scene.mesh_nodes: + self.scene.remove_node(self._axes[main_node]) + to_remove.add(main_node) + for main_node in to_remove: + self._axes.pop(main_node) + + def _remove_axes(self): + for main_node in self._axes: + axis_node = self._axes[main_node] + self.scene.remove_node(axis_node) + self._axes = {} + + def _location_to_x_y(self, location): + if location == TextAlign.CENTER: + return (self.viewport_size[0] / 2.0, self.viewport_size[1] / 2.0) + elif location == TextAlign.CENTER_LEFT: + return (TEXT_PADDING, self.viewport_size[1] / 2.0) + elif location == TextAlign.CENTER_RIGHT: + return (self.viewport_size[0] - TEXT_PADDING, + self.viewport_size[1] / 2.0) + elif location == TextAlign.BOTTOM_LEFT: + return (TEXT_PADDING, TEXT_PADDING) + elif location == TextAlign.BOTTOM_RIGHT: + return (self.viewport_size[0] - TEXT_PADDING, TEXT_PADDING) + elif location == TextAlign.BOTTOM_CENTER: + return (self.viewport_size[0] / 2.0, TEXT_PADDING) + elif location == TextAlign.TOP_LEFT: + return (TEXT_PADDING, self.viewport_size[1] - TEXT_PADDING) + elif location == TextAlign.TOP_RIGHT: + return (self.viewport_size[0] - TEXT_PADDING, + self.viewport_size[1] - TEXT_PADDING) + elif location == TextAlign.TOP_CENTER: + return (self.viewport_size[0] / 2.0, + self.viewport_size[1] - TEXT_PADDING) + + +__all__ = ['Viewer'] diff --git a/pyrender/requirements.txt b/pyrender/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8c40b74256f0dc6697754bb8609f69a39d51beba --- /dev/null +++ b/pyrender/requirements.txt @@ -0,0 +1,14 @@ +freetype-py +imageio +networkx +numpy +Pillow +pyglet==1.4.0a1 +PyOpenGL +PyOpenGL_accelerate +six +trimesh +sphinx +sphinx_rtd_theme +sphinx-automodapi + diff --git a/pyrender/setup.py b/pyrender/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b5ba0da2b0f17b759e5556597981096a80bda8 --- /dev/null +++ b/pyrender/setup.py @@ -0,0 +1,76 @@ +""" +Setup of pyrender Python codebase. + +Author: Matthew Matl +""" +import sys +from setuptools import setup + +# load __version__ +exec(open('pyrender/version.py').read()) + +def get_imageio_dep(): + if sys.version[0] == "2": + return 'imageio<=2.6.1' + return 'imageio' + +requirements = [ + 'freetype-py', # For font loading + get_imageio_dep(), # For Image I/O + 'networkx', # For the scene graph + 'numpy', # Numpy + 'Pillow', # For Trimesh texture conversions + 'pyglet>=1.4.10', # For the pyglet viewer + 'PyOpenGL~=3.1.0', # For OpenGL +# 'PyOpenGL_accelerate~=3.1.0', # For OpenGL + 'scipy', # Because of trimesh missing dep + 'six', # For Python 2/3 interop + 'trimesh', # For meshes +] + +dev_requirements = [ + 'flake8', # Code formatting checker + 'pre-commit', # Pre-commit hooks + 'pytest', # Code testing + 'pytest-cov', # Coverage testing + 'tox', # Automatic virtualenv testing +] + +docs_requirements = [ + 'sphinx', # General doc library + 'sphinx_rtd_theme', # RTD theme for sphinx + 'sphinx-automodapi' # For generating nice tables +] + + +setup( + name = 'pyrender', + version=__version__, + description='Easy-to-use Python renderer for 3D visualization', + long_description='A simple implementation of Physically-Based Rendering ' + '(PBR) in Python. Compliant with the glTF 2.0 standard.', + author='Matthew Matl', + author_email='matthewcmatl@gmail.com', + license='MIT License', + url = 'https://github.com/mmatl/pyrender', + classifiers = [ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: MIT License', + 'Operating System :: POSIX :: Linux', + 'Operating System :: MacOS :: MacOS X', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Natural Language :: English', + 'Topic :: Scientific/Engineering' + ], + keywords = 'rendering graphics opengl 3d visualization pbr gltf', + packages = ['pyrender', 'pyrender.platforms'], + setup_requires = requirements, + install_requires = requirements, + extras_require={ + 'dev': dev_requirements, + 'docs': docs_requirements, + }, + include_package_data=True +) diff --git a/pyrender/tests/__init__.py b/pyrender/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyrender/tests/conftest.py b/pyrender/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyrender/tests/pytest.ini b/pyrender/tests/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyrender/tests/unit/__init__.py b/pyrender/tests/unit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyrender/tests/unit/test_cameras.py b/pyrender/tests/unit/test_cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..7544ad8f8e3ee55236fd2e32dbc12065153cbe5b --- /dev/null +++ b/pyrender/tests/unit/test_cameras.py @@ -0,0 +1,164 @@ +import numpy as np +import pytest + +from pyrender import PerspectiveCamera, OrthographicCamera + + +def test_perspective_camera(): + + # Set up constants + znear = 0.05 + zfar = 100 + yfov = np.pi / 3.0 + width = 1000.0 + height = 500.0 + aspectRatio = 640.0 / 480.0 + + # Test basics + with pytest.raises(TypeError): + p = PerspectiveCamera() + + p = PerspectiveCamera(yfov=yfov) + assert p.yfov == yfov + assert p.znear == 0.05 + assert p.zfar is None + assert p.aspectRatio is None + p.name = 'asdf' + p.name = None + + with pytest.raises(ValueError): + p.yfov = 0.0 + + with pytest.raises(ValueError): + p.yfov = -1.0 + + with pytest.raises(ValueError): + p.znear = -1.0 + + p.znear = 0.0 + p.znear = 0.05 + p.zfar = 100.0 + assert p.zfar == 100.0 + + with pytest.raises(ValueError): + p.zfar = 0.03 + + with pytest.raises(ValueError): + p.zfar = 0.05 + + p.aspectRatio = 10.0 + assert p.aspectRatio == 10.0 + + with pytest.raises(ValueError): + p.aspectRatio = 0.0 + + with pytest.raises(ValueError): + p.aspectRatio = -1.0 + + # Test matrix getting/setting + + # NF + p.znear = 0.05 + p.zfar = 100 + p.aspectRatio = None + + with pytest.raises(ValueError): + p.get_projection_matrix() + + assert np.allclose( + p.get_projection_matrix(width, height), + np.array([ + [1.0 / (width / height * np.tan(yfov / 2.0)), 0.0, 0.0, 0.0], + [0.0, 1.0 / np.tan(yfov / 2.0), 0.0, 0.0], + [0.0, 0.0, (zfar + znear) / (znear - zfar), + (2 * zfar * znear) / (znear - zfar)], + [0.0, 0.0, -1.0, 0.0] + ]) + ) + + # NFA + p.aspectRatio = aspectRatio + assert np.allclose( + p.get_projection_matrix(width, height), + np.array([ + [1.0 / (aspectRatio * np.tan(yfov / 2.0)), 0.0, 0.0, 0.0], + [0.0, 1.0 / np.tan(yfov / 2.0), 0.0, 0.0], + [0.0, 0.0, (zfar + znear) / (znear - zfar), + (2 * zfar * znear) / (znear - zfar)], + [0.0, 0.0, -1.0, 0.0] + ]) + ) + assert np.allclose( + p.get_projection_matrix(), p.get_projection_matrix(width, height) + ) + + # N + p.zfar = None + p.aspectRatio = None + assert np.allclose( + p.get_projection_matrix(width, height), + np.array([ + [1.0 / (width / height * np.tan(yfov / 2.0)), 0.0, 0.0, 0.0], + [0.0, 1.0 / np.tan(yfov / 2.0), 0.0, 0.0], + [0.0, 0.0, -1.0, -2.0 * znear], + [0.0, 0.0, -1.0, 0.0] + ]) + ) + + +def test_orthographic_camera(): + xm = 1.0 + ym = 2.0 + n = 0.05 + f = 100.0 + + with pytest.raises(TypeError): + c = OrthographicCamera() + + c = OrthographicCamera(xmag=xm, ymag=ym) + + assert c.xmag == xm + assert c.ymag == ym + assert c.znear == 0.05 + assert c.zfar == 100.0 + assert c.name is None + + with pytest.raises(TypeError): + c.ymag = None + + with pytest.raises(ValueError): + c.ymag = 0.0 + + with pytest.raises(ValueError): + c.ymag = -1.0 + + with pytest.raises(TypeError): + c.xmag = None + + with pytest.raises(ValueError): + c.xmag = 0.0 + + with pytest.raises(ValueError): + c.xmag = -1.0 + + with pytest.raises(TypeError): + c.znear = None + + with pytest.raises(ValueError): + c.znear = 0.0 + + with pytest.raises(ValueError): + c.znear = -1.0 + + with pytest.raises(ValueError): + c.zfar = 0.01 + + assert np.allclose( + c.get_projection_matrix(), + np.array([ + [1.0 / xm, 0, 0, 0], + [0, 1.0 / ym, 0, 0], + [0, 0, 2.0 / (n - f), (f + n) / (n - f)], + [0, 0, 0, 1.0] + ]) + ) diff --git a/pyrender/tests/unit/test_egl.py b/pyrender/tests/unit/test_egl.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f4bef39e33c2794e6837b5a1bb127d8d4dba06 --- /dev/null +++ b/pyrender/tests/unit/test_egl.py @@ -0,0 +1,16 @@ +# from pyrender.platforms import egl + + +def tmp_test_default_device(): + egl.get_default_device() + + +def tmp_test_query_device(): + devices = egl.query_devices() + assert len(devices) > 0 + + +def tmp_test_init_context(): + device = egl.query_devices()[0] + platform = egl.EGLPlatform(128, 128, device=device) + platform.init_context() diff --git a/pyrender/tests/unit/test_lights.py b/pyrender/tests/unit/test_lights.py new file mode 100644 index 0000000000000000000000000000000000000000..ffde856b21e8cce9532f0308fcd1c7eb2d1eba90 --- /dev/null +++ b/pyrender/tests/unit/test_lights.py @@ -0,0 +1,104 @@ +import numpy as np +import pytest + +from pyrender import (DirectionalLight, SpotLight, PointLight, Texture, + PerspectiveCamera, OrthographicCamera) +from pyrender.constants import SHADOW_TEX_SZ + + +def test_directional_light(): + + d = DirectionalLight() + assert d.name is None + assert np.all(d.color == 1.0) + assert d.intensity == 1.0 + + d.name = 'direc' + with pytest.raises(ValueError): + d.color = None + with pytest.raises(TypeError): + d.intensity = None + + d = DirectionalLight(color=[0.0, 0.0, 0.0]) + assert np.all(d.color == 0.0) + + d._generate_shadow_texture() + st = d.shadow_texture + assert isinstance(st, Texture) + assert st.width == st.height == SHADOW_TEX_SZ + + sc = d._get_shadow_camera(scene_scale=5.0) + assert isinstance(sc, OrthographicCamera) + assert sc.xmag == sc.ymag == 5.0 + assert sc.znear == 0.01 * 5.0 + assert sc.zfar == 10 * 5.0 + + +def test_spot_light(): + + s = SpotLight() + assert s.name is None + assert np.all(s.color == 1.0) + assert s.intensity == 1.0 + assert s.innerConeAngle == 0.0 + assert s.outerConeAngle == np.pi / 4.0 + assert s.range is None + + with pytest.raises(ValueError): + s.range = -1.0 + + with pytest.raises(ValueError): + s.range = 0.0 + + with pytest.raises(ValueError): + s.innerConeAngle = -1.0 + + with pytest.raises(ValueError): + s.innerConeAngle = np.pi / 3.0 + + with pytest.raises(ValueError): + s.outerConeAngle = -1.0 + + with pytest.raises(ValueError): + s.outerConeAngle = np.pi + + s.range = 5.0 + s.outerConeAngle = np.pi / 2 - 0.05 + s.innerConeAngle = np.pi / 3 + s.innerConeAngle = 0.0 + s.outerConeAngle = np.pi / 4.0 + + s._generate_shadow_texture() + st = s.shadow_texture + assert isinstance(st, Texture) + assert st.width == st.height == SHADOW_TEX_SZ + + sc = s._get_shadow_camera(scene_scale=5.0) + assert isinstance(sc, PerspectiveCamera) + assert sc.znear == 0.01 * 5.0 + assert sc.zfar == 10 * 5.0 + assert sc.aspectRatio == 1.0 + assert np.allclose(sc.yfov, np.pi / 16.0 * 9.0) # Plus pi / 16 + + +def test_point_light(): + + s = PointLight() + assert s.name is None + assert np.all(s.color == 1.0) + assert s.intensity == 1.0 + assert s.range is None + + with pytest.raises(ValueError): + s.range = -1.0 + + with pytest.raises(ValueError): + s.range = 0.0 + + s.range = 5.0 + + with pytest.raises(NotImplementedError): + s._generate_shadow_texture() + + with pytest.raises(NotImplementedError): + s._get_shadow_camera(scene_scale=5.0) diff --git a/pyrender/tests/unit/test_meshes.py b/pyrender/tests/unit/test_meshes.py new file mode 100644 index 0000000000000000000000000000000000000000..7070b01171c97069fa013c6eba8eee217017f08e --- /dev/null +++ b/pyrender/tests/unit/test_meshes.py @@ -0,0 +1,133 @@ +import numpy as np +import pytest +import trimesh + +from pyrender import (Mesh, Primitive) + + +def test_meshes(): + + with pytest.raises(TypeError): + x = Mesh() + with pytest.raises(TypeError): + x = Primitive() + with pytest.raises(ValueError): + x = Primitive([], mode=10) + + # Basics + x = Mesh([]) + assert x.name is None + assert x.is_visible + assert x.weights is None + + x.name = 'str' + + # From Trimesh + x = Mesh.from_trimesh(trimesh.creation.box()) + assert isinstance(x, Mesh) + assert len(x.primitives) == 1 + assert x.is_visible + assert np.allclose(x.bounds, np.array([ + [-0.5, -0.5, -0.5], + [0.5, 0.5, 0.5] + ])) + assert np.allclose(x.centroid, np.zeros(3)) + assert np.allclose(x.extents, np.ones(3)) + assert np.allclose(x.scale, np.sqrt(3)) + assert not x.is_transparent + + # Test some primitive functions + x = x.primitives[0] + with pytest.raises(ValueError): + x.normals = np.zeros(10) + with pytest.raises(ValueError): + x.tangents = np.zeros(10) + with pytest.raises(ValueError): + x.texcoord_0 = np.zeros(10) + with pytest.raises(ValueError): + x.texcoord_1 = np.zeros(10) + with pytest.raises(TypeError): + x.material = np.zeros(10) + assert x.targets is None + assert np.allclose(x.bounds, np.array([ + [-0.5, -0.5, -0.5], + [0.5, 0.5, 0.5] + ])) + assert np.allclose(x.centroid, np.zeros(3)) + assert np.allclose(x.extents, np.ones(3)) + assert np.allclose(x.scale, np.sqrt(3)) + x.material.baseColorFactor = np.array([0.0, 0.0, 0.0, 0.0]) + assert x.is_transparent + + # From two trimeshes + x = Mesh.from_trimesh([trimesh.creation.box(), + trimesh.creation.cylinder(radius=0.1, height=2.0)], + smooth=False) + assert isinstance(x, Mesh) + assert len(x.primitives) == 2 + assert x.is_visible + assert np.allclose(x.bounds, np.array([ + [-0.5, -0.5, -1.0], + [0.5, 0.5, 1.0] + ])) + assert np.allclose(x.centroid, np.zeros(3)) + assert np.allclose(x.extents, [1.0, 1.0, 2.0]) + assert np.allclose(x.scale, np.sqrt(6)) + assert not x.is_transparent + + # From bad data + with pytest.raises(TypeError): + x = Mesh.from_trimesh(None) + + # With instancing + poses = np.tile(np.eye(4), (5,1,1)) + poses[:,0,3] = np.array([0,1,2,3,4]) + x = Mesh.from_trimesh(trimesh.creation.box(), poses=poses) + assert np.allclose(x.bounds, np.array([ + [-0.5, -0.5, -0.5], + [4.5, 0.5, 0.5] + ])) + poses = np.eye(4) + x = Mesh.from_trimesh(trimesh.creation.box(), poses=poses) + poses = np.eye(3) + with pytest.raises(ValueError): + x = Mesh.from_trimesh(trimesh.creation.box(), poses=poses) + + # From textured meshes + fm = trimesh.load('tests/data/fuze.obj') + x = Mesh.from_trimesh(fm) + assert isinstance(x, Mesh) + assert len(x.primitives) == 1 + assert x.is_visible + assert not x.is_transparent + assert x.primitives[0].material.baseColorTexture is not None + + x = Mesh.from_trimesh(fm, smooth=False) + fm.visual = fm.visual.to_color() + fm.visual.face_colors = np.array([1.0, 0.0, 0.0, 1.0]) + x = Mesh.from_trimesh(fm, smooth=False) + with pytest.raises(ValueError): + x = Mesh.from_trimesh(fm, smooth=True) + + fm.visual.vertex_colors = np.array([1.0, 0.0, 0.0, 0.5]) + x = Mesh.from_trimesh(fm, smooth=False) + x = Mesh.from_trimesh(fm, smooth=True) + assert x.primitives[0].color_0 is not None + assert x.is_transparent + + bm = trimesh.load('tests/data/WaterBottle.glb').dump()[0] + x = Mesh.from_trimesh(bm) + assert x.primitives[0].material.baseColorTexture is not None + assert x.primitives[0].material.emissiveTexture is not None + assert x.primitives[0].material.metallicRoughnessTexture is not None + + # From point cloud + x = Mesh.from_points(fm.vertices) + +# def test_duck(): +# bm = trimesh.load('tests/data/Duck.glb').dump()[0] +# x = Mesh.from_trimesh(bm) +# assert x.primitives[0].material.baseColorTexture is not None +# pixel = x.primitives[0].material.baseColorTexture.source[100, 100] +# yellowish = np.array([1.0, 0.7411765, 0.0, 1.0]) +# assert np.allclose(pixel, yellowish) diff --git a/pyrender/tests/unit/test_nodes.py b/pyrender/tests/unit/test_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..9857c8221b7f6fb8530699bdf5593f8f0b74e152 --- /dev/null +++ b/pyrender/tests/unit/test_nodes.py @@ -0,0 +1,124 @@ +import numpy as np +import pytest +from trimesh import transformations + +from pyrender import (DirectionalLight, PerspectiveCamera, Mesh, Node) + + +def test_nodes(): + + x = Node() + assert x.name is None + assert x.camera is None + assert x.children == [] + assert x.skin is None + assert np.allclose(x.matrix, np.eye(4)) + assert x.mesh is None + assert np.allclose(x.rotation, [0,0,0,1]) + assert np.allclose(x.scale, np.ones(3)) + assert np.allclose(x.translation, np.zeros(3)) + assert x.weights is None + assert x.light is None + + x.name = 'node' + + # Test node light/camera/mesh tests + c = PerspectiveCamera(yfov=2.0) + m = Mesh([]) + d = DirectionalLight() + x.camera = c + assert x.camera == c + with pytest.raises(TypeError): + x.camera = m + x.camera = d + x.camera = None + x.mesh = m + assert x.mesh == m + with pytest.raises(TypeError): + x.mesh = c + x.mesh = d + x.light = d + assert x.light == d + with pytest.raises(TypeError): + x.light = m + x.light = c + + # Test transformations getters/setters/etc... + # Set up test values + x = np.array([1.0, 0.0, 0.0]) + y = np.array([0.0, 1.0, 0.0]) + t = np.array([1.0, 2.0, 3.0]) + s = np.array([0.5, 2.0, 1.0]) + + Mx = transformations.rotation_matrix(np.pi / 2.0, x) + qx = np.roll(transformations.quaternion_about_axis(np.pi / 2.0, x), -1) + Mxt = Mx.copy() + Mxt[:3,3] = t + S = np.eye(4) + S[:3,:3] = np.diag(s) + Mxts = Mxt.dot(S) + + My = transformations.rotation_matrix(np.pi / 2.0, y) + qy = np.roll(transformations.quaternion_about_axis(np.pi / 2.0, y), -1) + Myt = My.copy() + Myt[:3,3] = t + + x = Node(matrix=Mx) + assert np.allclose(x.matrix, Mx) + assert np.allclose(x.rotation, qx) + assert np.allclose(x.translation, np.zeros(3)) + assert np.allclose(x.scale, np.ones(3)) + + x.matrix = My + assert np.allclose(x.matrix, My) + assert np.allclose(x.rotation, qy) + assert np.allclose(x.translation, np.zeros(3)) + assert np.allclose(x.scale, np.ones(3)) + x.translation = t + assert np.allclose(x.matrix, Myt) + assert np.allclose(x.rotation, qy) + x.rotation = qx + assert np.allclose(x.matrix, Mxt) + x.scale = s + assert np.allclose(x.matrix, Mxts) + + x = Node(matrix=Mxt) + assert np.allclose(x.matrix, Mxt) + assert np.allclose(x.rotation, qx) + assert np.allclose(x.translation, t) + assert np.allclose(x.scale, np.ones(3)) + + x = Node(matrix=Mxts) + assert np.allclose(x.matrix, Mxts) + assert np.allclose(x.rotation, qx) + assert np.allclose(x.translation, t) + assert np.allclose(x.scale, s) + + # Individual element getters + x.scale[0] = 0 + assert np.allclose(x.scale[0], 0) + + x.translation[0] = 0 + assert np.allclose(x.translation[0], 0) + + x.matrix = np.eye(4) + x.matrix[0,0] = 500 + assert x.matrix[0,0] == 1.0 + + # Failures + with pytest.raises(ValueError): + x.matrix = 5 * np.eye(4) + with pytest.raises(ValueError): + x.matrix = np.eye(5) + with pytest.raises(ValueError): + x.matrix = np.eye(4).dot([5,1,1,1]) + with pytest.raises(ValueError): + x.rotation = np.array([1,2]) + with pytest.raises(ValueError): + x.rotation = np.array([1,2,3]) + with pytest.raises(ValueError): + x.rotation = np.array([1,2,3,4]) + with pytest.raises(ValueError): + x.translation = np.array([1,2,3,4]) + with pytest.raises(ValueError): + x.scale = np.array([1,2,3,4]) diff --git a/pyrender/tests/unit/test_offscreen.py b/pyrender/tests/unit/test_offscreen.py new file mode 100644 index 0000000000000000000000000000000000000000..88983b0ff4e2ab6f5ef252c51f2ac669c3a0e0ca --- /dev/null +++ b/pyrender/tests/unit/test_offscreen.py @@ -0,0 +1,92 @@ +import numpy as np +import trimesh + +from pyrender import (OffscreenRenderer, PerspectiveCamera, DirectionalLight, + SpotLight, Mesh, Node, Scene) + + +def test_offscreen_renderer(tmpdir): + + # Fuze trimesh + fuze_trimesh = trimesh.load('examples/models/fuze.obj') + fuze_mesh = Mesh.from_trimesh(fuze_trimesh) + + # Drill trimesh + drill_trimesh = trimesh.load('examples/models/drill.obj') + drill_mesh = Mesh.from_trimesh(drill_trimesh) + drill_pose = np.eye(4) + drill_pose[0,3] = 0.1 + drill_pose[2,3] = -np.min(drill_trimesh.vertices[:,2]) + + # Wood trimesh + wood_trimesh = trimesh.load('examples/models/wood.obj') + wood_mesh = Mesh.from_trimesh(wood_trimesh) + + # Water bottle trimesh + bottle_gltf = trimesh.load('examples/models/WaterBottle.glb') + bottle_trimesh = bottle_gltf.geometry[list(bottle_gltf.geometry.keys())[0]] + bottle_mesh = Mesh.from_trimesh(bottle_trimesh) + bottle_pose = np.array([ + [1.0, 0.0, 0.0, 0.1], + [0.0, 0.0, -1.0, -0.16], + [0.0, 1.0, 0.0, 0.13], + [0.0, 0.0, 0.0, 1.0], + ]) + + boxv_trimesh = trimesh.creation.box(extents=0.1 * np.ones(3)) + boxv_vertex_colors = np.random.uniform(size=(boxv_trimesh.vertices.shape)) + boxv_trimesh.visual.vertex_colors = boxv_vertex_colors + boxv_mesh = Mesh.from_trimesh(boxv_trimesh, smooth=False) + boxf_trimesh = trimesh.creation.box(extents=0.1 * np.ones(3)) + boxf_face_colors = np.random.uniform(size=boxf_trimesh.faces.shape) + boxf_trimesh.visual.face_colors = boxf_face_colors + # Instanced + poses = np.tile(np.eye(4), (2,1,1)) + poses[0,:3,3] = np.array([-0.1, -0.10, 0.05]) + poses[1,:3,3] = np.array([-0.15, -0.10, 0.05]) + boxf_mesh = Mesh.from_trimesh(boxf_trimesh, poses=poses, smooth=False) + + points = trimesh.creation.icosphere(radius=0.05).vertices + point_colors = np.random.uniform(size=points.shape) + points_mesh = Mesh.from_points(points, colors=point_colors) + + direc_l = DirectionalLight(color=np.ones(3), intensity=1.0) + spot_l = SpotLight(color=np.ones(3), intensity=10.0, + innerConeAngle=np.pi / 16, outerConeAngle=np.pi / 6) + + cam = PerspectiveCamera(yfov=(np.pi / 3.0)) + cam_pose = np.array([ + [0.0, -np.sqrt(2) / 2, np.sqrt(2) / 2, 0.5], + [1.0, 0.0, 0.0, 0.0], + [0.0, np.sqrt(2) / 2, np.sqrt(2) / 2, 0.4], + [0.0, 0.0, 0.0, 1.0] + ]) + + scene = Scene(ambient_light=np.array([0.02, 0.02, 0.02])) + + fuze_node = Node(mesh=fuze_mesh, translation=np.array([ + 0.1, 0.15, -np.min(fuze_trimesh.vertices[:,2]) + ])) + scene.add_node(fuze_node) + boxv_node = Node(mesh=boxv_mesh, translation=np.array([-0.1, 0.10, 0.05])) + scene.add_node(boxv_node) + boxf_node = Node(mesh=boxf_mesh) + scene.add_node(boxf_node) + + _ = scene.add(drill_mesh, pose=drill_pose) + _ = scene.add(bottle_mesh, pose=bottle_pose) + _ = scene.add(wood_mesh) + _ = scene.add(direc_l, pose=cam_pose) + _ = scene.add(spot_l, pose=cam_pose) + _ = scene.add(points_mesh) + + _ = scene.add(cam, pose=cam_pose) + + r = OffscreenRenderer(viewport_width=640, viewport_height=480) + color, depth = r.render(scene) + + assert color.shape == (480, 640, 3) + assert depth.shape == (480, 640) + assert np.max(depth.data) > 0.05 + assert np.count_nonzero(depth.data) > (0.2 * depth.size) + r.delete() diff --git a/pyrender/tests/unit/test_scenes.py b/pyrender/tests/unit/test_scenes.py new file mode 100644 index 0000000000000000000000000000000000000000..d85dd714cb5d842ea12dee4140adfd7db55c9c01 --- /dev/null +++ b/pyrender/tests/unit/test_scenes.py @@ -0,0 +1,235 @@ +import numpy as np +import pytest +import trimesh + +from pyrender import (Mesh, PerspectiveCamera, DirectionalLight, + SpotLight, PointLight, Scene, Node, OrthographicCamera) + + +def test_scenes(): + + # Basics + s = Scene() + assert np.allclose(s.bg_color, np.ones(4)) + assert np.allclose(s.ambient_light, np.zeros(3)) + assert len(s.nodes) == 0 + assert s.name is None + s.name = 'asdf' + s.bg_color = None + s.ambient_light = None + assert np.allclose(s.bg_color, np.ones(4)) + assert np.allclose(s.ambient_light, np.zeros(3)) + + assert s.nodes == set() + assert s.cameras == set() + assert s.lights == set() + assert s.point_lights == set() + assert s.spot_lights == set() + assert s.directional_lights == set() + assert s.meshes == set() + assert s.camera_nodes == set() + assert s.light_nodes == set() + assert s.point_light_nodes == set() + assert s.spot_light_nodes == set() + assert s.directional_light_nodes == set() + assert s.mesh_nodes == set() + assert s.main_camera_node is None + assert np.all(s.bounds == 0) + assert np.all(s.centroid == 0) + assert np.all(s.extents == 0) + assert np.all(s.scale == 0) + + # From trimesh scene + tms = trimesh.load('tests/data/WaterBottle.glb') + s = Scene.from_trimesh_scene(tms) + assert len(s.meshes) == 1 + assert len(s.mesh_nodes) == 1 + + # Test bg color formatting + s = Scene(bg_color=[0, 1.0, 0]) + assert np.allclose(s.bg_color, np.array([0.0, 1.0, 0.0, 1.0])) + + # Test constructor for nodes + n1 = Node() + n2 = Node() + n3 = Node() + nodes = [n1, n2, n3] + s = Scene(nodes=nodes) + n1.children.append(n2) + s = Scene(nodes=nodes) + n3.children.append(n2) + with pytest.raises(ValueError): + s = Scene(nodes=nodes) + n3.children = [] + n2.children.append(n3) + n3.children.append(n2) + with pytest.raises(ValueError): + s = Scene(nodes=nodes) + + # Test node accessors + n1 = Node() + n2 = Node() + n3 = Node() + nodes = [n1, n2] + s = Scene(nodes=nodes) + assert s.has_node(n1) + assert s.has_node(n2) + assert not s.has_node(n3) + + # Test node poses + for n in nodes: + assert np.allclose(s.get_pose(n), np.eye(4)) + with pytest.raises(ValueError): + s.get_pose(n3) + with pytest.raises(ValueError): + s.set_pose(n3, np.eye(4)) + tf = np.eye(4) + tf[:3,3] = np.ones(3) + s.set_pose(n1, tf) + assert np.allclose(s.get_pose(n1), tf) + assert np.allclose(s.get_pose(n2), np.eye(4)) + + nodes = [n1, n2, n3] + tf2 = np.eye(4) + tf2[:3,:3] = np.diag([-1,-1,1]) + n1.children.append(n2) + n1.matrix = tf + n2.matrix = tf2 + s = Scene(nodes=nodes) + assert np.allclose(s.get_pose(n1), tf) + assert np.allclose(s.get_pose(n2), tf.dot(tf2)) + assert np.allclose(s.get_pose(n3), np.eye(4)) + + n1 = Node() + n2 = Node() + n3 = Node() + n1.children.append(n2) + s = Scene() + s.add_node(n1) + with pytest.raises(ValueError): + s.add_node(n2) + s.set_pose(n1, tf) + assert np.allclose(s.get_pose(n1), tf) + assert np.allclose(s.get_pose(n2), tf) + s.set_pose(n2, tf2) + assert np.allclose(s.get_pose(n2), tf.dot(tf2)) + + # Test node removal + n1 = Node() + n2 = Node() + n3 = Node() + n1.children.append(n2) + n2.children.append(n3) + s = Scene(nodes=[n1, n2, n3]) + s.remove_node(n2) + assert len(s.nodes) == 1 + assert n1 in s.nodes + assert len(n1.children) == 0 + assert len(n2.children) == 1 + s.add_node(n2, parent_node=n1) + assert len(n1.children) == 1 + n1.matrix = tf + n3.matrix = tf2 + assert np.allclose(s.get_pose(n3), tf.dot(tf2)) + + # Now test ADD function + s = Scene() + m = Mesh([], name='m') + cp = PerspectiveCamera(yfov=2.0) + co = OrthographicCamera(xmag=1.0, ymag=1.0) + dl = DirectionalLight() + pl = PointLight() + sl = SpotLight() + + n1 = s.add(m, name='mn') + assert n1.mesh == m + assert len(s.nodes) == 1 + assert len(s.mesh_nodes) == 1 + assert n1 in s.mesh_nodes + assert len(s.meshes) == 1 + assert m in s.meshes + assert len(s.get_nodes(node=n2)) == 0 + n2 = s.add(m, pose=tf) + assert len(s.nodes) == len(s.mesh_nodes) == 2 + assert len(s.meshes) == 1 + assert len(s.get_nodes(node=n1)) == 1 + assert len(s.get_nodes(node=n1, name='mn')) == 1 + assert len(s.get_nodes(name='mn')) == 1 + assert len(s.get_nodes(obj=m)) == 2 + assert len(s.get_nodes(obj=m, obj_name='m')) == 2 + assert len(s.get_nodes(obj=co)) == 0 + nsl = s.add(sl, name='sln') + npl = s.add(pl, parent_name='sln') + assert nsl.children[0] == npl + ndl = s.add(dl, parent_node=npl) + assert npl.children[0] == ndl + nco = s.add(co) + ncp = s.add(cp) + + assert len(s.light_nodes) == len(s.lights) == 3 + assert len(s.point_light_nodes) == len(s.point_lights) == 1 + assert npl in s.point_light_nodes + assert len(s.spot_light_nodes) == len(s.spot_lights) == 1 + assert nsl in s.spot_light_nodes + assert len(s.directional_light_nodes) == len(s.directional_lights) == 1 + assert ndl in s.directional_light_nodes + assert len(s.cameras) == len(s.camera_nodes) == 2 + assert s.main_camera_node == nco + s.main_camera_node = ncp + s.remove_node(ncp) + assert len(s.cameras) == len(s.camera_nodes) == 1 + assert s.main_camera_node == nco + s.remove_node(n2) + assert len(s.meshes) == 1 + s.remove_node(n1) + assert len(s.meshes) == 0 + s.remove_node(nsl) + assert len(s.lights) == 0 + s.remove_node(nco) + assert s.main_camera_node is None + + s.add_node(n1) + s.clear() + assert len(s.nodes) == 0 + + # Trigger final errors + with pytest.raises(ValueError): + s.main_camera_node = None + with pytest.raises(ValueError): + s.main_camera_node = ncp + with pytest.raises(ValueError): + s.add(m, parent_node=n1) + with pytest.raises(ValueError): + s.add(m, name='asdf') + s.add(m, name='asdf') + s.add(m, parent_name='asdf') + with pytest.raises(ValueError): + s.add(m, parent_name='asfd') + with pytest.raises(TypeError): + s.add(None) + + s.clear() + # Test bounds + m1 = Mesh.from_trimesh(trimesh.creation.box()) + m2 = Mesh.from_trimesh(trimesh.creation.box()) + m3 = Mesh.from_trimesh(trimesh.creation.box()) + n1 = Node(mesh=m1) + n2 = Node(mesh=m2, translation=[1.0, 0.0, 0.0]) + n3 = Node(mesh=m3, translation=[0.5, 0.0, 1.0]) + s.add_node(n1) + s.add_node(n2) + s.add_node(n3) + assert np.allclose(s.bounds, [[-0.5, -0.5, -0.5], [1.5, 0.5, 1.5]]) + s.clear() + s.add_node(n1) + s.add_node(n2, parent_node=n1) + s.add_node(n3, parent_node=n2) + assert np.allclose(s.bounds, [[-0.5, -0.5, -0.5], [2.0, 0.5, 1.5]]) + tf = np.eye(4) + tf[:3,3] = np.ones(3) + s.set_pose(n3, tf) + assert np.allclose(s.bounds, [[-0.5, -0.5, -0.5], [2.5, 1.5, 1.5]]) + s.remove_node(n2) + assert np.allclose(s.bounds, [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]]) + s.clear() + assert np.allclose(s.bounds, 0.0) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..924c282b799630240602eb639d8fb25d0505d74c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,39 @@ +ffmpeg +ConfigArgParse==1.7 +fasttext==0.9.2 +h5py==3.10.0 +imageio==2.31.4 +ipython==8.12.3 +joblib==1.3.2 +librosa==0.10.1 +lmdb==1.4.1 +loguru==0.7.2 +matplotlib==3.7.3 +moviepy==1.0.3 +gradio +fasttext-wheel +opencv_contrib_python==4.8.1.78 +opencv_python==4.8.1.78 +pandas==1.5.3 +peakutils==1.3.4 +ptflops==0.7.1.2 +python_igraph==0.11.3 +pyvirtualdisplay==3.0 +PyYAML==6.0.1 +replicate==0.15.4 +scikit_learn==1.3.2 +scipy +soundfile==0.12.1 +termcolor==2.4.0 +textgrid==1.5 +torch==2.1.0 +torchvision +tqdm==4.66.1 +transformers==4.35.2 +trimesh==3.23.5 +wandb==0.16.0 +pyglet<2 +smplx +tensorboard +pyrender +pyarrow \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..dda6f4be674b353b49a59de12bb20882dbeba756 --- /dev/null +++ b/test.py @@ -0,0 +1,223 @@ +import os +import signal +import time +import csv +import sys +import warnings +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp +import numpy as np +import time +import pprint +from loguru import logger +import smplx +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from utils import config, logger_tools, other_tools, metric +from dataloaders import data_tools +from dataloaders.build_vocab import Vocab +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func + + +class BaseTrainer(object): + def __init__(self, args): + self.args = args + self.rank = dist.get_rank() + self.checkpoint_path = args.out_path + "custom/" + args.name + args.notes + "/" #wandb.run.dir #args.cache_path+args.out_path+"/"+args.name + if self.rank==0: + if self.args.stat == "ts": + self.writer = SummaryWriter(log_dir=args.out_path + "custom/" + args.name + args.notes + "/") + else: + wandb.init(project=args.project, entity="liu1997", dir=args.out_path, name=args.name[12:] + args.notes) + wandb.config.update(args) + self.writer = None + #self.test_demo = args.data_path + args.test_data_path + "bvh_full/" + # self.train_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "train") + # self.train_loader = torch.utils.data.DataLoader( + # self.train_data, + # batch_size=args.batch_size, + # shuffle=False if args.ddp else True, + # num_workers=args.loader_workers, + # drop_last=True, + # sampler=torch.utils.data.distributed.DistributedSampler(self.train_data) if args.ddp else None, + # ) + # self.train_length = len(self.train_loader) + # logger.info(f"Init train dataloader success") + + # self.val_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "val") + # self.val_loader = torch.utils.data.DataLoader( + # self.val_data, + # batch_size=args.batch_size, + # shuffle=False, + # num_workers=args.loader_workers, + # drop_last=False, + # sampler=torch.utils.data.distributed.DistributedSampler(self.val_data) if args.ddp else None, + # ) + # logger.info(f"Init val dataloader success") + if self.rank == 0: + self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test") + self.test_loader = torch.utils.data.DataLoader( + self.test_data, + batch_size=1, + shuffle=False, + num_workers=args.loader_workers, + drop_last=False, + ) + logger.info(f"Init test dataloader success") + model_module = __import__(f"models.{args.model}", fromlist=["something"]) + + if args.ddp: + self.model = getattr(model_module, args.g_name)(args).to(self.rank) + process_group = torch.distributed.new_group() + self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model, process_group) + self.model = DDP(self.model, device_ids=[self.rank], output_device=self.rank, + broadcast_buffers=False, find_unused_parameters=False) + else: + self.model = torch.nn.DataParallel(getattr(model_module, args.g_name)(args), args.gpus).cuda() + + if self.rank == 0: + logger.info(self.model) + logger.info(f"init {args.g_name} success") + if args.stat == "wandb": + wandb.watch(self.model) + + # if args.d_name is not None: + # if args.ddp: + # self.d_model = getattr(model_module, args.d_name)(args).to(self.rank) + # self.d_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.d_model, process_group) + # self.d_model = DDP(self.d_model, device_ids=[self.rank], output_device=self.rank, + # broadcast_buffers=False, find_unused_parameters=False) + # else: + # self.d_model = torch.nn.DataParallel(getattr(model_module, args.d_name)(args), args.gpus).cuda() + # if self.rank == 0: + # logger.info(self.d_model) + # logger.info(f"init {args.d_name} success") + # if args.stat == "wandb": + # wandb.watch(self.d_model) + # self.opt_d = create_optimizer(args, self.d_model, lr_weight=args.d_lr_weight) + # self.opt_d_s = create_scheduler(args, self.opt_d) + + if args.e_name is not None: + """ + bugs on DDP training using eval_model, using additional eval_copy for evaluation + """ + eval_model_module = __import__(f"models.{args.eval_model}", fromlist=["something"]) + # eval copy is for single card evaluation + if self.args.ddp: + self.eval_model = getattr(eval_model_module, args.e_name)(args).to(self.rank) + self.eval_copy = getattr(eval_model_module, args.e_name)(args).to(self.rank) + else: + self.eval_model = getattr(eval_model_module, args.e_name)(args) + self.eval_copy = getattr(eval_model_module, args.e_name)(args).to(self.rank) + + #if self.rank == 0: + other_tools.load_checkpoints(self.eval_copy, args.data_path+args.e_path, args.e_name) + other_tools.load_checkpoints(self.eval_model, args.data_path+args.e_path, args.e_name) + if self.args.ddp: + self.eval_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.eval_model, process_group) + self.eval_model = DDP(self.eval_model, device_ids=[self.rank], output_device=self.rank, + broadcast_buffers=False, find_unused_parameters=False) + self.eval_model.eval() + self.eval_copy.eval() + if self.rank == 0: + logger.info(self.eval_model) + logger.info(f"init {args.e_name} success") + if args.stat == "wandb": + wandb.watch(self.eval_model) + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).to(self.rank).eval() + self.alignmenter = metric.alignment(0.3, 7, self.train_data.avg_vel, upper_body=[3,6,9,12,13,14,15,16,17,18,19,20,21]) if self.rank == 0 else None + self.align_mask = 60 + self.l1_calculator = metric.L1div() if self.rank == 0 else None + + def train_recording(self, epoch, its, t_data, t_train, mem_cost, lr_g, lr_d=None): + pstr = "[%03d][%03d/%03d] "%(epoch, its, self.train_length) + for name, states in self.tracker.loss_meters.items(): + metric = states['train'] + if metric.count > 0: + pstr += "{}: {:.3f}\t".format(name, metric.avg) + self.writer.add_scalar(f"train/{name}", metric.avg, epoch*self.train_length+its) if self.args.stat == "ts" else wandb.log({name: metric.avg}, step=epoch*self.train_length+its) + pstr += "glr: {:.1e}\t".format(lr_g) + self.writer.add_scalar("lr/glr", lr_g, epoch*self.train_length+its) if self.args.stat == "ts" else wandb.log({'glr': lr_g}, step=epoch*self.train_length+its) + if lr_d is not None: + pstr += "dlr: {:.1e}\t".format(lr_d) + self.writer.add_scalar("lr/dlr", lr_d, epoch*self.train_length+its) if self.args.stat == "ts" else wandb.log({'dlr': lr_d}, step=epoch*self.train_length+its) + pstr += "dtime: %04d\t"%(t_data*1000) + pstr += "ntime: %04d\t"%(t_train*1000) + pstr += "mem: {:.2f} ".format(mem_cost*len(self.args.gpus)) + logger.info(pstr) + + def val_recording(self, epoch): + pstr_curr = "Curr info >>>> " + pstr_best = "Best info >>>> " + for name, states in self.tracker.loss_meters.items(): + metric = states['val'] + if metric.count > 0: + pstr_curr += "{}: {:.3f} \t".format(name, metric.avg) + if epoch != 0: + if self.args.stat == "ts": + self.writer.add_scalars(f"val/{name}", {name+"_val":metric.avg, name+"_train":states['train'].avg}, epoch*self.train_length) + else: + wandb.log({name+"_val": metric.avg, name+"_train":states['train'].avg}, step=epoch*self.train_length) + new_best_train, new_best_val = self.tracker.update_and_plot(name, epoch, self.checkpoint_path+f"{name}_{self.args.name+self.args.notes}.png") + if new_best_val: + other_tools.save_checkpoints(os.path.join(self.checkpoint_path, f"{name}.bin"), self.model, opt=None, epoch=None, lrs=None) + for k, v in self.tracker.values.items(): + metric = v['val']['best'] + if self.tracker.loss_meters[k]['val'].count > 0: + pstr_best += "{}: {:.3f}({:03d})\t".format(k, metric['value'], metric['epoch']) + logger.info(pstr_curr) + logger.info(pstr_best) + + def test_recording(self, dict_name, value, epoch): + self.tracker.update_meter(dict_name, "test", value) + _ = self.tracker.update_values(dict_name, 'test', epoch) + +@logger.catch +def main_worker(rank, world_size, args): + #os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/" + if not sys.warnoptions: + warnings.simplefilter("ignore") + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + logger_tools.set_args_and_logger(args, rank) + other_tools.set_random_seed(args) + other_tools.print_exp_info(args) + + # return one intance of trainer + trainer = __import__(f"{args.trainer}_trainer", fromlist=["something"]).CustomTrainer(args) if args.trainer != "base" else BaseTrainer(args) + other_tools.load_checkpoints(trainer.model, args.test_ckpt, args.g_name) + trainer.test(999) + + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"]='127.0.0.1' + os.environ["MASTER_PORT"]='8675' + #os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + args = config.parse_args() + if args.ddp: + mp.set_start_method("spawn", force=True) + mp.spawn( + main_worker, + args=(len(args.gpus), args,), + nprocs=len(args.gpus), + ) + else: + main_worker(0, 1, args) \ No newline at end of file diff --git a/test_demo.py b/test_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad0b83c01947ccba896e8dd4d20a5b36860a060 --- /dev/null +++ b/test_demo.py @@ -0,0 +1,581 @@ +import os +import signal +import time +import csv +import sys +import warnings +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp +import numpy as np +import time +import pprint +from loguru import logger +import smplx +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from utils import config, logger_tools, other_tools, metric, data_transfer +from dataloaders import data_tools +from dataloaders.build_vocab import Vocab +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from dataloaders.data_tools import joints_list +from utils import rotation_conversions as rc + +class BaseTrainer(object): + def __init__(self, args): + self.args = args + self.rank = dist.get_rank() + self.checkpoint_path = args.out_path + "custom/" + args.name + args.notes + "/" #wandb.run.dir #args.cache_path+args.out_path+"/"+args.name + if self.rank == 0: + self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test") + self.test_loader = torch.utils.data.DataLoader( + self.test_data, + batch_size=1, + shuffle=False, + num_workers=args.loader_workers, + drop_last=False, + ) + logger.info(f"Init test dataloader success") + model_module = __import__(f"models.{args.model}", fromlist=["something"]) + + if args.ddp: + self.model = getattr(model_module, args.g_name)(args).to(self.rank) + process_group = torch.distributed.new_group() + self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model, process_group) + self.model = DDP(self.model, device_ids=[self.rank], output_device=self.rank, + broadcast_buffers=False, find_unused_parameters=False) + else: + self.model = torch.nn.DataParallel(getattr(model_module, args.g_name)(args), args.gpus).cuda() + + if self.rank == 0: + logger.info(self.model) + logger.info(f"init {args.g_name} success") + + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).to(self.rank).eval() + + self.args = args + self.joints = self.test_data.joints + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list_face = joints_list["beat_smplx_face"] + self.tar_joint_list_upper = joints_list["beat_smplx_upper"] + self.tar_joint_list_hands = joints_list["beat_smplx_hands"] + self.tar_joint_list_lower = joints_list["beat_smplx_lower"] + + self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = 55 + for joint_name in self.tar_joint_list_face: + self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_upper: + self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_hands: + self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_lower: + self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + + self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False,False,False]) + + vq_model_module = __import__(f"models.motion_representation", fromlist=["something"]) + self.args.vae_layer = 2 + self.args.vae_length = 256 + self.args.vae_test_dim = 106 + self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + # print(self.vq_model_face) + other_tools.load_checkpoints(self.vq_model_face, self.args.data_path_1 + "pretrained_vq/last_790_face_v2.bin", args.e_name) + self.args.vae_test_dim = 78 + self.vq_model_upper = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.vq_model_upper, self.args.data_path_1 + "pretrained_vq/upper_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 180 + self.vq_model_hands = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.vq_model_hands, self.args.data_path_1 + "pretrained_vq/hands_vertex_1layer_710.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.vq_model_lower, self.args.data_path_1 + "pretrained_vq/lower_foot_600.bin", args.e_name) + self.args.vae_test_dim = 61 + self.args.vae_layer = 4 + self.global_motion = getattr(vq_model_module, "VAEConvZero")(self.args).to(self.rank) + other_tools.load_checkpoints(self.global_motion, self.args.data_path_1 + "pretrained_vq/last_1700_foot.bin", args.e_name) + self.args.vae_test_dim = 330 + self.args.vae_layer = 4 + self.args.vae_length = 240 + + self.vq_model_face.eval() + self.vq_model_upper.eval() + self.vq_model_hands.eval() + self.vq_model_lower.eval() + self.global_motion.eval() + + self.cls_loss = nn.NLLLoss().to(self.rank) + self.reclatent_loss = nn.MSELoss().to(self.rank) + self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank) + self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank) + self.log_softmax = nn.LogSoftmax(dim=2).to(self.rank) + + + def inverse_selection(self, filtered_t, selection_array, n): + original_shape_t = np.zeros((n, selection_array.size)) + selected_indices = np.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array).cuda() + original_shape_t = torch.zeros((n, 165)).cuda() + selected_indices = torch.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def _load_data(self, dict_data): + tar_pose_raw = dict_data["pose"] + tar_pose = tar_pose_raw[:, :, :165].to(self.rank) + tar_contact = tar_pose_raw[:, :, 165:169].to(self.rank) + tar_trans = dict_data["trans"].to(self.rank) + tar_exps = dict_data["facial"].to(self.rank) + in_audio = dict_data["audio"].to(self.rank) + in_word = dict_data["word"].to(self.rank) + tar_beta = dict_data["beta"].to(self.rank) + tar_id = dict_data["id"].to(self.rank).long() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + # tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) + # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2) + + tar_index_value_face_top = self.vq_model_face.map2index(tar_pose_face) # bs*n/4 + tar_index_value_upper_top = self.vq_model_upper.map2index(tar_pose_upper) # bs*n/4 + tar_index_value_hands_top = self.vq_model_hands.map2index(tar_pose_hands) # bs*n/4 + tar_index_value_lower_top = self.vq_model_lower.map2index(tar_pose_lower) # bs*n/4 + + latent_face_top = self.vq_model_face.map2latent(tar_pose_face) # bs*n/4 + latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper) # bs*n/4 + latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands) # bs*n/4 + latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower) # bs*n/4 + + latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) + + index_in = torch.stack([tar_index_value_upper_top, tar_index_value_hands_top, tar_index_value_lower_top], dim=-1).long() + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + # print(tar_index_value_upper_top.shape, index_in.shape) + return { + "tar_pose_jaw": tar_pose_jaw, + "tar_pose_face": tar_pose_face, + "tar_pose_upper": tar_pose_upper, + "tar_pose_lower": tar_pose_lower, + "tar_pose_hands": tar_pose_hands, + 'tar_pose_leg': tar_pose_leg, + "in_audio": in_audio, + "in_word": in_word, + "tar_trans": tar_trans, + "tar_exps": tar_exps, + "tar_beta": tar_beta, + "tar_pose": tar_pose, + "tar4dis": tar4dis, + "tar_index_value_face_top": tar_index_value_face_top, + "tar_index_value_upper_top": tar_index_value_upper_top, + "tar_index_value_hands_top": tar_index_value_hands_top, + "tar_index_value_lower_top": tar_index_value_lower_top, + "latent_face_top": latent_face_top, + "latent_upper_top": latent_upper_top, + "latent_hands_top": latent_hands_top, + "latent_lower_top": latent_lower_top, + "latent_in": latent_in, + "index_in": index_in, + "tar_id": tar_id, + "latent_all": latent_all, + "tar_pose_6d": tar_pose_6d, + "tar_contact": tar_contact, + } + + def _g_test(self, loaded_data): + mode = 'test' + bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints + tar_pose = loaded_data["tar_pose"] + tar_beta = loaded_data["tar_beta"] + in_word = loaded_data["in_word"] + tar_exps = loaded_data["tar_exps"] + tar_contact = loaded_data["tar_contact"] + in_audio = loaded_data["in_audio"] + tar_trans = loaded_data["tar_trans"] + + remain = n%8 + if remain != 0: + tar_pose = tar_pose[:, :-remain, :] + tar_beta = tar_beta[:, :-remain, :] + tar_trans = tar_trans[:, :-remain, :] + in_word = in_word[:, :-remain] + tar_exps = tar_exps[:, :-remain, :] + tar_contact = tar_contact[:, :-remain, :] + n = n - remain + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + + rec_index_all_face = [] + rec_index_all_upper = [] + rec_index_all_lower = [] + rec_index_all_hands = [] + + roundt = (n - self.args.pre_frames) // (self.args.pose_length - self.args.pre_frames) + remain = (n - self.args.pre_frames) % (self.args.pose_length - self.args.pre_frames) + round_l = self.args.pose_length - self.args.pre_frames + + for i in range(0, roundt): + in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames] + # audio fps is 16000 and pose fps is 30 + in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames] + in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames] + mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims+3+4).float().cuda() + mask_val[:, :self.args.pre_frames, :] = 0.0 + if i == 0: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames, :] + else: + latent_all_tmp = latent_all[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames, :] + # print(latent_all_tmp.shape, latent_last.shape) + latent_all_tmp[:, :self.args.pre_frames, :] = latent_last[:, -self.args.pre_frames:, :] + + net_out_val = self.model( + in_audio = in_audio_tmp, + in_word=in_word_tmp, + mask=mask_val, + in_motion = latent_all_tmp, + in_id = in_id_tmp, + use_attentions=True,) + + if self.args.cu != 0: + rec_index_upper = self.log_softmax(net_out_val["cls_upper"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_upper = torch.max(rec_index_upper.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_upper = self.vq_model_upper.decode(rec_index_upper) + else: + _, rec_index_upper, _, _ = self.vq_model_upper.quantizer(net_out_val["rec_upper"]) + #rec_upper = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + rec_index_lower = self.log_softmax(net_out_val["cls_lower"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_lower = torch.max(rec_index_lower.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_lower = self.vq_model_lower.decode(rec_index_lower) + else: + _, rec_index_lower, _, _ = self.vq_model_lower.quantizer(net_out_val["rec_lower"]) + #rec_lower = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + rec_index_hands = self.log_softmax(net_out_val["cls_hands"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_hands = torch.max(rec_index_hands.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_hands = self.vq_model_hands.decode(rec_index_hands) + else: + _, rec_index_hands, _, _ = self.vq_model_hands.quantizer(net_out_val["rec_hands"]) + #rec_hands = self.vq_model_hands.decoder(rec_index_hands) + if self.args.cf != 0: + rec_index_face = self.log_softmax(net_out_val["cls_face"]).reshape(-1, self.args.vae_codebook_size) + _, rec_index_face = torch.max(rec_index_face.reshape(-1, self.args.pose_length, self.args.vae_codebook_size), dim=2) + #rec_face = self.vq_model_face.decoder(rec_index_face) + else: + _, rec_index_face, _, _ = self.vq_model_face.quantizer(net_out_val["rec_face"]) + #rec_face = self.vq_model_face.decoder(rec_index_face) + + if i == 0: + rec_index_all_face.append(rec_index_face) + rec_index_all_upper.append(rec_index_upper) + rec_index_all_lower.append(rec_index_lower) + rec_index_all_hands.append(rec_index_hands) + else: + rec_index_all_face.append(rec_index_face[:, self.args.pre_frames:]) + rec_index_all_upper.append(rec_index_upper[:, self.args.pre_frames:]) + rec_index_all_lower.append(rec_index_lower[:, self.args.pre_frames:]) + rec_index_all_hands.append(rec_index_hands[:, self.args.pre_frames:]) + + if self.args.cu != 0: + rec_upper_last = self.vq_model_upper.decode(rec_index_upper) + else: + rec_upper_last = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + rec_lower_last = self.vq_model_lower.decode(rec_index_lower) + else: + rec_lower_last = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + rec_hands_last = self.vq_model_hands.decode(rec_index_hands) + else: + rec_hands_last = self.vq_model_hands.decoder(rec_index_hands) + # if self.args.cf != 0: + # rec_face_last = self.vq_model_face.decode(rec_index_face) + # else: + # rec_face_last = self.vq_model_face.decoder(rec_index_face) + + rec_pose_legs = rec_lower_last[:, :, :54] + bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] + rec_pose_upper = rec_upper_last.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n) + rec_pose_hands = rec_hands_last.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + rec_trans_v_s = rec_lower_last[:, :, 54:57] + rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + latent_last = torch.cat([rec_pose, rec_trans, rec_lower_last[:, :, 57:61]], dim=-1) + + rec_index_face = torch.cat(rec_index_all_face, dim=1) + rec_index_upper = torch.cat(rec_index_all_upper, dim=1) + rec_index_lower = torch.cat(rec_index_all_lower, dim=1) + rec_index_hands = torch.cat(rec_index_all_hands, dim=1) + if self.args.cu != 0: + rec_upper = self.vq_model_upper.decode(rec_index_upper) + else: + rec_upper = self.vq_model_upper.decoder(rec_index_upper) + if self.args.cl != 0: + rec_lower = self.vq_model_lower.decode(rec_index_lower) + else: + rec_lower = self.vq_model_lower.decoder(rec_index_lower) + if self.args.ch != 0: + rec_hands = self.vq_model_hands.decode(rec_index_hands) + else: + rec_hands = self.vq_model_hands.decoder(rec_index_hands) + if self.args.cf != 0: + rec_face = self.vq_model_face.decode(rec_index_face) + else: + rec_face = self.vq_model_face.decoder(rec_index_face) + + rec_exps = rec_face[:, :, 6:] + rec_pose_jaw = rec_face[:, :, :6] + rec_pose_legs = rec_lower[:, :, :54] + bs, n = rec_pose_jaw.shape[0], rec_pose_jaw.shape[1] + rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n) + rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n) + rec_pose_jaw = rec_pose_jaw.reshape(bs*n, 6) + rec_pose_jaw = rc.rotation_6d_to_matrix(rec_pose_jaw) + rec_pose_jaw = rc.matrix_to_axis_angle(rec_pose_jaw).reshape(bs*n, 1*3) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose[:, 66:69] = rec_pose_jaw + + to_global = rec_lower + to_global[:, :, 54:57] = 0.0 + to_global[:, :, :54] = rec_lower2global + rec_global = self.global_motion(to_global) + + rec_trans_v_s = rec_global["rec_pose"][:, :, 54:57] + rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1]) + rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:,:,1:2] + rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + tar_pose = tar_pose[:, :n, :] + tar_exps = tar_exps[:, :n, :] + tar_trans = tar_trans[:, :n, :] + tar_beta = tar_beta[:, :n, :] + + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + + return { + 'rec_pose': rec_pose, + 'rec_trans': rec_trans, + 'tar_pose': tar_pose, + 'tar_exps': tar_exps, + 'tar_beta': tar_beta, + 'tar_trans': tar_trans, + 'rec_exps': rec_exps, + } + + + def test_demo(self, epoch): + ''' + input audio and text, output motion + do not calculate loss and metric + save video + ''' + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + return 0 + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + align = 0 + latent_out = [] + latent_ori = [] + l2_all = 0 + lvel = 0 + self.model.eval() + self.smplx.eval() + # self.eval_copy.eval() + with torch.no_grad(): + for its, batch_data in enumerate(self.test_loader): + loaded_data = self._load_data(batch_data) + net_out = self._g_test(loaded_data) + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + tar_exps = net_out['tar_exps'] + tar_beta = net_out['tar_beta'] + rec_trans = net_out['rec_trans'] + tar_trans = net_out['tar_trans'] + rec_exps = net_out['rec_exps'] + # print(rec_pose.shape, tar_pose.shape) + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + + # interpolate to 30fps + if (30/self.args.pose_fps) != 1: + assert 30%self.args.pose_fps == 0 + n *= int(30/self.args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + + # print(rec_pose.shape, tar_pose.shape) + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + + tar_pose_np = tar_pose.detach().cpu().numpy() + rec_pose_np = rec_pose.detach().cpu().numpy() + rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) + rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) + + gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True) + np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=tar_pose_np, + expressions=tar_exp_np, + trans=tar_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30 , + ) + np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + betas=gt_npz["betas"], + poses=rec_pose_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + total_length += n + # other_tools.render_one_sequence( + # results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', + # results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + # results_save_path, + # self.args.data_path+"wave16k/"+test_seq_list.iloc[its]['id']+".wav", + # self.args.data_path_1+"smplx_models/", + # use_matplotlib = False, + # args = self.args, + # ) + end_time = time.time() - start_time + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") + +@logger.catch +def main_worker(rank, world_size, args): + #os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/" + if not sys.warnoptions: + warnings.simplefilter("ignore") + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + logger_tools.set_args_and_logger(args, rank) + other_tools.set_random_seed(args) + other_tools.print_exp_info(args) + + # return one intance of trainer + other_tools.write_wav_names_to_csv(args.data_path, args.data_path+"test.csv") + trainer = BaseTrainer(args) + other_tools.load_checkpoints(trainer.model, args.test_ckpt, args.g_name) + trainer.test_demo(999) + + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"]='127.0.0.1' + os.environ["MASTER_PORT"]='8675' + #os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + args = config.parse_args() + if args.ddp: + mp.set_start_method("spawn", force=True) + mp.spawn( + main_worker, + args=(len(args.gpus), args,), + nprocs=len(args.gpus), + ) + else: + main_worker(0, 1, args) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3a47475400185fc5042ccd3e1ce0f66fe2a1cae5 --- /dev/null +++ b/train.py @@ -0,0 +1,307 @@ +import os +import signal +import time +import csv +import sys +import warnings +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp +import numpy as np +import time +import pprint +from loguru import logger +import smplx +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from utils import config, logger_tools, other_tools, metric +from dataloaders import data_tools +from dataloaders.build_vocab import Vocab +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func + + +class BaseTrainer(object): + def __init__(self, args): + self.args = args + self.rank = dist.get_rank() + self.checkpoint_path = args.out_path + "custom/" + args.name + args.notes + "/" #wandb.run.dir #args.cache_path+args.out_path+"/"+args.name + if self.rank==0: + if self.args.stat == "ts": + self.writer = SummaryWriter(log_dir=args.out_path + "custom/" + args.name + args.notes + "/") + else: + wandb.init(project=args.project, entity="liu1997", dir=args.out_path, name=args.name[12:] + args.notes) + wandb.config.update(args) + self.writer = None + #self.test_demo = args.data_path + args.test_data_path + "bvh_full/" + self.train_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "train") + self.train_loader = torch.utils.data.DataLoader( + self.train_data, + batch_size=args.batch_size, + shuffle=False if args.ddp else True, + num_workers=args.loader_workers, + drop_last=True, + sampler=torch.utils.data.distributed.DistributedSampler(self.train_data) if args.ddp else None, + ) + self.train_length = len(self.train_loader) + logger.info(f"Init train dataloader success") + + self.val_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "val") + self.val_loader = torch.utils.data.DataLoader( + self.val_data, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.loader_workers, + drop_last=False, + sampler=torch.utils.data.distributed.DistributedSampler(self.val_data) if args.ddp else None, + ) + logger.info(f"Init val dataloader success") + if self.rank == 0: + self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test") + self.test_loader = torch.utils.data.DataLoader( + self.test_data, + batch_size=1, + shuffle=False, + num_workers=args.loader_workers, + drop_last=False, + ) + logger.info(f"Init test dataloader success") + model_module = __import__(f"models.{args.model}", fromlist=["something"]) + + if args.ddp: + self.model = getattr(model_module, args.g_name)(args).to(self.rank) + process_group = torch.distributed.new_group() + self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model, process_group) + self.model = DDP(self.model, device_ids=[self.rank], output_device=self.rank, + broadcast_buffers=False, find_unused_parameters=False) + else: + self.model = torch.nn.DataParallel(getattr(model_module, args.g_name)(args), args.gpus).cuda() + + if self.rank == 0: + logger.info(self.model) + logger.info(f"init {args.g_name} success") + if args.stat == "wandb": + wandb.watch(self.model) + + if args.d_name is not None: + if args.ddp: + self.d_model = getattr(model_module, args.d_name)(args).to(self.rank) + self.d_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.d_model, process_group) + self.d_model = DDP(self.d_model, device_ids=[self.rank], output_device=self.rank, + broadcast_buffers=False, find_unused_parameters=False) + else: + self.d_model = torch.nn.DataParallel(getattr(model_module, args.d_name)(args), args.gpus).cuda() + if self.rank == 0: + logger.info(self.d_model) + logger.info(f"init {args.d_name} success") + if args.stat == "wandb": + wandb.watch(self.d_model) + self.opt_d = create_optimizer(args, self.d_model, lr_weight=args.d_lr_weight) + self.opt_d_s = create_scheduler(args, self.opt_d) + + if args.e_name is not None: + """ + bugs on DDP training using eval_model, using additional eval_copy for evaluation + """ + eval_model_module = __import__(f"models.{args.eval_model}", fromlist=["something"]) + # eval copy is for single card evaluation + if self.args.ddp: + self.eval_model = getattr(eval_model_module, args.e_name)(args).to(self.rank) + self.eval_copy = getattr(eval_model_module, args.e_name)(args).to(self.rank) + else: + self.eval_model = getattr(eval_model_module, args.e_name)(args) + self.eval_copy = getattr(eval_model_module, args.e_name)(args).to(self.rank) + + #if self.rank == 0: + other_tools.load_checkpoints(self.eval_copy, args.data_path+args.e_path, args.e_name) + other_tools.load_checkpoints(self.eval_model, args.data_path+args.e_path, args.e_name) + if self.args.ddp: + self.eval_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.eval_model, process_group) + self.eval_model = DDP(self.eval_model, device_ids=[self.rank], output_device=self.rank, + broadcast_buffers=False, find_unused_parameters=False) + self.eval_model.eval() + self.eval_copy.eval() + if self.rank == 0: + logger.info(self.eval_model) + logger.info(f"init {args.e_name} success") + if args.stat == "wandb": + wandb.watch(self.eval_model) + self.opt = create_optimizer(args, self.model) + self.opt_s = create_scheduler(args, self.opt) + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).to(self.rank).eval() + self.alignmenter = metric.alignment(0.3, 7, self.train_data.avg_vel, upper_body=[3,6,9,12,13,14,15,16,17,18,19,20,21]) if self.rank == 0 else None + self.align_mask = 60 + self.l1_calculator = metric.L1div() if self.rank == 0 else None + + + def inverse_selection(self, filtered_t, selection_array, n): + original_shape_t = np.zeros((n, selection_array.size)) + selected_indices = np.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + # def inverse_selection_6d(self, filtered_t, selection_array, n): + # original_shape_t = np.zeros((n, selection_array.size)) + # selected_indices = np.where(selection_array == 1)[0] + # new_selected_indices = np.zeros((n, selected_indices.size*2)) + # new_selected_indices[:, ::2] = selected_indices + # new_selected_indices[:, 1::2] = selected_indices + # selected_indices = new_selected_indices.astype(np.bool) + # for i in range(n): + # original_shape_t[i, selected_indices] = filtered_t[i] + # return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array).cuda() + selected_indices = torch.where(selection_array == 1)[0] + if len(filtered_t.shape) == 2: + original_shape_t = torch.zeros((n, 165)).cuda() + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + elif len(filtered_t.shape) == 3: + bs, n, _ = filtered_t.shape + original_shape_t = torch.zeros((bs, n, 165), device='cuda') + expanded_indices = selected_indices.unsqueeze(0).unsqueeze(0).expand(bs, n, -1) + original_shape_t.scatter_(2, expanded_indices, filtered_t) + return original_shape_t + + def inverse_selection_tensor_6d(self, filtered_t, selection_array, n): + new_selected_array = np.zeros((330)) + new_selected_array[::2] = selection_array + new_selected_array[1::2] = selection_array + selection_array = new_selected_array + selection_array = torch.from_numpy(selection_array).cuda() + selected_indices = torch.where(selection_array == 1)[0] + if len(filtered_t.shape) == 2: + original_shape_t = torch.zeros((n, 330)).cuda() + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + elif len(filtered_t.shape) == 3: + bs, n, _ = filtered_t.shape + original_shape_t = torch.zeros((bs, n, 330), device='cuda') + expanded_indices = selected_indices.unsqueeze(0).unsqueeze(0).expand(bs, n, -1) + original_shape_t.scatter_(2, expanded_indices, filtered_t) + return original_shape_t + + def train_recording(self, epoch, its, t_data, t_train, mem_cost, lr_g, lr_d=None): + pstr = "[%03d][%03d/%03d] "%(epoch, its, self.train_length) + for name, states in self.tracker.loss_meters.items(): + metric = states['train'] + if metric.count > 0: + pstr += "{}: {:.3f}\t".format(name, metric.avg) + self.writer.add_scalar(f"train/{name}", metric.avg, epoch*self.train_length+its) if self.args.stat == "ts" else wandb.log({name: metric.avg}, step=epoch*self.train_length+its) + pstr += "glr: {:.1e}\t".format(lr_g) + self.writer.add_scalar("lr/glr", lr_g, epoch*self.train_length+its) if self.args.stat == "ts" else wandb.log({'glr': lr_g}, step=epoch*self.train_length+its) + if lr_d is not None: + pstr += "dlr: {:.1e}\t".format(lr_d) + self.writer.add_scalar("lr/dlr", lr_d, epoch*self.train_length+its) if self.args.stat == "ts" else wandb.log({'dlr': lr_d}, step=epoch*self.train_length+its) + pstr += "dtime: %04d\t"%(t_data*1000) + pstr += "ntime: %04d\t"%(t_train*1000) + pstr += "mem: {:.2f} ".format(mem_cost*len(self.args.gpus)) + logger.info(pstr) + + def val_recording(self, epoch): + pstr_curr = "Curr info >>>> " + pstr_best = "Best info >>>> " + for name, states in self.tracker.loss_meters.items(): + metric = states['val'] + if metric.count > 0: + pstr_curr += "{}: {:.3f} \t".format(name, metric.avg) + if epoch != 0: + if self.args.stat == "ts": + self.writer.add_scalars(f"val/{name}", {name+"_val":metric.avg, name+"_train":states['train'].avg}, epoch*self.train_length) + else: + wandb.log({name+"_val": metric.avg, name+"_train":states['train'].avg}, step=epoch*self.train_length) + new_best_train, new_best_val = self.tracker.update_and_plot(name, epoch, self.checkpoint_path+f"{name}_{self.args.name+self.args.notes}.png") + if new_best_val: + other_tools.save_checkpoints(os.path.join(self.checkpoint_path, f"{name}.bin"), self.model, opt=None, epoch=None, lrs=None) + for k, v in self.tracker.values.items(): + metric = v['val']['best'] + if self.tracker.loss_meters[k]['val'].count > 0: + pstr_best += "{}: {:.3f}({:03d})\t".format(k, metric['value'], metric['epoch']) + logger.info(pstr_curr) + logger.info(pstr_best) + + def test_recording(self, dict_name, value, epoch): + self.tracker.update_meter(dict_name, "test", value) + _ = self.tracker.update_values(dict_name, 'test', epoch) + +@logger.catch +def main_worker(rank, world_size, args): + #os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/" + if not sys.warnoptions: + warnings.simplefilter("ignore") + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + logger_tools.set_args_and_logger(args, rank) + other_tools.set_random_seed(args) + other_tools.print_exp_info(args) + + # return one intance of trainer + trainer = __import__(f"{args.trainer}_trainer", fromlist=["something"]).CustomTrainer(args) if args.trainer != "base" else BaseTrainer(args) + logger.info("Training from scratch ...") + start_time = time.time() + for epoch in range(args.epochs+1): + if args.ddp: trainer.val_loader.sampler.set_epoch(epoch) + trainer.val(epoch) + # if (epoch) % args.test_period == 1: trainer.val(epoch) + epoch_time = time.time()-start_time + if trainer.rank == 0: logger.info("Time info >>>> elapsed: %.2f mins\t"%(epoch_time/60)+"remain: %.2f mins"%((args.epochs/(epoch+1e-7)-1)*epoch_time/60)) + if epoch != args.epochs: + if args.ddp: trainer.train_loader.sampler.set_epoch(epoch) + trainer.tracker.reset() + trainer.train(epoch) + if args.debug: + other_tools.save_checkpoints(os.path.join(trainer.checkpoint_path, f"last_{epoch}.bin"), trainer.model, opt=None, epoch=None, lrs=None) + other_tools.load_checkpoints(trainer.model, os.path.join(trainer.checkpoint_path, f"last_{epoch}.bin"), args.g_name) + #other_tools.load_checkpoints(trainer.model, "/home/s24273/datasets/hub/pretrained_vq/last_140.bin", args.g_name) + trainer.test(epoch) + if (epoch) % args.test_period == 0 and epoch !=0: + if rank == 0: + other_tools.save_checkpoints(os.path.join(trainer.checkpoint_path, f"last_{epoch}.bin"), trainer.model, opt=None, epoch=None, lrs=None) + trainer.test(epoch) + + if rank == 0: + for k, v in trainer.tracker.values.items(): + if trainer.tracker.loss_meters[k]['val'].count > 0: + other_tools.load_checkpoints(trainer.model, os.path.join(trainer.checkpoint_path, f"{k}.bin"), args.g_name) + logger.info(f"inference on ckpt {k}_val_{v['val']['best']['epoch']}:") + trainer.test(v['val']['best']['epoch']) + other_tools.record_trial(args, trainer.tracker) + wandb.log({"fid_test": trainer.tracker["fid"]["test"]["best"]}) + if args.stat == "ts": + trainer.writer.close() + else: + wandb.finish() + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"]='127.0.0.1' + os.environ["MASTER_PORT"]='8675' + #os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + args = config.parse_args() + if args.ddp: + mp.set_start_method("spawn", force=True) + mp.spawn( + main_worker, + args=(len(args.gpus), args,), + nprocs=len(args.gpus), + ) + else: + main_worker(0, 1, args) \ No newline at end of file diff --git a/utils/.ipynb_checkpoints/config-checkpoint.py b/utils/.ipynb_checkpoints/config-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac207d7270c15532fa55b9184af3d64ceae9fc6 --- /dev/null +++ b/utils/.ipynb_checkpoints/config-checkpoint.py @@ -0,0 +1,275 @@ +import configargparse +import time +import json +import yaml +import os + + +def str2bool(v): + """ from https://stackoverflow.com/a/43357954/1361529 """ + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise configargparse.ArgumentTypeError('Boolean value expected.') + + +def parse_args(): + """ + requirement for config + 1. command > yaml > default + 2. avoid re-definition + 3. lowercase letters is better + 4. hierarchical is not necessary + """ + parser = configargparse.ArgParser() + parser.add("-c", "--config", default='./configs/emage_test_hf.yaml', is_config_file=True) + parser.add("--project", default="audio2pose", type=str) # wandb project name + parser.add("--stat", default="ts", type=str) + parser.add("--csv_name", default="a2g_0", type=str) # local device id + parser.add("--notes", default="", type=str) + parser.add("--trainer", default="camn", type=str) + + parser.add("--l", default=4, type=int) + # ------------- path and save name ---------------- # + parser.add("--is_train", default=True, type=str2bool) + parser.add("--debug", default=False, type=str2bool) + # different between environments + parser.add("--root_path", default="/home/ma-user/work/") + parser.add("--cache_path", default="/outputs/audio2pose/", type=str) + parser.add("--out_path", default="/outputs/audio2pose/", type=str) + parser.add("--data_path", default="/outputs/audio2pose/", type=str) + parser.add("--train_data_path", default="/datasets/trinity/train/", type=str) + parser.add("--val_data_path", default="/datasets/trinity/val/", type=str) + parser.add("--test_data_path", default="/datasets/trinity/test/", type=str) + parser.add("--mean_pose_path", default="/datasets/trinity/train/", type=str) + parser.add("--std_pose_path", default="/datasets/trinity/train/", type=str) + # for pretrian weights + parser.add("--data_path_1", default="../../datasets/checkpoints/", type=str) + # ------------------- evaluation ----------------------- # + parser.add("--test_ckpt", default="/datasets/beat_cache/beat_4english_15_141/last.bin") + parser.add("--eval_model", default="vae", type=str) + parser.add("--e_name", default=None, type=str) #HalfEmbeddingNet + parser.add("--e_path", default="/datasets/beat/generated_data/self_vae_128.bin") + parser.add("--variational", default=False, type=str2bool) + parser.add("--vae_length", default=256, type=int) + parser.add("--vae_test_dim", default=141, type=int) + parser.add("--vae_test_len", default=34, type=int) + parser.add("--vae_test_stride", default=10, type=int) + #parser.add("--vae_pose_length", default=34, type=int) + parser.add("--test_period", default=20, type=int) + parser.add("--vae_codebook_size", default=1024, type=int) + parser.add("--vae_quantizer_lambda", default=1., type=float) + + parser.add("--vae_layer", default=2, type=int) + parser.add("--vae_grow", default=[1,1,2,1], type=int, nargs="*") + parser.add("--lf", default=0., type=float) + parser.add("--ll", default=0., type=float) + parser.add("--lu", default=0., type=float) + parser.add("--lh", default=0., type=float) + parser.add("--cf", default=0., type=float) + parser.add("--cl", default=0., type=float) + parser.add("--cu", default=0., type=float) + parser.add("--ch", default=0., type=float) + + + # --------------- data ---------------------------- # + parser.add("--additional_data", default=False, type=str2bool) + parser.add("--train_trans", default=True, type=str2bool) + parser.add("--dataset", default="beat", type=str) + parser.add("--rot6d", default=True, type=str2bool) + parser.add("--ori_joints", default="spine_neck_141", type=str) + parser.add("--tar_joints", default="spine_neck_141", type=str) + parser.add("--training_speakers", default=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], type=int, nargs="*") + #parser.add("--pose_version", default="spine_neck_141", type=str) + parser.add("--new_cache", default=True, type=str2bool) + parser.add("--beat_align", default=True, type=str2bool) + parser.add("--cache_only", default=False, type=str2bool) + parser.add("--word_cache", default=False, type=str2bool) + parser.add("--use_aug", default=False, type=str2bool) + parser.add("--disable_filtering", default=False, type=str2bool) + parser.add("--clean_first_seconds", default=0, type=int) + parser.add("--clean_final_seconds", default=0, type=int) + + parser.add("--audio_rep", default=None, type=str) + parser.add("--audio_sr", default=16000, type=int) + parser.add("--word_rep", default=None, type=str) + parser.add("--emo_rep", default=None, type=str) + parser.add("--sem_rep", default=None, type=str) + parser.add("--facial_rep", default=None, type=str) + parser.add("--pose_rep", default="bvhrot", type=str) + parser.add("--id_rep", default="onehot", type=str) + parser.add("--speaker_id", default="onehot", type=str) + + parser.add("--a_pre_encoder", default=None, type=str) + parser.add("--a_encoder", default=None, type=str) + parser.add("--a_fix_pre", default=False, type=str2bool) + parser.add("--t_pre_encoder", default=None, type=str) + parser.add("--t_encoder", default=None, type=str) + parser.add("--t_fix_pre", default=False, type=str2bool) + parser.add("--m_pre_encoder", default=None, type=str) + parser.add("--m_encoder", default=None, type=str) + parser.add("--m_fix_pre", default=False, type=str2bool) + parser.add("--f_pre_encoder", default=None, type=str) + parser.add("--f_encoder", default=None, type=str) + parser.add("--f_fix_pre", default=False, type=str2bool) + parser.add("--m_decoder", default=None, type=str) + parser.add("--decode_fusion", default=None, type=str) + parser.add("--atmr", default=0.0, type=float) + parser.add("--ttmr", default=0., type=float) + parser.add("--mtmr", default=0., type=float) + parser.add("--ftmr", default=0., type=float) + parser.add("--asmr", default=0., type=float) + parser.add("--tsmr", default=0., type=float) + parser.add("--msmr", default=0., type=float) + parser.add("--fsmr", default=0., type=float) +# parser.add("--m_encoder", default=None, type=str) +# parser.add("--m_pre_fix", default=None, type=str) +# parser.add("--id_rep", default=None, type=str) + + parser.add("--freeze_wordembed", default=True, type=str2bool) + parser.add("--audio_fps", default=16000, type=int) + parser.add("--facial_fps", default=15, type=int) + parser.add("--pose_fps", default=15, type=int) + + parser.add("--audio_dims", default=1, type=int) + parser.add("--facial_dims", default=39, type=int) + parser.add("--pose_dims", default=123, type=int) + parser.add("--word_index_num", default=5793, type=int) + parser.add("--word_dims", default=300, type=int) + parser.add("--speaker_dims", default=4, type=int) + parser.add("--emotion_dims", default=8, type=int) + + parser.add("--audio_norm", default=False, type=str2bool) + parser.add("--facial_norm", default=False, type=str2bool) + parser.add("--pose_norm", default=False, type=str2bool) + + parser.add("--pose_length", default=34, type=int) + parser.add("--pre_frames", default=4, type=int) + parser.add("--stride", default=10, type=int) + parser.add("--pre_type", default="zero", type=str) + + parser.add("--audio_f", default=0, type=int) + parser.add("--motion_f", default=0, type=int) + parser.add("--facial_f", default=0, type=int) + parser.add("--speaker_f", default=0, type=int) + parser.add("--word_f", default=0, type=int) + parser.add("--emotion_f", default=0, type=int) + parser.add("--aud_prob", default=1.0, type=float) + parser.add("--pos_prob", default=1.0, type=float) + parser.add("--txt_prob", default=1.0, type=float) + parser.add("--fac_prob", default=1.0, type=float) + parser.add("--multi_length_training", default=[1.0], type=float, nargs="*") + # --------------- model ---------------------------- # + parser.add("--pretrain", default=False, type=str2bool) + parser.add("--model", default="camn", type=str) + parser.add("--g_name", default="CaMN", type=str) + parser.add("--d_name", default=None, type=str) #ConvDiscriminator + parser.add("--dropout_prob", default=0.3, type=float) + parser.add("--n_layer", default=4, type=int) + parser.add("--hidden_size", default=300, type=int) + #parser.add("--period", default=34, type=int) + parser.add("--test_length", default=34, type=int) + # Self-designed "Multi-Stage", "Seprate", or "Original" + parser.add("--finger_net", default="original", type=str) + parser.add("--pos_encoding_type", default="sin", type=str) + parser.add("--queue_size", default=1024, type=int) + + # --------------- training ------------------------- # + parser.add("--epochs", default=120, type=int) + parser.add("--epoch_stage", default=0, type=int) + parser.add("--grad_norm", default=0, type=float) + parser.add("--no_adv_epoch", default=999, type=int) + parser.add("--batch_size", default=128, type=int) + parser.add("--opt", default="adam", type=str) + parser.add("--lr_base", default=0.00025, type=float) + parser.add("--opt_betas", default=[0.5, 0.999], type=float, nargs="*") + parser.add("--weight_decay", default=0., type=float) + # for warmup and cosine + parser.add("--lr_min", default=1e-7, type=float) + parser.add("--warmup_lr", default=5e-4, type=float) + parser.add("--warmup_epochs", default=0, type=int) + parser.add("--decay_epochs", default=9999, type=int) + parser.add("--decay_rate", default=0.1, type=float) + parser.add("--lr_policy", default="step", type=str) + # for sgd + parser.add("--momentum", default=0.8, type=float) + parser.add("--nesterov", default=True, type=str2bool) + parser.add("--amsgrad", default=False, type=str2bool) + parser.add("--d_lr_weight", default=0.2, type=float) + parser.add("--rec_weight", default=500, type=float) + parser.add("--adv_weight", default=20.0, type=float) + parser.add("--fid_weight", default=0.0, type=float) + parser.add("--vel_weight", default=0.0, type=float) + parser.add("--acc_weight", default=0.0, type=float) + parser.add("--kld_weight", default=0.0, type=float) + parser.add("--kld_aud_weight", default=0.0, type=float) + parser.add("--kld_fac_weight", default=0.0, type=float) + parser.add("--ali_weight", default=0.0, type=float) + parser.add("--ita_weight", default=0.0, type=float) + parser.add("--iwa_weight", default=0.0, type=float) + parser.add("--wei_weight", default=0.0, type=float) + parser.add("--gap_weight", default=0.0, type=float) + parser.add("--atcont", default=0.0, type=float) + parser.add("--fusion_mode", default="sum", type=str) + + parser.add("--div_reg_weight", default=0.0, type=float) + parser.add("--rec_aud_weight", default=0.0, type=float) + parser.add("--rec_ver_weight", default=0.0, type=float) + parser.add("--rec_pos_weight", default=0.0, type=float) + parser.add("--rec_fac_weight", default=0.0, type=float) + parser.add("--rec_txt_weight", default=0.0, type=float) +# parser.add("--gan_noise_size", default=0, type=int) + # --------------- ha2g -------------------------- # + parser.add("--n_pre_poses", default=4, type=int) + parser.add("--n_poses", default=34, type=int) + parser.add("--input_context", default="both", type=str) + parser.add("--loss_contrastive_pos_weight", default=0.2, type=float) + parser.add("--loss_contrastive_neg_weight", default=0.005, type=float) + parser.add("--loss_physical_weight", default=0.0, type=float) + parser.add("--loss_reg_weight", default=0.05, type=float) + parser.add("--loss_regression_weight", default=70.0, type=float) + parser.add("--loss_gan_weight", default=5.0, type=float) + parser.add("--loss_kld_weight", default=0.1, type=float) + parser.add("--z_type", default="speaker", type=str) + # --------------- device -------------------------- # + parser.add("--random_seed", default=2021, type=int) + parser.add("--deterministic", default=True, type=str2bool) + parser.add("--benchmark", default=True, type=str2bool) + parser.add("--cudnn_enabled", default=True, type=str2bool) + # mix precision + parser.add("--apex", default=False, type=str2bool) + parser.add("--gpus", default=[0], type=int, nargs="*") + parser.add("--loader_workers", default=0, type=int) + parser.add("--ddp", default=False, type=str2bool) + parser.add("--sparse", default=1, type=int) + #parser.add("--world_size") + parser.add("--render_video_fps", default=30, type=int) + parser.add("--render_video_width", default=1920, type=int) + parser.add("--render_video_height", default=720, type=int) + cpu_cores = os.cpu_count() if os.cpu_count() is not None else 1 + default_concurrent = max(1, cpu_cores // 2) + parser.add("--render_concurrent_num", default=default_concurrent, type=int) + parser.add("--render_tmp_img_filetype", default="bmp", type=str) + + # logging + parser.add("--log_period", default=10, type=int) + + + args = parser.parse_args() + idc = 0 + for i, char in enumerate(args.config): + if char == "/": idc = i + args.name = args.config[idc+1:-5] + + is_train = args.is_train + + if is_train: + time_local = time.localtime() + name_expend = "%02d%02d_%02d%02d%02d_"%(time_local[1], time_local[2],time_local[3], time_local[4], time_local[5]) + args.name = name_expend + args.name + + return args \ No newline at end of file diff --git a/utils/.ipynb_checkpoints/fast_render-checkpoint.py b/utils/.ipynb_checkpoints/fast_render-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..78311cbcc99ac414c7724e1081615825ad20843a --- /dev/null +++ b/utils/.ipynb_checkpoints/fast_render-checkpoint.py @@ -0,0 +1,266 @@ +import os +import time +import numpy as np +import pyrender +import trimesh +import queue +import imageio +import threading +import multiprocessing +import utils.media +import glob + +def deg_to_rad(degrees): + return degrees * np.pi / 180 + +def create_pose_camera(angle_deg): + angle_rad = deg_to_rad(angle_deg) + return np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 1.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 5.0], + [0.0, 0.0, 0.0, 1.0] + ]) + +def create_pose_light(angle_deg): + angle_rad = deg_to_rad(angle_deg) + return np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 0.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 3.0], + [0.0, 0.0, 0.0, 1.0] + ]) + +def create_scene_with_mesh(vertices, faces, uniform_color, pose_camera, pose_light): + trimesh_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=uniform_color) + mesh = pyrender.Mesh.from_trimesh(trimesh_mesh, smooth=True) + scene = pyrender.Scene() + scene.add(mesh) + camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + scene.add(camera, pose=pose_camera) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=4.0) + scene.add(light, pose=pose_light) + return scene + +def do_render_one_frame(renderer, frame_idx, vertices, vertices1, faces): + if frame_idx % 100 == 0: + print('processed', frame_idx, 'frames') + + uniform_color = [220, 220, 220, 255] + pose_camera = create_pose_camera(angle_deg=-2) + pose_light = create_pose_light(angle_deg=-30) + + figs = [] + for vtx in [vertices, vertices1]: + # print(vtx.shape) + scene = create_scene_with_mesh(vtx, faces, uniform_color, pose_camera, pose_light) + fig, _ = renderer.render(scene) + figs.append(fig) + + return figs[0], figs[1] + +def do_render_one_frame_no_gt(renderer, frame_idx, vertices, faces): + if frame_idx % 100 == 0: + print('processed', frame_idx, 'frames') + + uniform_color = [220, 220, 220, 255] + pose_camera = create_pose_camera(angle_deg=-2) + pose_light = create_pose_light(angle_deg=-30) + + figs = [] + # for vtx in [vertices]: + # print(vtx.shape) + # print(vertices.shape) + scene = create_scene_with_mesh(vertices, faces, uniform_color, pose_camera, pose_light) + fig, _ = renderer.render(scene) + figs.append(fig) + + return figs[0] + +def write_images_from_queue(fig_queue, output_dir, img_filetype): + while True: + e = fig_queue.get() + if e is None: + break + fid, fig1, fig2 = e + filename = os.path.join(output_dir, f"frame_{fid}.{img_filetype}") + merged_fig = np.hstack((fig1, fig2)) + try: + imageio.imwrite(filename, merged_fig) + except Exception as ex: + print(f"Error writing image {filename}: {ex}") + raise ex + +def write_images_from_queue_no_gt(fig_queue, output_dir, img_filetype): + while True: + e = fig_queue.get() + if e is None: + break + fid, fig1, fig2 = e + filename = os.path.join(output_dir, f"frame_{fid}.{img_filetype}") + merged_fig = fig1 #np.hstack((fig1)) + try: + imageio.imwrite(filename, merged_fig) + except Exception as ex: + print(f"Error writing image {filename}: {ex}") + raise ex + + +def render_frames_and_enqueue(fids, frame_vertex_pairs, faces, render_width, render_height, fig_queue): + fig_resolution = (render_width // 2, render_height) + renderer = pyrender.OffscreenRenderer(*fig_resolution) + + for idx, fid in enumerate(fids): + fig1, fig2 = do_render_one_frame(renderer, fid, frame_vertex_pairs[idx][0], frame_vertex_pairs[idx][1], faces) + fig_queue.put((fid, fig1, fig2)) + + renderer.delete() + +def render_frames_and_enqueue_no_gt(fids, frame_vertex_pairs, faces, render_width, render_height, fig_queue): + fig_resolution = (render_width // 2, render_height) + renderer = pyrender.OffscreenRenderer(*fig_resolution) + + for idx, fid in enumerate(fids): + fig1 = do_render_one_frame_no_gt(renderer, fid, frame_vertex_pairs[idx][0], faces) + fig_queue.put((fid, fig1)) + + renderer.delete() + +def sub_process_process_frame(subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, fids, frame_vertex_pairs, faces, output_dir): + begin_ts = time.time() + print(f"subprocess_index={subprocess_index} begin_ts={begin_ts}") + + fig_queue = queue.Queue() + render_frames_and_enqueue(fids, frame_vertex_pairs, faces, render_video_width, render_video_height, fig_queue) + fig_queue.put(None) + render_end_ts = time.time() + + image_writer_thread = threading.Thread(target=write_images_from_queue, args=(fig_queue, output_dir, render_tmp_img_filetype)) + image_writer_thread.start() + image_writer_thread.join() + + write_end_ts = time.time() + print( + f"subprocess_index={subprocess_index} " + f"render={render_end_ts - begin_ts:.2f} " + f"all={write_end_ts - begin_ts:.2f} " + f"begin_ts={begin_ts:.2f} " + f"render_end_ts={render_end_ts:.2f} " + f"write_end_ts={write_end_ts:.2f}" + ) + +def sub_process_process_frame_no_gt(subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, fids, frame_vertex_pairs, faces, output_dir): + begin_ts = time.time() + print(f"subprocess_index={subprocess_index} begin_ts={begin_ts}") + + fig_queue = queue.Queue() + render_frames_and_enqueue(fids, frame_vertex_pairs, faces, render_video_width, render_video_height, fig_queue) + fig_queue.put(None) + render_end_ts = time.time() + + image_writer_thread = threading.Thread(target=write_images_from_queue_no_gt, args=(fig_queue, output_dir, render_tmp_img_filetype)) + image_writer_thread.start() + image_writer_thread.join() + + write_end_ts = time.time() + print( + f"subprocess_index={subprocess_index} " + f"render={render_end_ts - begin_ts:.2f} " + f"all={write_end_ts - begin_ts:.2f} " + f"begin_ts={begin_ts:.2f} " + f"render_end_ts={render_end_ts:.2f} " + f"write_end_ts={write_end_ts:.2f}" + ) + +def distribute_frames(frames, render_video_fps, render_concurent_nums, vertices_all, vertices1_all): + sample_interval = max(1, int(30 // render_video_fps)) + subproc_frame_ids = [[] for _ in range(render_concurent_nums)] + subproc_vertices = [[] for _ in range(render_concurent_nums)] + sampled_frame_id = 0 + + for i in range(frames): + if i % sample_interval != 0: + continue + subprocess_index = sampled_frame_id % render_concurent_nums + subproc_frame_ids[subprocess_index].append(sampled_frame_id) + subproc_vertices[subprocess_index].append((vertices_all[i], vertices1_all[i])) + sampled_frame_id += 1 + + return subproc_frame_ids, subproc_vertices + +def distribute_frames_no_gt(frames, render_video_fps, render_concurent_nums, vertices_all): + sample_interval = max(1, int(30 // render_video_fps)) + subproc_frame_ids = [[] for _ in range(render_concurent_nums)] + subproc_vertices = [[] for _ in range(render_concurent_nums)] + sampled_frame_id = 0 + + for i in range(frames): + if i % sample_interval != 0: + continue + subprocess_index = sampled_frame_id % render_concurent_nums + subproc_frame_ids[subprocess_index].append(sampled_frame_id) + subproc_vertices[subprocess_index].append((vertices_all[i], vertices_all[i])) + sampled_frame_id += 1 + + return subproc_frame_ids, subproc_vertices + +def generate_silent_videos(render_video_fps, + render_video_width, + render_video_height, + render_concurent_nums, + render_tmp_img_filetype, + frames, + vertices_all, + vertices1_all, + faces, + output_dir): + + subproc_frame_ids, subproc_vertices = distribute_frames(frames, render_video_fps, render_concurent_nums, vertices_all, vertices1_all) + + print(f"generate_silent_videos concurrentNum={render_concurent_nums} time={time.time()}") + with multiprocessing.Pool(render_concurent_nums) as pool: + pool.starmap( + sub_process_process_frame, + [ + (subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, subproc_frame_ids[subprocess_index], subproc_vertices[subprocess_index], faces, output_dir) + for subprocess_index in range(render_concurent_nums) + ] + ) + + output_file = os.path.join(output_dir, "silence_video.mp4") + utils.media.convert_img_to_mp4(os.path.join(output_dir, f"frame_%d.{render_tmp_img_filetype}"), output_file, render_video_fps) + filenames = glob.glob(os.path.join(output_dir, f"*.{render_tmp_img_filetype}")) + for filename in filenames: + os.remove(filename) + + return output_file + +def generate_silent_videos_no_gt(render_video_fps, + render_video_width, + render_video_height, + render_concurent_nums, + render_tmp_img_filetype, + frames, + vertices_all, + faces, + output_dir): + + subproc_frame_ids, subproc_vertices = distribute_frames_no_gt(frames, render_video_fps, render_concurent_nums, vertices_all) + + print(f"generate_silent_videos concurrentNum={render_concurent_nums} time={time.time()}") + with multiprocessing.Pool(render_concurent_nums) as pool: + pool.starmap( + sub_process_process_frame_no_gt, + [ + (subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, subproc_frame_ids[subprocess_index], subproc_vertices[subprocess_index], faces, output_dir) + for subprocess_index in range(render_concurent_nums) + ] + ) + + output_file = os.path.join(output_dir, "silence_video.mp4") + utils.media.convert_img_to_mp4(os.path.join(output_dir, f"frame_%d.{render_tmp_img_filetype}"), output_file, render_video_fps) + filenames = glob.glob(os.path.join(output_dir, f"*.{render_tmp_img_filetype}")) + for filename in filenames: + os.remove(filename) + + return output_file \ No newline at end of file diff --git a/utils/.ipynb_checkpoints/media-checkpoint.py b/utils/.ipynb_checkpoints/media-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd21e079a9e48f97f1511bd289d39f4aeccc40e --- /dev/null +++ b/utils/.ipynb_checkpoints/media-checkpoint.py @@ -0,0 +1,39 @@ +import numpy as np +import subprocess + +def add_audio_to_video(silent_video_path, audio_path, output_video_path): + command = [ + 'ffmpeg', + '-y', + '-i', silent_video_path, + '-i', audio_path, + '-map', '0:v', + '-map', '1:a', + '-c:v', 'copy', + '-shortest', + output_video_path + ] + + try: + subprocess.run(command, check=True) + print(f"Video with audio generated successfully: {output_video_path}") + except subprocess.CalledProcessError as e: + print(f"Error occurred: {e}") + + +def convert_img_to_mp4(input_pattern, output_file, framerate=30): + command = [ + 'ffmpeg', + '-framerate', str(framerate), + '-i', input_pattern, + '-c:v', 'libx264', + '-pix_fmt', 'yuv420p', + output_file, + '-y' + ] + + try: + subprocess.run(command, check=True) + print(f"Video conversion successful. Output file: {output_file}") + except subprocess.CalledProcessError as e: + print(f"Error during video conversion: {e}") diff --git a/utils/.ipynb_checkpoints/other_tools_hf-checkpoint.py b/utils/.ipynb_checkpoints/other_tools_hf-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b9f943120954e5745f1beca4e37ff461f663bb --- /dev/null +++ b/utils/.ipynb_checkpoints/other_tools_hf-checkpoint.py @@ -0,0 +1,975 @@ +import os +import numpy as np +import random +import torch +import shutil +import csv +import pprint +import pandas as pd +from loguru import logger +from collections import OrderedDict +import matplotlib.pyplot as plt +import pickle +import time +import hashlib +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp +import cv2 +import utils.media +import utils.fast_render + +def write_wav_names_to_csv(folder_path, csv_path): + """ + Traverse a folder and write the base names of all .wav files to a CSV file. + + :param folder_path: Path to the folder to traverse. + :param csv_path: Path to the CSV file to write. + """ + # Open the CSV file for writing + with open(csv_path, mode='w', newline='') as file: + writer = csv.writer(file) + # Write the header + writer.writerow(['id', 'type']) + + # Walk through the folder + for root, dirs, files in os.walk(folder_path): + for file in files: + # Check if the file ends with .wav + if file.endswith('.wav'): + # Extract the base name without the extension + base_name = os.path.splitext(file)[0] + # Write the base name and type to the CSV + writer.writerow([base_name, 'test']) + +def resize_motion_sequence_tensor(sequence, target_frames): + """ + Resize a batch of 8-frame motion sequences to a specified number of frames using interpolation. + + :param sequence: A (bs, 8, 165) tensor representing a batch of 8-frame motion sequences + :param target_frames: An integer representing the desired number of frames in the output sequences + :return: A (bs, target_frames, 165) tensor representing the resized motion sequences + """ + bs, _, _ = sequence.shape + + # Create a time vector for the original and target sequences + original_time = torch.linspace(0, 1, 8, device=sequence.device).view(1, -1, 1) + target_time = torch.linspace(0, 1, target_frames, device=sequence.device).view(1, -1, 1) + + # Permute the dimensions to (bs, 165, 8) for interpolation + sequence = sequence.permute(0, 2, 1) + + # Interpolate each joint's motion to the target number of frames + resized_sequence = torch.nn.functional.interpolate(sequence, size=target_frames, mode='linear', align_corners=True) + + # Permute the dimensions back to (bs, target_frames, 165) + resized_sequence = resized_sequence.permute(0, 2, 1) + + return resized_sequence + +def adjust_speed_according_to_ratio_tensor(chunks): + """ + Adjust the playback speed within a batch of 32-frame chunks according to random intervals. + + :param chunks: A (bs, 32, 165) tensor representing a batch of motion chunks + :return: A (bs, 32, 165) tensor representing the motion chunks after speed adjustment + """ + bs, _, _ = chunks.shape + + # Step 1: Divide the chunk into 4 equal intervals of 8 frames + equal_intervals = torch.chunk(chunks, 4, dim=1) + + # Step 2: Randomly sample 3 points within the chunk to determine new intervals + success = 0 + all_success = [] + #sample_points = torch.sort(torch.randint(1, 32, (bs, 3), device=chunks.device), dim=1).values + # new_intervals_boundaries = torch.cat([torch.zeros((bs, 1), device=chunks.device, dtype=torch.long), sample_points, 32*torch.ones((bs, 1), device=chunks.device, dtype=torch.long)], dim=1) + while success != 1: + sample_points = sorted(random.sample(range(1, 32), 3)) + new_intervals_boundaries = [0] + sample_points + [32] + new_intervals = [chunks[0][new_intervals_boundaries[i]:new_intervals_boundaries[i+1]] for i in range(4)] + speed_ratios = [8 / len(new_interval) for new_interval in new_intervals] + # if any of the speed ratios is greater than 3 or less than 0.33, resample + if all([0.33 <= speed_ratio <= 3 for speed_ratio in speed_ratios]): + success += 1 + all_success.append(new_intervals_boundaries) + new_intervals_boundaries = torch.from_numpy(np.array(all_success)) + # print(new_intervals_boundaries) + all_shapes = new_intervals_boundaries[:, 1:] - new_intervals_boundaries[:, :-1] + # Step 4: Adjust the speed of each new interval + adjusted_intervals = [] + # print(equal_intervals[0].shape) + for i in range(4): + adjusted_interval = resize_motion_sequence_tensor(equal_intervals[i], all_shapes[0, i]) + adjusted_intervals.append(adjusted_interval) + + # Step 5: Concatenate the adjusted intervals + adjusted_chunk = torch.cat(adjusted_intervals, dim=1) + + return adjusted_chunk + +def compute_exact_iou(bbox1, bbox2): + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) + y2 = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) + + intersection_area = max(0, x2 - x1) * max(0, y2 - y1) + bbox1_area = bbox1[2] * bbox1[3] + bbox2_area = bbox2[2] * bbox2[3] + union_area = bbox1_area + bbox2_area - intersection_area + + if union_area == 0: + return 0 + + return intersection_area / union_area + +def compute_iou(mask1, mask2): + # Compute the intersection + intersection = np.logical_and(mask1, mask2).sum() + + # Compute the union + union = np.logical_or(mask1, mask2).sum() + + # Compute the IoU + iou = intersection / union + + return iou + +def blankblending(all_frames, x, n): + return all_frames[x:x+n+1] + +def synthesize_intermediate_frames_FILM(frame1, frame2, t, name, save_path): + import replicate + from urllib.request import urlretrieve + import os + cv2.imwrite(save_path[:-9]+name+"_frame1.png", frame1) + cv2.imwrite(save_path[:-9]+name+"_frame2.png", frame2) + os.environ["REPLICATE_API_TOKEN"] = "r8_He1rkPk9GAxNQ3LpOohK8sYw1SUfMYV3Fxk9b" + output = replicate.run( + "google-research/frame-interpolation:4f88a16a13673a8b589c18866e540556170a5bcb2ccdc12de556e800e9456d3d", + input={ + "frame1": open(save_path[:-9]+name+"_frame1.png", "rb"), + "frame2": open(save_path[:-9]+name+"_frame2.png", "rb"), + "times_to_interpolate": t, + } + ) + print(output) + urlretrieve(output, save_path[:-9]+name+"_inter.mp4") + return load_video_as_numpy_array(save_path[:-9]+name+"_inter.mp4") + +def load_video_as_numpy_array(video_path): + cap = cv2.VideoCapture(video_path) + + # Using list comprehension to read frames and store in a list + frames = [frame for ret, frame in iter(lambda: cap.read(), (False, None)) if ret] + + cap.release() + + return np.array(frames) + +def synthesize_intermediate_frames_bidirectional(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + # Convert the frames to grayscale + gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # Calculate the forward and backward optical flow + forward_flow = cv2.calcOpticalFlowFarneback(gray1, gray2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + backward_flow = cv2.calcOpticalFlowFarneback(gray2, gray1, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / n # Interpolation factor + + # Compute the intermediate forward and backward flow + intermediate_forward_flow = forward_flow * alpha + intermediate_backward_flow = backward_flow * (1 - alpha) + + # Warp the frames based on the intermediate flow + h, w = frame1.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + forward_displacement = flow_map + intermediate_forward_flow.reshape(-1, 2) + backward_displacement = flow_map - intermediate_backward_flow.reshape(-1, 2) + + # Use cv2.remap for efficient warping + remap_x_forward, remap_y_forward = np.clip(forward_displacement[:, 1], 0, w - 1), np.clip(forward_displacement[:, 0], 0, h - 1) + remap_x_backward, remap_y_backward = np.clip(backward_displacement[:, 1], 0, w - 1), np.clip(backward_displacement[:, 0], 0, h - 1) + + warped_forward = cv2.remap(frame1, remap_x_forward.reshape(h, w).astype(np.float32), remap_y_forward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + warped_backward = cv2.remap(frame2, remap_x_backward.reshape(h, w).astype(np.float32), remap_y_backward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + + # Blend the warped frames to generate the intermediate frame + intermediate_frame = cv2.addWeighted(warped_forward, 1 - alpha, warped_backward, alpha, 0) + synthesized_frames.append(intermediate_frame) + + return synthesized_frames # Return n-2 synthesized intermediate frames + + +def linear_interpolate_frames(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / (n) # Correct interpolation factor + inter_frame = cv2.addWeighted(frame1, 1 - alpha, frame2, alpha, 0) + synthesized_frames.append(inter_frame) + return synthesized_frames[:-1] + +def warp_frame(src_frame, flow): + h, w = flow.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + displacement = flow_map + flow.reshape(-1, 2) + + # Extract x and y coordinates of the displacement + x_coords = np.clip(displacement[:, 1], 0, w - 1).reshape(h, w).astype(np.float32) + y_coords = np.clip(displacement[:, 0], 0, h - 1).reshape(h, w).astype(np.float32) + + # Use cv2.remap for efficient warping + warped_frame = cv2.remap(src_frame, x_coords, y_coords, interpolation=cv2.INTER_LINEAR) + + return warped_frame + +def synthesize_intermediate_frames(all_frames, x, n): + # Calculate Optical Flow between the first and last frame + frame1 = cv2.cvtColor(all_frames[x], cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(all_frames[x + n], cv2.COLOR_BGR2GRAY) + flow = cv2.calcOpticalFlowFarneback(frame1, frame2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame + alpha = i / (n) # Interpolation factor + intermediate_flow = flow * alpha # Interpolate the flow + intermediate_frame = warp_frame(all_frames[x], intermediate_flow) # Warp the first frame + synthesized_frames.append(intermediate_frame) + + return synthesized_frames + + +def map2color(s): + m = hashlib.md5() + m.update(s.encode('utf-8')) + color_code = m.hexdigest()[:6] + return '#' + color_code + +def euclidean_distance(a, b): + return np.sqrt(np.sum((a - b)**2)) + +def adjust_array(x, k): + len_x = len(x) + len_k = len(k) + + # If x is shorter than k, pad with zeros + if len_x < len_k: + return np.pad(x, (0, len_k - len_x), 'constant') + + # If x is longer than k, truncate x + elif len_x > len_k: + return x[:len_k] + + # If both are of same length + else: + return x + +def onset_to_frame(onset_times, audio_length, fps): + # Calculate total number of frames for the given audio length + total_frames = int(audio_length * fps) + + # Create an array of zeros of shape (total_frames,) + frame_array = np.zeros(total_frames, dtype=np.int32) + + # For each onset time, calculate the frame number and set it to 1 + for onset in onset_times: + frame_num = int(onset * fps) + # Check if the frame number is within the array bounds + if 0 <= frame_num < total_frames: + frame_array[frame_num] = 1 + + return frame_array + +# def np_slerp(q1, q2, t): +# dot_product = np.sum(q1 * q2, axis=-1) +# q2_flip = np.where(dot_product[:, None] < 0, -q2, q2) # Flip quaternions where dot_product is negative +# dot_product = np.abs(dot_product) + +# angle = np.arccos(np.clip(dot_product, -1, 1)) +# sin_angle = np.sin(angle) + +# t1 = np.sin((1.0 - t) * angle) / sin_angle +# t2 = np.sin(t * angle) / sin_angle + +# return t1 * q1 + t2 * q2_flip + + +def smooth_rotvec_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using SLERP. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + n1, k1 = animation1.shape + n2, k2 = animation2.shape + animation1 = animation1.reshape(n1, k1//3, 3) + animation2 = animation2.reshape(n2, k2//3, 3) + blend_frames = min(blend_frames, len(animation1), len(animation2)) + all_int = [] + for i in range(k1//3): + # Convert rotation vectors to quaternion for the overlapping part + q = R.from_rotvec(np.concatenate([animation1[0:1, i], animation2[-2:-1, i]], axis=0))#.as_quat() + # q2 = R.from_rotvec()#.as_quat() + times = [0, blend_frames * 2 - 1] + slerp = Slerp(times, q) + interpolated = slerp(np.arange(blend_frames * 2)) + interpolated_rotvecs = interpolated.as_rotvec() + all_int.append(interpolated_rotvecs) + interpolated_rotvecs = np.concatenate(all_int, axis=1) + # result = np.vstack((animation1[:-blend_frames], interpolated_rotvecs, animation2[blend_frames:])) + result = interpolated_rotvecs.reshape(2*n1, k1) + return result + +def smooth_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using linear interpolation. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + blend_frames = min(blend_frames, len(animation1), len(animation2)) + + # Extract overlapping sections + overlap_a1 = animation1[-blend_frames:-blend_frames+1, :] + overlap_a2 = animation2[blend_frames-1:blend_frames, :] + + # Create blend weights for linear interpolation + alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1) + + # Linearly interpolate between overlapping sections + blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha + + # Extend the animations to form the result with 2n frames + if blend_frames == len(animation1) and blend_frames == len(animation2): + result = blended_overlap + else: + before_blend = animation1[:-blend_frames] + after_blend = animation2[blend_frames:] + result = np.vstack((before_blend, blended_overlap, after_blend)) + return result + +def interpolate_sequence(quaternions): + bs, n, j, _ = quaternions.shape + new_n = 2 * n + new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype) + + for i in range(n): + q1 = quaternions[:, i, :, :] + new_quaternions[:, 2*i, :, :] = q1 + + if i < n - 1: + q2 = quaternions[:, i + 1, :, :] + new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5) + else: + # For the last point, duplicate the value + new_quaternions[:, 2*i + 1, :, :] = q1 + + return new_quaternions + +def quaternion_multiply(q1, q2): + w1, x1, y1, z1 = q1 + w2, x2, y2, z2 = q2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return w, x, y, z + +def quaternion_conjugate(q): + w, x, y, z = q + return (w, -x, -y, -z) + +def slerp(q1, q2, t): + dot = torch.sum(q1 * q2, dim=-1, keepdim=True) + + flip = (dot < 0).float() + q2 = (1 - flip * 2) * q2 + dot = dot * (1 - flip * 2) + + DOT_THRESHOLD = 0.9995 + mask = (dot > DOT_THRESHOLD).float() + + theta_0 = torch.acos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / torch.norm(q3, dim=-1, keepdim=True) + + interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3) + + return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated + +def estimate_linear_velocity(data_seq, dt): + ''' + Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates + the velocity for the middle T-2 steps using a second order central difference scheme. + The first and last frames are with forward and backward first-order + differences, respectively + - h : step size + ''' + # first steps is forward diff (t+1 - t) / dt + init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt + + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) + return vel_seq + +def velocity2position(data_seq, dt, init_pos): + res_trans = [] + for i in range(data_seq.shape[1]): + if i == 0: + res_trans.append(init_pos.unsqueeze(1)) + else: + res = data_seq[:, i-1:i] * dt + res_trans[-1] + res_trans.append(res) + return torch.cat(res_trans, dim=1) + +def estimate_angular_velocity(rot_seq, dt): + ''' + Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. + Input sequence should be of shape (B, T, ..., 3, 3) + ''' + # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix + dRdt = estimate_linear_velocity(rot_seq, dt) + R = rot_seq + RT = R.transpose(-1, -2) + # compute skew-symmetric angular velocity tensor + w_mat = torch.matmul(dRdt, RT) + # pull out angular velocity vector by averaging symmetric entries + w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 + w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 + w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 + w = torch.stack([w_x, w_y, w_z], axis=-1) + return w + +def image_from_bytes(image_bytes): + import matplotlib.image as mpimg + from io import BytesIO + return mpimg.imread(BytesIO(image_bytes), format='PNG') + +def process_frame(i, vertices_all, vertices1_all, faces, output_dir, filenames): + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import trimesh + import pyrender + + def deg_to_rad(degrees): + return degrees * np.pi / 180 + + uniform_color = [220, 220, 220, 255] + resolution = (1000, 1000) + figsize = (10, 10) + + fig, axs = plt.subplots( + nrows=1, + ncols=2, + figsize=(figsize[0] * 2, figsize[1] * 1) + ) + axs = axs.flatten() + + vertices = vertices_all[i] + vertices1 = vertices1_all[i] + filename = f"{output_dir}frame_{i}.png" + filenames.append(filename) + if i%100 == 0: + print('processed', i, 'frames') + #time_s = time.time() + #print(vertices.shape) + angle_rad = deg_to_rad(-2) + pose_camera = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 1.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 5.0], + [0.0, 0.0, 0.0, 1.0] + ]) + angle_rad = deg_to_rad(-30) + pose_light = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 0.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 3.0], + [0.0, 0.0, 0.0, 1.0] + ]) + + for vtx_idx, vtx in enumerate([vertices, vertices1]): + trimesh_mesh = trimesh.Trimesh( + vertices=vtx, + faces=faces, + vertex_colors=uniform_color + ) + mesh = pyrender.Mesh.from_trimesh( + trimesh_mesh, smooth=True + ) + scene = pyrender.Scene() + scene.add(mesh) + camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + scene.add(camera, pose=pose_camera) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=4.0) + scene.add(light, pose=pose_light) + renderer = pyrender.OffscreenRenderer(*resolution) + color, _ = renderer.render(scene) + axs[vtx_idx].imshow(color) + axs[vtx_idx].axis('off') + renderer.delete() + + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + +def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, filenames): + import multiprocessing + # import trimesh + num_cores = multiprocessing.cpu_count() - 1 # This will get the number of cores on your machine. + # mesh = trimesh.Trimesh(vertices_all[0], faces) + # scene = mesh.scene() + # fov = scene.camera.fov.copy() + # fov[0] = 80.0 + # fov[1] = 60.0 + # camera_params = { + # 'fov': fov, + # 'resolution': scene.camera.resolution, + # 'focal': scene.camera.focal, + # 'z_near': scene.camera.z_near, + # "z_far": scene.camera.z_far, + # 'transform': scene.graph[scene.camera.name][0] + # } + # mesh1 = trimesh.Trimesh(vertices1_all[0], faces) + # scene1 = mesh1.scene() + # camera_params1 = { + # 'fov': fov, + # 'resolution': scene1.camera.resolution, + # 'focal': scene1.camera.focal, + # 'z_near': scene1.camera.z_near, + # "z_far": scene1.camera.z_far, + # 'transform': scene1.graph[scene1.camera.name][0] + # } + # Use a Pool to manage the processes + # print(num_cores) + # for i in range(frames): + # process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) + for i in range(frames): + process_frame(i*3, vertices_all, vertices1_all, faces, output_dir, filenames) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + +def render_one_sequence( + res_npz_path, + gt_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import moviepy.editor as mp + import librosa + + model = smplx.create(model_folder, model_type=model_type, + gender=gender, use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, use_pca=False).cuda() + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + # if not use_matplotlib: + # import trimesh + #import pyrender + from pyvirtualdisplay import Display + #''' + #display = Display(visible=0, size=(1000, 1000)) + #display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).cuda() + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).cuda() + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).cuda() + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).cuda() + pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).cuda() + transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).cuda() + output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + leye_pose=pose1[:, 69:72], + reye_pose=pose1[:, 72:75],return_verts=True) + vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + silent_video_file_path = utils.fast_render.generate_silent_videos(args.render_video_fps, + args.render_video_width, + args.render_video_height, + args.render_concurrent_num, + args.render_tmp_img_filetype, + int(seconds*args.render_video_fps), + vertices_all, + vertices1_all, + faces, + output_dir) + base_filename_without_ext = os.path.splitext(os.path.basename(res_npz_path))[0] + final_clip = os.path.join(output_dir, f"{base_filename_without_ext}.mp4") + utils.media.add_audio_to_video(silent_video_file_path, audio_path, final_clip) + os.remove(silent_video_file_path) + return final_clip + +def render_one_sequence_no_gt( + res_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import moviepy.editor as mp + import librosa + + model = smplx.create(model_folder, model_type=model_type, + gender=gender, use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, use_pca=False).cuda() + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + # gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + # if not use_matplotlib: + # import trimesh + #import pyrender + from pyvirtualdisplay import Display + #''' + #display = Display(visible=0, size=(1000, 1000)) + #display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).cuda() + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).cuda() + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).cuda() + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + # beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + # expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).cuda() + # jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).cuda() + # pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).cuda() + # transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).cuda() + # output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + # leye_pose=pose1[:, 69:72], + # reye_pose=pose1[:, 72:75],return_verts=True) + # vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + silent_video_file_path = utils.fast_render.generate_silent_videos_no_gt(args.render_video_fps, + args.render_video_width, + args.render_video_height, + args.render_concurrent_num, + args.render_tmp_img_filetype, + int(seconds*args.render_video_fps), + vertices_all, + faces, + output_dir) + base_filename_without_ext = os.path.splitext(os.path.basename(res_npz_path))[0] + final_clip = os.path.join(output_dir, f"{base_filename_without_ext}.mp4") + utils.media.add_audio_to_video(silent_video_file_path, audio_path, final_clip) + os.remove(silent_video_file_path) + return final_clip + +def print_exp_info(args): + logger.info(pprint.pformat(vars(args))) + logger.info(f"# ------------ {args.name} ----------- #") + logger.info("PyTorch version: {}".format(torch.__version__)) + logger.info("CUDA version: {}".format(torch.version.cuda)) + logger.info("{} GPUs".format(torch.cuda.device_count())) + logger.info(f"Random Seed: {args.random_seed}") + +def args2csv(args, get_head=False, list4print=[]): + for k, v in args.items(): + if isinstance(args[k], dict): + args2csv(args[k], get_head, list4print) + else: list4print.append(k) if get_head else list4print.append(v) + return list4print + +class EpochTracker: + def __init__(self, metric_names, metric_directions): + assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length" + + + self.metric_names = metric_names + self.states = ['train', 'val', 'test'] + self.types = ['last', 'best'] + + + self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0} + for type_ in self.types} + for state in self.states} + for name, is_higher_better in zip(metric_names, metric_directions)} + + self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}") + for state in self.states} + for name in metric_names} + + + self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)} + self.train_history = {name: [] for name in metric_names} + self.val_history = {name: [] for name in metric_names} + + + def update_meter(self, name, state, value): + self.loss_meters[name][state].update(value) + + + def update_values(self, name, state, epoch): + value_avg = self.loss_meters[name][state].avg + new_best = False + + + if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or + (value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])): + self.values[name][state]['best']['value'] = value_avg + self.values[name][state]['best']['epoch'] = epoch + new_best = True + self.values[name][state]['last']['value'] = value_avg + self.values[name][state]['last']['epoch'] = epoch + return new_best + + + def get(self, name, state, type_): + return self.values[name][state][type_] + + + def reset(self): + for name in self.metric_names: + for state in self.states: + self.loss_meters[name][state].reset() + + + def flatten_values(self): + flat_dict = {} + for name in self.metric_names: + for state in self.states: + for type_ in self.types: + value_key = f"{name}_{state}_{type_}" + epoch_key = f"{name}_{state}_{type_}_epoch" + flat_dict[value_key] = self.values[name][state][type_]['value'] + flat_dict[epoch_key] = self.values[name][state][type_]['epoch'] + return flat_dict + + def update_and_plot(self, name, epoch, save_path): + new_best_train = self.update_values(name, 'train', epoch) + new_best_val = self.update_values(name, 'val', epoch) + + + self.train_history[name].append(self.loss_meters[name]['train'].avg) + self.val_history[name].append(self.loss_meters[name]['val'].avg) + + + train_values = self.train_history[name] + val_values = self.val_history[name] + epochs = list(range(1, len(train_values) + 1)) + + + plt.figure(figsize=(10, 6)) + plt.plot(epochs, train_values, label='Train') + plt.plot(epochs, val_values, label='Val') + plt.title(f'Train vs Val {name} over epochs') + plt.xlabel('Epochs') + plt.ylabel(name) + plt.legend() + plt.savefig(save_path) + plt.close() + + + return new_best_train, new_best_val + +def record_trial(args, tracker): + """ + 1. record notes, score, env_name, experments_path, + """ + csv_path = args.out_path + "custom/" +args.csv_name+".csv" + all_print_dict = vars(args) + all_print_dict.update(tracker.flatten_values()) + if not os.path.exists(csv_path): + pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False) + else: + df_existing = pd.read_csv(csv_path) + df_new = pd.DataFrame([all_print_dict]) + df_aligned = df_existing.append(df_new).fillna("") + df_aligned.to_csv(csv_path, index=False) + +def set_random_seed(args): + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + np.random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + torch.cuda.manual_seed_all(args.random_seed) + torch.cuda.manual_seed(args.random_seed) + torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC + torch.backends.cudnn.benchmark = args.benchmark + torch.backends.cudnn.enabled = args.cudnn_enabled + +def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None): + if lrs is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(), + 'lrs':lrs.state_dict(),} + elif opt is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(),} + else: + states = { 'model_state': model.state_dict(),} + torch.save(states, save_path) + +def load_checkpoints(model, save_path, load_name='model'): + states = torch.load(save_path) + new_weights = OrderedDict() + flag=False + for k, v in states['model_state'].items(): + #print(k) + if "module" not in k: + break + else: + new_weights[k[7:]]=v + flag=True + if flag: + try: + model.load_state_dict(new_weights) + except: + #print(states['model_state']) + model.load_state_dict(states['model_state']) + else: + model.load_state_dict(states['model_state']) + logger.info(f"load self-pretrained checkpoints for {load_name}") + +def model_complexity(model, args): + from ptflops import get_model_complexity_info + flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN), + as_strings=False, print_per_layer_stat=False) + logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9)) + logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6)) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) \ No newline at end of file diff --git a/utils/__pycache__/config.cpython-310.pyc b/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccb99e557a2f04c99edf6886eab1aa73a6e428cb Binary files /dev/null and b/utils/__pycache__/config.cpython-310.pyc differ diff --git a/utils/__pycache__/config.cpython-38.pyc b/utils/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1705878047f033c56b6a971989da04fc68b6487a Binary files /dev/null and b/utils/__pycache__/config.cpython-38.pyc differ diff --git a/utils/__pycache__/data_transfer.cpython-310.pyc b/utils/__pycache__/data_transfer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c805fd8d2788e0b9d089d2cbd4e0c120d5d621e4 Binary files /dev/null and b/utils/__pycache__/data_transfer.cpython-310.pyc differ diff --git a/utils/__pycache__/data_transfer.cpython-38.pyc b/utils/__pycache__/data_transfer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a362ad170dee4be8d1af64031d61167815538b4 Binary files /dev/null and b/utils/__pycache__/data_transfer.cpython-38.pyc differ diff --git a/utils/__pycache__/fast_render.cpython-310.pyc b/utils/__pycache__/fast_render.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcfd4eb98ed471d13401d6a9838d57d37bc00700 Binary files /dev/null and b/utils/__pycache__/fast_render.cpython-310.pyc differ diff --git a/utils/__pycache__/fast_render.cpython-38.pyc b/utils/__pycache__/fast_render.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fdd15ee756a1b24520fc7f0cc4510aaf6c3bd7e Binary files /dev/null and b/utils/__pycache__/fast_render.cpython-38.pyc differ diff --git a/utils/__pycache__/logger_tools.cpython-310.pyc b/utils/__pycache__/logger_tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86143e276b726736951c74c3b9e11a0e23d0d82a Binary files /dev/null and b/utils/__pycache__/logger_tools.cpython-310.pyc differ diff --git a/utils/__pycache__/logger_tools.cpython-38.pyc b/utils/__pycache__/logger_tools.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90d54863ed1d53f243e77a6927d3180b6f92a6c8 Binary files /dev/null and b/utils/__pycache__/logger_tools.cpython-38.pyc differ diff --git a/utils/__pycache__/media.cpython-310.pyc b/utils/__pycache__/media.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f503c310af12630fe9b8187739a96531cf5366ea Binary files /dev/null and b/utils/__pycache__/media.cpython-310.pyc differ diff --git a/utils/__pycache__/media.cpython-38.pyc b/utils/__pycache__/media.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8dc1965a5b0d4604cdcb51e146e3923e7bffc77 Binary files /dev/null and b/utils/__pycache__/media.cpython-38.pyc differ diff --git a/utils/__pycache__/metric.cpython-310.pyc b/utils/__pycache__/metric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf224063392388937113a4cff8428b9204338ee0 Binary files /dev/null and b/utils/__pycache__/metric.cpython-310.pyc differ diff --git a/utils/__pycache__/metric.cpython-38.pyc b/utils/__pycache__/metric.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbe2362798419f4f9a966be130ee90aded3d0a93 Binary files /dev/null and b/utils/__pycache__/metric.cpython-38.pyc differ diff --git a/utils/__pycache__/other_tools_hf.cpython-310.pyc b/utils/__pycache__/other_tools_hf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bce850dbcfc5f9a3539a75d28b48ad5991602ef Binary files /dev/null and b/utils/__pycache__/other_tools_hf.cpython-310.pyc differ diff --git a/utils/__pycache__/other_tools_hf.cpython-38.pyc b/utils/__pycache__/other_tools_hf.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d2e5f320c10ff27ab4948be839ebf6690af28f2 Binary files /dev/null and b/utils/__pycache__/other_tools_hf.cpython-38.pyc differ diff --git a/utils/__pycache__/rotation_conversions.cpython-310.pyc b/utils/__pycache__/rotation_conversions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..965dc587eef87e67ffc1a8bc79635df7075d5ae1 Binary files /dev/null and b/utils/__pycache__/rotation_conversions.cpython-310.pyc differ diff --git a/utils/__pycache__/rotation_conversions.cpython-38.pyc b/utils/__pycache__/rotation_conversions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1628d0abc77b8f33f68b954ff62f464981614af9 Binary files /dev/null and b/utils/__pycache__/rotation_conversions.cpython-38.pyc differ diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac207d7270c15532fa55b9184af3d64ceae9fc6 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,275 @@ +import configargparse +import time +import json +import yaml +import os + + +def str2bool(v): + """ from https://stackoverflow.com/a/43357954/1361529 """ + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise configargparse.ArgumentTypeError('Boolean value expected.') + + +def parse_args(): + """ + requirement for config + 1. command > yaml > default + 2. avoid re-definition + 3. lowercase letters is better + 4. hierarchical is not necessary + """ + parser = configargparse.ArgParser() + parser.add("-c", "--config", default='./configs/emage_test_hf.yaml', is_config_file=True) + parser.add("--project", default="audio2pose", type=str) # wandb project name + parser.add("--stat", default="ts", type=str) + parser.add("--csv_name", default="a2g_0", type=str) # local device id + parser.add("--notes", default="", type=str) + parser.add("--trainer", default="camn", type=str) + + parser.add("--l", default=4, type=int) + # ------------- path and save name ---------------- # + parser.add("--is_train", default=True, type=str2bool) + parser.add("--debug", default=False, type=str2bool) + # different between environments + parser.add("--root_path", default="/home/ma-user/work/") + parser.add("--cache_path", default="/outputs/audio2pose/", type=str) + parser.add("--out_path", default="/outputs/audio2pose/", type=str) + parser.add("--data_path", default="/outputs/audio2pose/", type=str) + parser.add("--train_data_path", default="/datasets/trinity/train/", type=str) + parser.add("--val_data_path", default="/datasets/trinity/val/", type=str) + parser.add("--test_data_path", default="/datasets/trinity/test/", type=str) + parser.add("--mean_pose_path", default="/datasets/trinity/train/", type=str) + parser.add("--std_pose_path", default="/datasets/trinity/train/", type=str) + # for pretrian weights + parser.add("--data_path_1", default="../../datasets/checkpoints/", type=str) + # ------------------- evaluation ----------------------- # + parser.add("--test_ckpt", default="/datasets/beat_cache/beat_4english_15_141/last.bin") + parser.add("--eval_model", default="vae", type=str) + parser.add("--e_name", default=None, type=str) #HalfEmbeddingNet + parser.add("--e_path", default="/datasets/beat/generated_data/self_vae_128.bin") + parser.add("--variational", default=False, type=str2bool) + parser.add("--vae_length", default=256, type=int) + parser.add("--vae_test_dim", default=141, type=int) + parser.add("--vae_test_len", default=34, type=int) + parser.add("--vae_test_stride", default=10, type=int) + #parser.add("--vae_pose_length", default=34, type=int) + parser.add("--test_period", default=20, type=int) + parser.add("--vae_codebook_size", default=1024, type=int) + parser.add("--vae_quantizer_lambda", default=1., type=float) + + parser.add("--vae_layer", default=2, type=int) + parser.add("--vae_grow", default=[1,1,2,1], type=int, nargs="*") + parser.add("--lf", default=0., type=float) + parser.add("--ll", default=0., type=float) + parser.add("--lu", default=0., type=float) + parser.add("--lh", default=0., type=float) + parser.add("--cf", default=0., type=float) + parser.add("--cl", default=0., type=float) + parser.add("--cu", default=0., type=float) + parser.add("--ch", default=0., type=float) + + + # --------------- data ---------------------------- # + parser.add("--additional_data", default=False, type=str2bool) + parser.add("--train_trans", default=True, type=str2bool) + parser.add("--dataset", default="beat", type=str) + parser.add("--rot6d", default=True, type=str2bool) + parser.add("--ori_joints", default="spine_neck_141", type=str) + parser.add("--tar_joints", default="spine_neck_141", type=str) + parser.add("--training_speakers", default=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], type=int, nargs="*") + #parser.add("--pose_version", default="spine_neck_141", type=str) + parser.add("--new_cache", default=True, type=str2bool) + parser.add("--beat_align", default=True, type=str2bool) + parser.add("--cache_only", default=False, type=str2bool) + parser.add("--word_cache", default=False, type=str2bool) + parser.add("--use_aug", default=False, type=str2bool) + parser.add("--disable_filtering", default=False, type=str2bool) + parser.add("--clean_first_seconds", default=0, type=int) + parser.add("--clean_final_seconds", default=0, type=int) + + parser.add("--audio_rep", default=None, type=str) + parser.add("--audio_sr", default=16000, type=int) + parser.add("--word_rep", default=None, type=str) + parser.add("--emo_rep", default=None, type=str) + parser.add("--sem_rep", default=None, type=str) + parser.add("--facial_rep", default=None, type=str) + parser.add("--pose_rep", default="bvhrot", type=str) + parser.add("--id_rep", default="onehot", type=str) + parser.add("--speaker_id", default="onehot", type=str) + + parser.add("--a_pre_encoder", default=None, type=str) + parser.add("--a_encoder", default=None, type=str) + parser.add("--a_fix_pre", default=False, type=str2bool) + parser.add("--t_pre_encoder", default=None, type=str) + parser.add("--t_encoder", default=None, type=str) + parser.add("--t_fix_pre", default=False, type=str2bool) + parser.add("--m_pre_encoder", default=None, type=str) + parser.add("--m_encoder", default=None, type=str) + parser.add("--m_fix_pre", default=False, type=str2bool) + parser.add("--f_pre_encoder", default=None, type=str) + parser.add("--f_encoder", default=None, type=str) + parser.add("--f_fix_pre", default=False, type=str2bool) + parser.add("--m_decoder", default=None, type=str) + parser.add("--decode_fusion", default=None, type=str) + parser.add("--atmr", default=0.0, type=float) + parser.add("--ttmr", default=0., type=float) + parser.add("--mtmr", default=0., type=float) + parser.add("--ftmr", default=0., type=float) + parser.add("--asmr", default=0., type=float) + parser.add("--tsmr", default=0., type=float) + parser.add("--msmr", default=0., type=float) + parser.add("--fsmr", default=0., type=float) +# parser.add("--m_encoder", default=None, type=str) +# parser.add("--m_pre_fix", default=None, type=str) +# parser.add("--id_rep", default=None, type=str) + + parser.add("--freeze_wordembed", default=True, type=str2bool) + parser.add("--audio_fps", default=16000, type=int) + parser.add("--facial_fps", default=15, type=int) + parser.add("--pose_fps", default=15, type=int) + + parser.add("--audio_dims", default=1, type=int) + parser.add("--facial_dims", default=39, type=int) + parser.add("--pose_dims", default=123, type=int) + parser.add("--word_index_num", default=5793, type=int) + parser.add("--word_dims", default=300, type=int) + parser.add("--speaker_dims", default=4, type=int) + parser.add("--emotion_dims", default=8, type=int) + + parser.add("--audio_norm", default=False, type=str2bool) + parser.add("--facial_norm", default=False, type=str2bool) + parser.add("--pose_norm", default=False, type=str2bool) + + parser.add("--pose_length", default=34, type=int) + parser.add("--pre_frames", default=4, type=int) + parser.add("--stride", default=10, type=int) + parser.add("--pre_type", default="zero", type=str) + + parser.add("--audio_f", default=0, type=int) + parser.add("--motion_f", default=0, type=int) + parser.add("--facial_f", default=0, type=int) + parser.add("--speaker_f", default=0, type=int) + parser.add("--word_f", default=0, type=int) + parser.add("--emotion_f", default=0, type=int) + parser.add("--aud_prob", default=1.0, type=float) + parser.add("--pos_prob", default=1.0, type=float) + parser.add("--txt_prob", default=1.0, type=float) + parser.add("--fac_prob", default=1.0, type=float) + parser.add("--multi_length_training", default=[1.0], type=float, nargs="*") + # --------------- model ---------------------------- # + parser.add("--pretrain", default=False, type=str2bool) + parser.add("--model", default="camn", type=str) + parser.add("--g_name", default="CaMN", type=str) + parser.add("--d_name", default=None, type=str) #ConvDiscriminator + parser.add("--dropout_prob", default=0.3, type=float) + parser.add("--n_layer", default=4, type=int) + parser.add("--hidden_size", default=300, type=int) + #parser.add("--period", default=34, type=int) + parser.add("--test_length", default=34, type=int) + # Self-designed "Multi-Stage", "Seprate", or "Original" + parser.add("--finger_net", default="original", type=str) + parser.add("--pos_encoding_type", default="sin", type=str) + parser.add("--queue_size", default=1024, type=int) + + # --------------- training ------------------------- # + parser.add("--epochs", default=120, type=int) + parser.add("--epoch_stage", default=0, type=int) + parser.add("--grad_norm", default=0, type=float) + parser.add("--no_adv_epoch", default=999, type=int) + parser.add("--batch_size", default=128, type=int) + parser.add("--opt", default="adam", type=str) + parser.add("--lr_base", default=0.00025, type=float) + parser.add("--opt_betas", default=[0.5, 0.999], type=float, nargs="*") + parser.add("--weight_decay", default=0., type=float) + # for warmup and cosine + parser.add("--lr_min", default=1e-7, type=float) + parser.add("--warmup_lr", default=5e-4, type=float) + parser.add("--warmup_epochs", default=0, type=int) + parser.add("--decay_epochs", default=9999, type=int) + parser.add("--decay_rate", default=0.1, type=float) + parser.add("--lr_policy", default="step", type=str) + # for sgd + parser.add("--momentum", default=0.8, type=float) + parser.add("--nesterov", default=True, type=str2bool) + parser.add("--amsgrad", default=False, type=str2bool) + parser.add("--d_lr_weight", default=0.2, type=float) + parser.add("--rec_weight", default=500, type=float) + parser.add("--adv_weight", default=20.0, type=float) + parser.add("--fid_weight", default=0.0, type=float) + parser.add("--vel_weight", default=0.0, type=float) + parser.add("--acc_weight", default=0.0, type=float) + parser.add("--kld_weight", default=0.0, type=float) + parser.add("--kld_aud_weight", default=0.0, type=float) + parser.add("--kld_fac_weight", default=0.0, type=float) + parser.add("--ali_weight", default=0.0, type=float) + parser.add("--ita_weight", default=0.0, type=float) + parser.add("--iwa_weight", default=0.0, type=float) + parser.add("--wei_weight", default=0.0, type=float) + parser.add("--gap_weight", default=0.0, type=float) + parser.add("--atcont", default=0.0, type=float) + parser.add("--fusion_mode", default="sum", type=str) + + parser.add("--div_reg_weight", default=0.0, type=float) + parser.add("--rec_aud_weight", default=0.0, type=float) + parser.add("--rec_ver_weight", default=0.0, type=float) + parser.add("--rec_pos_weight", default=0.0, type=float) + parser.add("--rec_fac_weight", default=0.0, type=float) + parser.add("--rec_txt_weight", default=0.0, type=float) +# parser.add("--gan_noise_size", default=0, type=int) + # --------------- ha2g -------------------------- # + parser.add("--n_pre_poses", default=4, type=int) + parser.add("--n_poses", default=34, type=int) + parser.add("--input_context", default="both", type=str) + parser.add("--loss_contrastive_pos_weight", default=0.2, type=float) + parser.add("--loss_contrastive_neg_weight", default=0.005, type=float) + parser.add("--loss_physical_weight", default=0.0, type=float) + parser.add("--loss_reg_weight", default=0.05, type=float) + parser.add("--loss_regression_weight", default=70.0, type=float) + parser.add("--loss_gan_weight", default=5.0, type=float) + parser.add("--loss_kld_weight", default=0.1, type=float) + parser.add("--z_type", default="speaker", type=str) + # --------------- device -------------------------- # + parser.add("--random_seed", default=2021, type=int) + parser.add("--deterministic", default=True, type=str2bool) + parser.add("--benchmark", default=True, type=str2bool) + parser.add("--cudnn_enabled", default=True, type=str2bool) + # mix precision + parser.add("--apex", default=False, type=str2bool) + parser.add("--gpus", default=[0], type=int, nargs="*") + parser.add("--loader_workers", default=0, type=int) + parser.add("--ddp", default=False, type=str2bool) + parser.add("--sparse", default=1, type=int) + #parser.add("--world_size") + parser.add("--render_video_fps", default=30, type=int) + parser.add("--render_video_width", default=1920, type=int) + parser.add("--render_video_height", default=720, type=int) + cpu_cores = os.cpu_count() if os.cpu_count() is not None else 1 + default_concurrent = max(1, cpu_cores // 2) + parser.add("--render_concurrent_num", default=default_concurrent, type=int) + parser.add("--render_tmp_img_filetype", default="bmp", type=str) + + # logging + parser.add("--log_period", default=10, type=int) + + + args = parser.parse_args() + idc = 0 + for i, char in enumerate(args.config): + if char == "/": idc = i + args.name = args.config[idc+1:-5] + + is_train = args.is_train + + if is_train: + time_local = time.localtime() + name_expend = "%02d%02d_%02d%02d%02d_"%(time_local[1], time_local[2],time_local[3], time_local[4], time_local[5]) + args.name = name_expend + args.name + + return args \ No newline at end of file diff --git a/utils/data_transfer.py b/utils/data_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..025110cac44f42d581c4414bd3d0c6c7b21f33e0 --- /dev/null +++ b/utils/data_transfer.py @@ -0,0 +1,202 @@ +import os +import logging +import random +import h5py +import numpy as np +import pickle +import math +import numbers +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim.lr_scheduler import StepLR +from torch.distributions import Normal + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + Returns: + 6D rotation representation, of size (*, 6) + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Args: + d6: 6D rotation representation, of size (*, 6) + Returns: + batch of rotation matrices of size (*, 3, 3) + """ + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def so3_relative_angle(m1, m2): + m1 = m1.reshape(-1, 3, 3) + m2 = m2.reshape(-1, 3, 3) + #print(m2.shape) + m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3 + #print(m.shape) + cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 + #print(cos.shape) + cos = torch.clamp(cos, min=-1 + 1E-6, max=1-1E-6) + #print(cos.shape) + theta = torch.acos(cos) + #print(theta.shape) + return torch.mean(theta) diff --git a/utils/fast_render.py b/utils/fast_render.py new file mode 100644 index 0000000000000000000000000000000000000000..78311cbcc99ac414c7724e1081615825ad20843a --- /dev/null +++ b/utils/fast_render.py @@ -0,0 +1,266 @@ +import os +import time +import numpy as np +import pyrender +import trimesh +import queue +import imageio +import threading +import multiprocessing +import utils.media +import glob + +def deg_to_rad(degrees): + return degrees * np.pi / 180 + +def create_pose_camera(angle_deg): + angle_rad = deg_to_rad(angle_deg) + return np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 1.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 5.0], + [0.0, 0.0, 0.0, 1.0] + ]) + +def create_pose_light(angle_deg): + angle_rad = deg_to_rad(angle_deg) + return np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 0.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 3.0], + [0.0, 0.0, 0.0, 1.0] + ]) + +def create_scene_with_mesh(vertices, faces, uniform_color, pose_camera, pose_light): + trimesh_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=uniform_color) + mesh = pyrender.Mesh.from_trimesh(trimesh_mesh, smooth=True) + scene = pyrender.Scene() + scene.add(mesh) + camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + scene.add(camera, pose=pose_camera) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=4.0) + scene.add(light, pose=pose_light) + return scene + +def do_render_one_frame(renderer, frame_idx, vertices, vertices1, faces): + if frame_idx % 100 == 0: + print('processed', frame_idx, 'frames') + + uniform_color = [220, 220, 220, 255] + pose_camera = create_pose_camera(angle_deg=-2) + pose_light = create_pose_light(angle_deg=-30) + + figs = [] + for vtx in [vertices, vertices1]: + # print(vtx.shape) + scene = create_scene_with_mesh(vtx, faces, uniform_color, pose_camera, pose_light) + fig, _ = renderer.render(scene) + figs.append(fig) + + return figs[0], figs[1] + +def do_render_one_frame_no_gt(renderer, frame_idx, vertices, faces): + if frame_idx % 100 == 0: + print('processed', frame_idx, 'frames') + + uniform_color = [220, 220, 220, 255] + pose_camera = create_pose_camera(angle_deg=-2) + pose_light = create_pose_light(angle_deg=-30) + + figs = [] + # for vtx in [vertices]: + # print(vtx.shape) + # print(vertices.shape) + scene = create_scene_with_mesh(vertices, faces, uniform_color, pose_camera, pose_light) + fig, _ = renderer.render(scene) + figs.append(fig) + + return figs[0] + +def write_images_from_queue(fig_queue, output_dir, img_filetype): + while True: + e = fig_queue.get() + if e is None: + break + fid, fig1, fig2 = e + filename = os.path.join(output_dir, f"frame_{fid}.{img_filetype}") + merged_fig = np.hstack((fig1, fig2)) + try: + imageio.imwrite(filename, merged_fig) + except Exception as ex: + print(f"Error writing image {filename}: {ex}") + raise ex + +def write_images_from_queue_no_gt(fig_queue, output_dir, img_filetype): + while True: + e = fig_queue.get() + if e is None: + break + fid, fig1, fig2 = e + filename = os.path.join(output_dir, f"frame_{fid}.{img_filetype}") + merged_fig = fig1 #np.hstack((fig1)) + try: + imageio.imwrite(filename, merged_fig) + except Exception as ex: + print(f"Error writing image {filename}: {ex}") + raise ex + + +def render_frames_and_enqueue(fids, frame_vertex_pairs, faces, render_width, render_height, fig_queue): + fig_resolution = (render_width // 2, render_height) + renderer = pyrender.OffscreenRenderer(*fig_resolution) + + for idx, fid in enumerate(fids): + fig1, fig2 = do_render_one_frame(renderer, fid, frame_vertex_pairs[idx][0], frame_vertex_pairs[idx][1], faces) + fig_queue.put((fid, fig1, fig2)) + + renderer.delete() + +def render_frames_and_enqueue_no_gt(fids, frame_vertex_pairs, faces, render_width, render_height, fig_queue): + fig_resolution = (render_width // 2, render_height) + renderer = pyrender.OffscreenRenderer(*fig_resolution) + + for idx, fid in enumerate(fids): + fig1 = do_render_one_frame_no_gt(renderer, fid, frame_vertex_pairs[idx][0], faces) + fig_queue.put((fid, fig1)) + + renderer.delete() + +def sub_process_process_frame(subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, fids, frame_vertex_pairs, faces, output_dir): + begin_ts = time.time() + print(f"subprocess_index={subprocess_index} begin_ts={begin_ts}") + + fig_queue = queue.Queue() + render_frames_and_enqueue(fids, frame_vertex_pairs, faces, render_video_width, render_video_height, fig_queue) + fig_queue.put(None) + render_end_ts = time.time() + + image_writer_thread = threading.Thread(target=write_images_from_queue, args=(fig_queue, output_dir, render_tmp_img_filetype)) + image_writer_thread.start() + image_writer_thread.join() + + write_end_ts = time.time() + print( + f"subprocess_index={subprocess_index} " + f"render={render_end_ts - begin_ts:.2f} " + f"all={write_end_ts - begin_ts:.2f} " + f"begin_ts={begin_ts:.2f} " + f"render_end_ts={render_end_ts:.2f} " + f"write_end_ts={write_end_ts:.2f}" + ) + +def sub_process_process_frame_no_gt(subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, fids, frame_vertex_pairs, faces, output_dir): + begin_ts = time.time() + print(f"subprocess_index={subprocess_index} begin_ts={begin_ts}") + + fig_queue = queue.Queue() + render_frames_and_enqueue(fids, frame_vertex_pairs, faces, render_video_width, render_video_height, fig_queue) + fig_queue.put(None) + render_end_ts = time.time() + + image_writer_thread = threading.Thread(target=write_images_from_queue_no_gt, args=(fig_queue, output_dir, render_tmp_img_filetype)) + image_writer_thread.start() + image_writer_thread.join() + + write_end_ts = time.time() + print( + f"subprocess_index={subprocess_index} " + f"render={render_end_ts - begin_ts:.2f} " + f"all={write_end_ts - begin_ts:.2f} " + f"begin_ts={begin_ts:.2f} " + f"render_end_ts={render_end_ts:.2f} " + f"write_end_ts={write_end_ts:.2f}" + ) + +def distribute_frames(frames, render_video_fps, render_concurent_nums, vertices_all, vertices1_all): + sample_interval = max(1, int(30 // render_video_fps)) + subproc_frame_ids = [[] for _ in range(render_concurent_nums)] + subproc_vertices = [[] for _ in range(render_concurent_nums)] + sampled_frame_id = 0 + + for i in range(frames): + if i % sample_interval != 0: + continue + subprocess_index = sampled_frame_id % render_concurent_nums + subproc_frame_ids[subprocess_index].append(sampled_frame_id) + subproc_vertices[subprocess_index].append((vertices_all[i], vertices1_all[i])) + sampled_frame_id += 1 + + return subproc_frame_ids, subproc_vertices + +def distribute_frames_no_gt(frames, render_video_fps, render_concurent_nums, vertices_all): + sample_interval = max(1, int(30 // render_video_fps)) + subproc_frame_ids = [[] for _ in range(render_concurent_nums)] + subproc_vertices = [[] for _ in range(render_concurent_nums)] + sampled_frame_id = 0 + + for i in range(frames): + if i % sample_interval != 0: + continue + subprocess_index = sampled_frame_id % render_concurent_nums + subproc_frame_ids[subprocess_index].append(sampled_frame_id) + subproc_vertices[subprocess_index].append((vertices_all[i], vertices_all[i])) + sampled_frame_id += 1 + + return subproc_frame_ids, subproc_vertices + +def generate_silent_videos(render_video_fps, + render_video_width, + render_video_height, + render_concurent_nums, + render_tmp_img_filetype, + frames, + vertices_all, + vertices1_all, + faces, + output_dir): + + subproc_frame_ids, subproc_vertices = distribute_frames(frames, render_video_fps, render_concurent_nums, vertices_all, vertices1_all) + + print(f"generate_silent_videos concurrentNum={render_concurent_nums} time={time.time()}") + with multiprocessing.Pool(render_concurent_nums) as pool: + pool.starmap( + sub_process_process_frame, + [ + (subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, subproc_frame_ids[subprocess_index], subproc_vertices[subprocess_index], faces, output_dir) + for subprocess_index in range(render_concurent_nums) + ] + ) + + output_file = os.path.join(output_dir, "silence_video.mp4") + utils.media.convert_img_to_mp4(os.path.join(output_dir, f"frame_%d.{render_tmp_img_filetype}"), output_file, render_video_fps) + filenames = glob.glob(os.path.join(output_dir, f"*.{render_tmp_img_filetype}")) + for filename in filenames: + os.remove(filename) + + return output_file + +def generate_silent_videos_no_gt(render_video_fps, + render_video_width, + render_video_height, + render_concurent_nums, + render_tmp_img_filetype, + frames, + vertices_all, + faces, + output_dir): + + subproc_frame_ids, subproc_vertices = distribute_frames_no_gt(frames, render_video_fps, render_concurent_nums, vertices_all) + + print(f"generate_silent_videos concurrentNum={render_concurent_nums} time={time.time()}") + with multiprocessing.Pool(render_concurent_nums) as pool: + pool.starmap( + sub_process_process_frame_no_gt, + [ + (subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, subproc_frame_ids[subprocess_index], subproc_vertices[subprocess_index], faces, output_dir) + for subprocess_index in range(render_concurent_nums) + ] + ) + + output_file = os.path.join(output_dir, "silence_video.mp4") + utils.media.convert_img_to_mp4(os.path.join(output_dir, f"frame_%d.{render_tmp_img_filetype}"), output_file, render_video_fps) + filenames = glob.glob(os.path.join(output_dir, f"*.{render_tmp_img_filetype}")) + for filename in filenames: + os.remove(filename) + + return output_file \ No newline at end of file diff --git a/utils/logger_tools.py b/utils/logger_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..5eaf6c06274589018ba3cc21e2d4214e76611d5d --- /dev/null +++ b/utils/logger_tools.py @@ -0,0 +1,59 @@ +import os +import inspect +import sys +import yaml +#import wandb +from loguru import logger + +def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): + """setup logger for training and testing. + Args: + save_dir(str): location to save log file + distributed_rank(int): device rank when multi-gpu environment + filename (string): log save name. + mode(str): log file write mode, `append` or `override`. default is `a`. + + Return: + logger instance. + """ + loguru_format = ( + "{time: MM-DD HH:mm:ss} | " + #"{level: <8} | " + #"{name}:{line} - {message}" + "{message}" + ) + + logger.remove() + save_file = os.path.join(save_dir, filename) + if mode == "o" and os.path.exists(save_file): + os.remove(save_file) + # only keep logger in rank0 process + if distributed_rank == 0: + logger.add( + sys.stderr, + format=loguru_format, + level="INFO", + enqueue=True, + ) + logger.add(save_file, + format=loguru_format, + ) + + +def set_args_and_logger(args, rank): + """ + set logger file and print args + """ + args_name_dir = args.out_path + "custom/" + args.name + args.notes + "/" + if rank == 0: + if not os.path.exists(args_name_dir): os.makedirs(args_name_dir) + args_name = args_name_dir + "/" + args.name +".yaml" + if os.path.exists(args_name): + s_add = 10 + logger.warning(f"Already exist args, add {s_add} to ran_seed to continue training") + args.random_seed += s_add + else: + with open(args_name, "w+") as f: + yaml.dump(args.__dict__, f, default_flow_style=True) + #json.dump(args.__dict__, f) + setup_logger(args_name_dir, rank, filename=f"{args.name}.txt") \ No newline at end of file diff --git a/utils/media.py b/utils/media.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd21e079a9e48f97f1511bd289d39f4aeccc40e --- /dev/null +++ b/utils/media.py @@ -0,0 +1,39 @@ +import numpy as np +import subprocess + +def add_audio_to_video(silent_video_path, audio_path, output_video_path): + command = [ + 'ffmpeg', + '-y', + '-i', silent_video_path, + '-i', audio_path, + '-map', '0:v', + '-map', '1:a', + '-c:v', 'copy', + '-shortest', + output_video_path + ] + + try: + subprocess.run(command, check=True) + print(f"Video with audio generated successfully: {output_video_path}") + except subprocess.CalledProcessError as e: + print(f"Error occurred: {e}") + + +def convert_img_to_mp4(input_pattern, output_file, framerate=30): + command = [ + 'ffmpeg', + '-framerate', str(framerate), + '-i', input_pattern, + '-c:v', 'libx264', + '-pix_fmt', 'yuv420p', + output_file, + '-y' + ] + + try: + subprocess.run(command, check=True) + print(f"Video conversion successful. Output file: {output_file}") + except subprocess.CalledProcessError as e: + print(f"Error during video conversion: {e}") diff --git a/utils/metric.py b/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..53930062137b7ee82adb21ce226f572be77176e5 --- /dev/null +++ b/utils/metric.py @@ -0,0 +1,242 @@ +import librosa +import glob +import os +import numpy as np +import matplotlib.pyplot as plt +import librosa.display +from matplotlib.pyplot import figure +import math +from scipy.signal import argrelextrema + + +class L1div(object): + def __init__(self): + self.counter = 0 + self.sum = 0 + def run(self, results): + self.counter += results.shape[0] + mean = np.mean(results, 0) + for i in range(results.shape[0]): + results[i, :] = abs(results[i, :] - mean) + sum_l1 = np.sum(results) + self.sum += sum_l1 + def avg(self): + return self.sum/self.counter + def reset(self): + self.counter = 0 + self.sum = 0 + + +class SRGR(object): + def __init__(self, threshold=0.1, joints=47): + self.threshold = threshold + self.pose_dimes = joints + self.counter = 0 + self.sum = 0 + + def run(self, results, targets, semantic): + results = results.reshape(-1, self.pose_dimes, 3) + targets = targets.reshape(-1, self.pose_dimes, 3) + semantic = semantic.reshape(-1) + diff = np.sum(abs(results-targets),2) + success = np.where(diffself.threshold) + #print(vel.shape) + #t_end = 80 + #vel[::2, :] -= 0.000001 + #print(vel[t_start:t_end, i], vel[t_start:t_end, i].shape) + beat_vel = argrelextrema(vel[t_start:t_end, i], np.less, order=self.order) # n*47 + #print(beat_vel, t_start, t_end) + beat_vel_list = [] + for j in beat_vel[0]: + if j in vel_mask[0]: + beat_vel_list.append(j) + beat_vel = np.array(beat_vel_list) + beat_vel_all.append(beat_vel) + #print(beat_vel_all) + return beat_vel_all #beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist + + + def load_data(self, wave, pose, t_start, t_end, pose_fps): + onset_raw, onset_bt, onset_bt_rms = self.load_audio(wave, t_start, t_end) + beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist = self.load_pose(pose, t_start, t_end, pose_fps) + return onset_raw, onset_bt, onset_bt_rms, beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist + + def eval_random_pose(self, wave, pose, t_start, t_end, pose_fps, num_random=60): + onset_raw, onset_bt, onset_bt_rms = self.load_audio(wave, t_start, t_end) + dur = t_end - t_start + for i in range(num_random): + beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist = self.load_pose(pose, i, i+dur, pose_fps) + dis_all_b2a= self.calculate_align(onset_raw, onset_bt, onset_bt_rms, beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist) + print(f"{i}s: ",dis_all_b2a) + + + @staticmethod + def plot_onsets(audio, sr, onset_times_1, onset_times_2): + import librosa + import librosa.display + import matplotlib.pyplot as plt + # Plot audio waveform + fig, axarr = plt.subplots(2, 1, figsize=(10, 10), sharex=True) + + # Plot audio waveform in both subplots + librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[0]) + librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[1]) + + # Plot onsets from first method on the first subplot + for onset in onset_times_1: + axarr[0].axvline(onset, color='r', linestyle='--', alpha=0.9, label='Onset Method 1') + axarr[0].legend() + axarr[0].set(title='Onset Method 1', xlabel='', ylabel='Amplitude') + + # Plot onsets from second method on the second subplot + for onset in onset_times_2: + axarr[1].axvline(onset, color='b', linestyle='-', alpha=0.7, label='Onset Method 2') + axarr[1].legend() + axarr[1].set(title='Onset Method 2', xlabel='Time (s)', ylabel='Amplitude') + + + # Add legend (eliminate duplicate labels) + handles, labels = plt.gca().get_legend_handles_labels() + by_label = dict(zip(labels, handles)) + plt.legend(by_label.values(), by_label.keys()) + + # Show plot + plt.title("Audio waveform with Onsets") + plt.savefig("./onset.png", dpi=500) + + def audio_beat_vis(self, onset_raw, onset_bt, onset_bt_rms): + figure(figsize=(24, 6), dpi=80) + fig, ax = plt.subplots(nrows=4, sharex=True) + librosa.display.specshow(librosa.amplitude_to_db(self.S, ref=np.max), + y_axis='log', x_axis='time', ax=ax[0]) + ax[0].label_outer() + ax[1].plot(self.times, self.oenv, label='Onset strength') + ax[1].vlines(librosa.frames_to_time(onset_raw), 0, self.oenv.max(), label='Raw onsets', color='r') + ax[1].legend() + ax[1].label_outer() + + ax[2].plot(self.times, self.oenv, label='Onset strength') + ax[2].vlines(librosa.frames_to_time(onset_bt), 0, self.oenv.max(), label='Backtracked', color='r') + ax[2].legend() + ax[2].label_outer() + + ax[3].plot(self.times, self.rms[0], label='RMS') + ax[3].vlines(librosa.frames_to_time(onset_bt_rms), 0, self.oenv.max(), label='Backtracked (RMS)', color='r') + ax[3].legend() + fig.savefig("./onset.png", dpi=500) + + @staticmethod + def motion_frames2time(vel, offset, pose_fps): + time_vel = vel/pose_fps + offset + return time_vel + + @staticmethod + def GAHR(a, b, sigma): + dis_all_a2b = 0 + dis_all_b2a = 0 + for b_each in b: + l2_min = np.inf + for a_each in a: + l2_dis = abs(a_each - b_each) + if l2_dis < l2_min: + l2_min = l2_dis + dis_all_b2a += math.exp(-(l2_min**2)/(2*sigma**2)) + dis_all_b2a /= len(b) + return dis_all_b2a + + @staticmethod + def fix_directed_GAHR(a, b, sigma): + a = alignment.motion_frames2time(a, 0, 30) + b = alignment.motion_frames2time(b, 0, 30) + t = len(a)/30 + a = [0] + a + [t] + b = [0] + b + [t] + dis_a2b = alignment.GAHR(a, b, sigma) + return dis_a2b + + def calculate_align(self, onset_bt_rms, beat_vel, pose_fps=30): + audio_bt = onset_bt_rms + avg_dis_all_b2a_list = [] + for its, beat_vel_each in enumerate(beat_vel): + if its not in self.upper_body: + continue + #print(beat_vel_each) + #print(audio_bt.shape, beat_vel_each.shape) + pose_bt = self.motion_frames2time(beat_vel_each, 0, pose_fps) + #print(pose_bt) + avg_dis_all_b2a_list.append(self.GAHR(pose_bt, audio_bt, self.sigma)) + # avg_dis_all_b2a = max(avg_dis_all_b2a_list) + avg_dis_all_b2a = sum(avg_dis_all_b2a_list)/len(avg_dis_all_b2a_list) #max(avg_dis_all_b2a_list) + #print(avg_dis_all_b2a, sum(avg_dis_all_b2a_list)/47) + return avg_dis_all_b2a \ No newline at end of file diff --git a/utils/other_tools.py b/utils/other_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..64dc17ffd10a8cb32c7a0bd1d141bc724c753b03 --- /dev/null +++ b/utils/other_tools.py @@ -0,0 +1,891 @@ +import os +import numpy as np +import random +import torch +import shutil +import csv +import pprint +import pandas as pd +from loguru import logger +from collections import OrderedDict +import matplotlib.pyplot as plt +import pickle +import time +import hashlib +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp +import cv2 +import utils.media +import utils.fast_render + +def write_wav_names_to_csv(folder_path, csv_path): + """ + Traverse a folder and write the base names of all .wav files to a CSV file. + + :param folder_path: Path to the folder to traverse. + :param csv_path: Path to the CSV file to write. + """ + # Open the CSV file for writing + with open(csv_path, mode='w', newline='') as file: + writer = csv.writer(file) + # Write the header + writer.writerow(['id', 'type']) + + # Walk through the folder + for root, dirs, files in os.walk(folder_path): + for file in files: + # Check if the file ends with .wav + if file.endswith('.wav'): + # Extract the base name without the extension + base_name = os.path.splitext(file)[0] + # Write the base name and type to the CSV + writer.writerow([base_name, 'test']) + +def resize_motion_sequence_tensor(sequence, target_frames): + """ + Resize a batch of 8-frame motion sequences to a specified number of frames using interpolation. + + :param sequence: A (bs, 8, 165) tensor representing a batch of 8-frame motion sequences + :param target_frames: An integer representing the desired number of frames in the output sequences + :return: A (bs, target_frames, 165) tensor representing the resized motion sequences + """ + bs, _, _ = sequence.shape + + # Create a time vector for the original and target sequences + original_time = torch.linspace(0, 1, 8, device=sequence.device).view(1, -1, 1) + target_time = torch.linspace(0, 1, target_frames, device=sequence.device).view(1, -1, 1) + + # Permute the dimensions to (bs, 165, 8) for interpolation + sequence = sequence.permute(0, 2, 1) + + # Interpolate each joint's motion to the target number of frames + resized_sequence = torch.nn.functional.interpolate(sequence, size=target_frames, mode='linear', align_corners=True) + + # Permute the dimensions back to (bs, target_frames, 165) + resized_sequence = resized_sequence.permute(0, 2, 1) + + return resized_sequence + +def adjust_speed_according_to_ratio_tensor(chunks): + """ + Adjust the playback speed within a batch of 32-frame chunks according to random intervals. + + :param chunks: A (bs, 32, 165) tensor representing a batch of motion chunks + :return: A (bs, 32, 165) tensor representing the motion chunks after speed adjustment + """ + bs, _, _ = chunks.shape + + # Step 1: Divide the chunk into 4 equal intervals of 8 frames + equal_intervals = torch.chunk(chunks, 4, dim=1) + + # Step 2: Randomly sample 3 points within the chunk to determine new intervals + success = 0 + all_success = [] + #sample_points = torch.sort(torch.randint(1, 32, (bs, 3), device=chunks.device), dim=1).values + # new_intervals_boundaries = torch.cat([torch.zeros((bs, 1), device=chunks.device, dtype=torch.long), sample_points, 32*torch.ones((bs, 1), device=chunks.device, dtype=torch.long)], dim=1) + while success != 1: + sample_points = sorted(random.sample(range(1, 32), 3)) + new_intervals_boundaries = [0] + sample_points + [32] + new_intervals = [chunks[0][new_intervals_boundaries[i]:new_intervals_boundaries[i+1]] for i in range(4)] + speed_ratios = [8 / len(new_interval) for new_interval in new_intervals] + # if any of the speed ratios is greater than 3 or less than 0.33, resample + if all([0.33 <= speed_ratio <= 3 for speed_ratio in speed_ratios]): + success += 1 + all_success.append(new_intervals_boundaries) + new_intervals_boundaries = torch.from_numpy(np.array(all_success)) + # print(new_intervals_boundaries) + all_shapes = new_intervals_boundaries[:, 1:] - new_intervals_boundaries[:, :-1] + # Step 4: Adjust the speed of each new interval + adjusted_intervals = [] + # print(equal_intervals[0].shape) + for i in range(4): + adjusted_interval = resize_motion_sequence_tensor(equal_intervals[i], all_shapes[0, i]) + adjusted_intervals.append(adjusted_interval) + + # Step 5: Concatenate the adjusted intervals + adjusted_chunk = torch.cat(adjusted_intervals, dim=1) + + return adjusted_chunk + +def compute_exact_iou(bbox1, bbox2): + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) + y2 = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) + + intersection_area = max(0, x2 - x1) * max(0, y2 - y1) + bbox1_area = bbox1[2] * bbox1[3] + bbox2_area = bbox2[2] * bbox2[3] + union_area = bbox1_area + bbox2_area - intersection_area + + if union_area == 0: + return 0 + + return intersection_area / union_area + +def compute_iou(mask1, mask2): + # Compute the intersection + intersection = np.logical_and(mask1, mask2).sum() + + # Compute the union + union = np.logical_or(mask1, mask2).sum() + + # Compute the IoU + iou = intersection / union + + return iou + +def blankblending(all_frames, x, n): + return all_frames[x:x+n+1] + +def synthesize_intermediate_frames_FILM(frame1, frame2, t, name, save_path): + import replicate + from urllib.request import urlretrieve + import os + cv2.imwrite(save_path[:-9]+name+"_frame1.png", frame1) + cv2.imwrite(save_path[:-9]+name+"_frame2.png", frame2) + os.environ["REPLICATE_API_TOKEN"] = "r8_He1rkPk9GAxNQ3LpOohK8sYw1SUfMYV3Fxk9b" + output = replicate.run( + "google-research/frame-interpolation:4f88a16a13673a8b589c18866e540556170a5bcb2ccdc12de556e800e9456d3d", + input={ + "frame1": open(save_path[:-9]+name+"_frame1.png", "rb"), + "frame2": open(save_path[:-9]+name+"_frame2.png", "rb"), + "times_to_interpolate": t, + } + ) + print(output) + urlretrieve(output, save_path[:-9]+name+"_inter.mp4") + return load_video_as_numpy_array(save_path[:-9]+name+"_inter.mp4") + +def load_video_as_numpy_array(video_path): + cap = cv2.VideoCapture(video_path) + + # Using list comprehension to read frames and store in a list + frames = [frame for ret, frame in iter(lambda: cap.read(), (False, None)) if ret] + + cap.release() + + return np.array(frames) + +def synthesize_intermediate_frames_bidirectional(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + # Convert the frames to grayscale + gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # Calculate the forward and backward optical flow + forward_flow = cv2.calcOpticalFlowFarneback(gray1, gray2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + backward_flow = cv2.calcOpticalFlowFarneback(gray2, gray1, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / n # Interpolation factor + + # Compute the intermediate forward and backward flow + intermediate_forward_flow = forward_flow * alpha + intermediate_backward_flow = backward_flow * (1 - alpha) + + # Warp the frames based on the intermediate flow + h, w = frame1.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + forward_displacement = flow_map + intermediate_forward_flow.reshape(-1, 2) + backward_displacement = flow_map - intermediate_backward_flow.reshape(-1, 2) + + # Use cv2.remap for efficient warping + remap_x_forward, remap_y_forward = np.clip(forward_displacement[:, 1], 0, w - 1), np.clip(forward_displacement[:, 0], 0, h - 1) + remap_x_backward, remap_y_backward = np.clip(backward_displacement[:, 1], 0, w - 1), np.clip(backward_displacement[:, 0], 0, h - 1) + + warped_forward = cv2.remap(frame1, remap_x_forward.reshape(h, w).astype(np.float32), remap_y_forward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + warped_backward = cv2.remap(frame2, remap_x_backward.reshape(h, w).astype(np.float32), remap_y_backward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + + # Blend the warped frames to generate the intermediate frame + intermediate_frame = cv2.addWeighted(warped_forward, 1 - alpha, warped_backward, alpha, 0) + synthesized_frames.append(intermediate_frame) + + return synthesized_frames # Return n-2 synthesized intermediate frames + + +def linear_interpolate_frames(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / (n) # Correct interpolation factor + inter_frame = cv2.addWeighted(frame1, 1 - alpha, frame2, alpha, 0) + synthesized_frames.append(inter_frame) + return synthesized_frames[:-1] + +def warp_frame(src_frame, flow): + h, w = flow.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + displacement = flow_map + flow.reshape(-1, 2) + + # Extract x and y coordinates of the displacement + x_coords = np.clip(displacement[:, 1], 0, w - 1).reshape(h, w).astype(np.float32) + y_coords = np.clip(displacement[:, 0], 0, h - 1).reshape(h, w).astype(np.float32) + + # Use cv2.remap for efficient warping + warped_frame = cv2.remap(src_frame, x_coords, y_coords, interpolation=cv2.INTER_LINEAR) + + return warped_frame + +def synthesize_intermediate_frames(all_frames, x, n): + # Calculate Optical Flow between the first and last frame + frame1 = cv2.cvtColor(all_frames[x], cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(all_frames[x + n], cv2.COLOR_BGR2GRAY) + flow = cv2.calcOpticalFlowFarneback(frame1, frame2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame + alpha = i / (n) # Interpolation factor + intermediate_flow = flow * alpha # Interpolate the flow + intermediate_frame = warp_frame(all_frames[x], intermediate_flow) # Warp the first frame + synthesized_frames.append(intermediate_frame) + + return synthesized_frames + + +def map2color(s): + m = hashlib.md5() + m.update(s.encode('utf-8')) + color_code = m.hexdigest()[:6] + return '#' + color_code + +def euclidean_distance(a, b): + return np.sqrt(np.sum((a - b)**2)) + +def adjust_array(x, k): + len_x = len(x) + len_k = len(k) + + # If x is shorter than k, pad with zeros + if len_x < len_k: + return np.pad(x, (0, len_k - len_x), 'constant') + + # If x is longer than k, truncate x + elif len_x > len_k: + return x[:len_k] + + # If both are of same length + else: + return x + +def onset_to_frame(onset_times, audio_length, fps): + # Calculate total number of frames for the given audio length + total_frames = int(audio_length * fps) + + # Create an array of zeros of shape (total_frames,) + frame_array = np.zeros(total_frames, dtype=np.int32) + + # For each onset time, calculate the frame number and set it to 1 + for onset in onset_times: + frame_num = int(onset * fps) + # Check if the frame number is within the array bounds + if 0 <= frame_num < total_frames: + frame_array[frame_num] = 1 + + return frame_array + +# def np_slerp(q1, q2, t): +# dot_product = np.sum(q1 * q2, axis=-1) +# q2_flip = np.where(dot_product[:, None] < 0, -q2, q2) # Flip quaternions where dot_product is negative +# dot_product = np.abs(dot_product) + +# angle = np.arccos(np.clip(dot_product, -1, 1)) +# sin_angle = np.sin(angle) + +# t1 = np.sin((1.0 - t) * angle) / sin_angle +# t2 = np.sin(t * angle) / sin_angle + +# return t1 * q1 + t2 * q2_flip + + +def smooth_rotvec_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using SLERP. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + n1, k1 = animation1.shape + n2, k2 = animation2.shape + animation1 = animation1.reshape(n1, k1//3, 3) + animation2 = animation2.reshape(n2, k2//3, 3) + blend_frames = min(blend_frames, len(animation1), len(animation2)) + all_int = [] + for i in range(k1//3): + # Convert rotation vectors to quaternion for the overlapping part + q = R.from_rotvec(np.concatenate([animation1[0:1, i], animation2[-2:-1, i]], axis=0))#.as_quat() + # q2 = R.from_rotvec()#.as_quat() + times = [0, blend_frames * 2 - 1] + slerp = Slerp(times, q) + interpolated = slerp(np.arange(blend_frames * 2)) + interpolated_rotvecs = interpolated.as_rotvec() + all_int.append(interpolated_rotvecs) + interpolated_rotvecs = np.concatenate(all_int, axis=1) + # result = np.vstack((animation1[:-blend_frames], interpolated_rotvecs, animation2[blend_frames:])) + result = interpolated_rotvecs.reshape(2*n1, k1) + return result + +def smooth_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using linear interpolation. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + blend_frames = min(blend_frames, len(animation1), len(animation2)) + + # Extract overlapping sections + overlap_a1 = animation1[-blend_frames:-blend_frames+1, :] + overlap_a2 = animation2[blend_frames-1:blend_frames, :] + + # Create blend weights for linear interpolation + alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1) + + # Linearly interpolate between overlapping sections + blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha + + # Extend the animations to form the result with 2n frames + if blend_frames == len(animation1) and blend_frames == len(animation2): + result = blended_overlap + else: + before_blend = animation1[:-blend_frames] + after_blend = animation2[blend_frames:] + result = np.vstack((before_blend, blended_overlap, after_blend)) + return result + +def interpolate_sequence(quaternions): + bs, n, j, _ = quaternions.shape + new_n = 2 * n + new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype) + + for i in range(n): + q1 = quaternions[:, i, :, :] + new_quaternions[:, 2*i, :, :] = q1 + + if i < n - 1: + q2 = quaternions[:, i + 1, :, :] + new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5) + else: + # For the last point, duplicate the value + new_quaternions[:, 2*i + 1, :, :] = q1 + + return new_quaternions + +def quaternion_multiply(q1, q2): + w1, x1, y1, z1 = q1 + w2, x2, y2, z2 = q2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return w, x, y, z + +def quaternion_conjugate(q): + w, x, y, z = q + return (w, -x, -y, -z) + +def slerp(q1, q2, t): + dot = torch.sum(q1 * q2, dim=-1, keepdim=True) + + flip = (dot < 0).float() + q2 = (1 - flip * 2) * q2 + dot = dot * (1 - flip * 2) + + DOT_THRESHOLD = 0.9995 + mask = (dot > DOT_THRESHOLD).float() + + theta_0 = torch.acos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / torch.norm(q3, dim=-1, keepdim=True) + + interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3) + + return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated + +def estimate_linear_velocity(data_seq, dt): + ''' + Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates + the velocity for the middle T-2 steps using a second order central difference scheme. + The first and last frames are with forward and backward first-order + differences, respectively + - h : step size + ''' + # first steps is forward diff (t+1 - t) / dt + init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt + + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) + return vel_seq + +def velocity2position(data_seq, dt, init_pos): + res_trans = [] + for i in range(data_seq.shape[1]): + if i == 0: + res_trans.append(init_pos.unsqueeze(1)) + else: + res = data_seq[:, i-1:i] * dt + res_trans[-1] + res_trans.append(res) + return torch.cat(res_trans, dim=1) + +def estimate_angular_velocity(rot_seq, dt): + ''' + Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. + Input sequence should be of shape (B, T, ..., 3, 3) + ''' + # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix + dRdt = estimate_linear_velocity(rot_seq, dt) + R = rot_seq + RT = R.transpose(-1, -2) + # compute skew-symmetric angular velocity tensor + w_mat = torch.matmul(dRdt, RT) + # pull out angular velocity vector by averaging symmetric entries + w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 + w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 + w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 + w = torch.stack([w_x, w_y, w_z], axis=-1) + return w + +def image_from_bytes(image_bytes): + import matplotlib.image as mpimg + from io import BytesIO + return mpimg.imread(BytesIO(image_bytes), format='PNG') + +def process_frame(i, vertices_all, vertices1_all, faces, output_dir, filenames): + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import trimesh + import pyrender + + def deg_to_rad(degrees): + return degrees * np.pi / 180 + + uniform_color = [220, 220, 220, 255] + resolution = (1000, 1000) + figsize = (10, 10) + + fig, axs = plt.subplots( + nrows=1, + ncols=2, + figsize=(figsize[0] * 2, figsize[1] * 1) + ) + axs = axs.flatten() + + vertices = vertices_all[i] + vertices1 = vertices1_all[i] + filename = f"{output_dir}frame_{i}.png" + filenames.append(filename) + if i%100 == 0: + print('processed', i, 'frames') + #time_s = time.time() + #print(vertices.shape) + angle_rad = deg_to_rad(-2) + pose_camera = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 1.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 5.0], + [0.0, 0.0, 0.0, 1.0] + ]) + angle_rad = deg_to_rad(-30) + pose_light = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 0.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 3.0], + [0.0, 0.0, 0.0, 1.0] + ]) + + for vtx_idx, vtx in enumerate([vertices, vertices1]): + trimesh_mesh = trimesh.Trimesh( + vertices=vtx, + faces=faces, + vertex_colors=uniform_color + ) + mesh = pyrender.Mesh.from_trimesh( + trimesh_mesh, smooth=True + ) + scene = pyrender.Scene() + scene.add(mesh) + camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + scene.add(camera, pose=pose_camera) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=4.0) + scene.add(light, pose=pose_light) + renderer = pyrender.OffscreenRenderer(*resolution) + color, _ = renderer.render(scene) + axs[vtx_idx].imshow(color) + axs[vtx_idx].axis('off') + renderer.delete() + + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + +def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, filenames): + import multiprocessing + # import trimesh + num_cores = multiprocessing.cpu_count() - 1 # This will get the number of cores on your machine. + # mesh = trimesh.Trimesh(vertices_all[0], faces) + # scene = mesh.scene() + # fov = scene.camera.fov.copy() + # fov[0] = 80.0 + # fov[1] = 60.0 + # camera_params = { + # 'fov': fov, + # 'resolution': scene.camera.resolution, + # 'focal': scene.camera.focal, + # 'z_near': scene.camera.z_near, + # "z_far": scene.camera.z_far, + # 'transform': scene.graph[scene.camera.name][0] + # } + # mesh1 = trimesh.Trimesh(vertices1_all[0], faces) + # scene1 = mesh1.scene() + # camera_params1 = { + # 'fov': fov, + # 'resolution': scene1.camera.resolution, + # 'focal': scene1.camera.focal, + # 'z_near': scene1.camera.z_near, + # "z_far": scene1.camera.z_far, + # 'transform': scene1.graph[scene1.camera.name][0] + # } + # Use a Pool to manage the processes + # print(num_cores) + # for i in range(frames): + # process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) + for i in range(frames): + process_frame(i, vertices_all, vertices1_all, faces, output_dir, filenames) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + +def render_one_sequence( + res_npz_path, + gt_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import moviepy.editor as mp + import librosa + + model = smplx.create(model_folder, model_type=model_type, + gender=gender, use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, use_pca=False).cuda() + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + filenames = [] + # if not use_matplotlib: + # import trimesh + #import pyrender + #!!! I only have Windows, the following lines of comments are feasible, but have not been tested on other platforms. + #from pyvirtualdisplay import Display + #display = Display(visible=0, size=(500, 500)) + #display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).cuda() + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).cuda() + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).cuda() + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).cuda() + pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).cuda() + transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).cuda() + output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + leye_pose=pose1[:, 69:72], + reye_pose=pose1[:, 72:75],return_verts=True) + vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + + silent_video_file_path = utils.fast_render.generate_silent_videos(args.render_video_fps, + args.render_video_width, + args.render_video_height, + args.render_concurrent_num, + args.render_tmp_img_filetype, + int(seconds*args.render_video_fps), + vertices_all, + vertices1_all, + faces, + output_dir) + + #final_clip = f"{output_dir}{res_npz_path.split('/')[-1][4:-4]}.mp4" + base_filename_without_ext = os.path.splitext(os.path.basename(res_npz_path))[0] + final_clip = os.path.join(output_dir, f"{base_filename_without_ext}.mp4") + utils.media.add_audio_to_video(silent_video_file_path, audio_path, final_clip) + os.remove(silent_video_file_path) + +def print_exp_info(args): + logger.info(pprint.pformat(vars(args))) + logger.info(f"# ------------ {args.name} ----------- #") + logger.info("PyTorch version: {}".format(torch.__version__)) + logger.info("CUDA version: {}".format(torch.version.cuda)) + logger.info("{} GPUs".format(torch.cuda.device_count())) + logger.info(f"Random Seed: {args.random_seed}") + +def args2csv(args, get_head=False, list4print=[]): + for k, v in args.items(): + if isinstance(args[k], dict): + args2csv(args[k], get_head, list4print) + else: list4print.append(k) if get_head else list4print.append(v) + return list4print + +class EpochTracker: + def __init__(self, metric_names, metric_directions): + assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length" + + + self.metric_names = metric_names + self.states = ['train', 'val', 'test'] + self.types = ['last', 'best'] + + + self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0} + for type_ in self.types} + for state in self.states} + for name, is_higher_better in zip(metric_names, metric_directions)} + + self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}") + for state in self.states} + for name in metric_names} + + + self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)} + self.train_history = {name: [] for name in metric_names} + self.val_history = {name: [] for name in metric_names} + + + def update_meter(self, name, state, value): + self.loss_meters[name][state].update(value) + + + def update_values(self, name, state, epoch): + value_avg = self.loss_meters[name][state].avg + new_best = False + + + if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or + (value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])): + self.values[name][state]['best']['value'] = value_avg + self.values[name][state]['best']['epoch'] = epoch + new_best = True + self.values[name][state]['last']['value'] = value_avg + self.values[name][state]['last']['epoch'] = epoch + return new_best + + + def get(self, name, state, type_): + return self.values[name][state][type_] + + + def reset(self): + for name in self.metric_names: + for state in self.states: + self.loss_meters[name][state].reset() + + + def flatten_values(self): + flat_dict = {} + for name in self.metric_names: + for state in self.states: + for type_ in self.types: + value_key = f"{name}_{state}_{type_}" + epoch_key = f"{name}_{state}_{type_}_epoch" + flat_dict[value_key] = self.values[name][state][type_]['value'] + flat_dict[epoch_key] = self.values[name][state][type_]['epoch'] + return flat_dict + + def update_and_plot(self, name, epoch, save_path): + new_best_train = self.update_values(name, 'train', epoch) + new_best_val = self.update_values(name, 'val', epoch) + + + self.train_history[name].append(self.loss_meters[name]['train'].avg) + self.val_history[name].append(self.loss_meters[name]['val'].avg) + + + train_values = self.train_history[name] + val_values = self.val_history[name] + epochs = list(range(1, len(train_values) + 1)) + + + plt.figure(figsize=(10, 6)) + plt.plot(epochs, train_values, label='Train') + plt.plot(epochs, val_values, label='Val') + plt.title(f'Train vs Val {name} over epochs') + plt.xlabel('Epochs') + plt.ylabel(name) + plt.legend() + plt.savefig(save_path) + plt.close() + + + return new_best_train, new_best_val + +def record_trial(args, tracker): + """ + 1. record notes, score, env_name, experments_path, + """ + csv_path = args.out_path + "custom/" +args.csv_name+".csv" + all_print_dict = vars(args) + all_print_dict.update(tracker.flatten_values()) + if not os.path.exists(csv_path): + pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False) + else: + df_existing = pd.read_csv(csv_path) + df_new = pd.DataFrame([all_print_dict]) + df_aligned = df_existing.append(df_new).fillna("") + df_aligned.to_csv(csv_path, index=False) + +def set_random_seed(args): + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + np.random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + torch.cuda.manual_seed_all(args.random_seed) + torch.cuda.manual_seed(args.random_seed) + torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC + torch.backends.cudnn.benchmark = args.benchmark + torch.backends.cudnn.enabled = args.cudnn_enabled + +def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None): + if lrs is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(), + 'lrs':lrs.state_dict(),} + elif opt is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(),} + else: + states = { 'model_state': model.state_dict(),} + torch.save(states, save_path) + +def load_checkpoints(model, save_path, load_name='model'): + states = torch.load(save_path) + new_weights = OrderedDict() + flag=False + for k, v in states['model_state'].items(): + #print(k) + if "module" not in k: + break + else: + new_weights[k[7:]]=v + flag=True + if flag: + try: + model.load_state_dict(new_weights) + except: + #print(states['model_state']) + model.load_state_dict(states['model_state']) + else: + model.load_state_dict(states['model_state']) + logger.info(f"load self-pretrained checkpoints for {load_name}") + +def model_complexity(model, args): + from ptflops import get_model_complexity_info + flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN), + as_strings=False, print_per_layer_stat=False) + logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9)) + logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6)) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) diff --git a/utils/other_tools_hf.py b/utils/other_tools_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b9f943120954e5745f1beca4e37ff461f663bb --- /dev/null +++ b/utils/other_tools_hf.py @@ -0,0 +1,975 @@ +import os +import numpy as np +import random +import torch +import shutil +import csv +import pprint +import pandas as pd +from loguru import logger +from collections import OrderedDict +import matplotlib.pyplot as plt +import pickle +import time +import hashlib +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp +import cv2 +import utils.media +import utils.fast_render + +def write_wav_names_to_csv(folder_path, csv_path): + """ + Traverse a folder and write the base names of all .wav files to a CSV file. + + :param folder_path: Path to the folder to traverse. + :param csv_path: Path to the CSV file to write. + """ + # Open the CSV file for writing + with open(csv_path, mode='w', newline='') as file: + writer = csv.writer(file) + # Write the header + writer.writerow(['id', 'type']) + + # Walk through the folder + for root, dirs, files in os.walk(folder_path): + for file in files: + # Check if the file ends with .wav + if file.endswith('.wav'): + # Extract the base name without the extension + base_name = os.path.splitext(file)[0] + # Write the base name and type to the CSV + writer.writerow([base_name, 'test']) + +def resize_motion_sequence_tensor(sequence, target_frames): + """ + Resize a batch of 8-frame motion sequences to a specified number of frames using interpolation. + + :param sequence: A (bs, 8, 165) tensor representing a batch of 8-frame motion sequences + :param target_frames: An integer representing the desired number of frames in the output sequences + :return: A (bs, target_frames, 165) tensor representing the resized motion sequences + """ + bs, _, _ = sequence.shape + + # Create a time vector for the original and target sequences + original_time = torch.linspace(0, 1, 8, device=sequence.device).view(1, -1, 1) + target_time = torch.linspace(0, 1, target_frames, device=sequence.device).view(1, -1, 1) + + # Permute the dimensions to (bs, 165, 8) for interpolation + sequence = sequence.permute(0, 2, 1) + + # Interpolate each joint's motion to the target number of frames + resized_sequence = torch.nn.functional.interpolate(sequence, size=target_frames, mode='linear', align_corners=True) + + # Permute the dimensions back to (bs, target_frames, 165) + resized_sequence = resized_sequence.permute(0, 2, 1) + + return resized_sequence + +def adjust_speed_according_to_ratio_tensor(chunks): + """ + Adjust the playback speed within a batch of 32-frame chunks according to random intervals. + + :param chunks: A (bs, 32, 165) tensor representing a batch of motion chunks + :return: A (bs, 32, 165) tensor representing the motion chunks after speed adjustment + """ + bs, _, _ = chunks.shape + + # Step 1: Divide the chunk into 4 equal intervals of 8 frames + equal_intervals = torch.chunk(chunks, 4, dim=1) + + # Step 2: Randomly sample 3 points within the chunk to determine new intervals + success = 0 + all_success = [] + #sample_points = torch.sort(torch.randint(1, 32, (bs, 3), device=chunks.device), dim=1).values + # new_intervals_boundaries = torch.cat([torch.zeros((bs, 1), device=chunks.device, dtype=torch.long), sample_points, 32*torch.ones((bs, 1), device=chunks.device, dtype=torch.long)], dim=1) + while success != 1: + sample_points = sorted(random.sample(range(1, 32), 3)) + new_intervals_boundaries = [0] + sample_points + [32] + new_intervals = [chunks[0][new_intervals_boundaries[i]:new_intervals_boundaries[i+1]] for i in range(4)] + speed_ratios = [8 / len(new_interval) for new_interval in new_intervals] + # if any of the speed ratios is greater than 3 or less than 0.33, resample + if all([0.33 <= speed_ratio <= 3 for speed_ratio in speed_ratios]): + success += 1 + all_success.append(new_intervals_boundaries) + new_intervals_boundaries = torch.from_numpy(np.array(all_success)) + # print(new_intervals_boundaries) + all_shapes = new_intervals_boundaries[:, 1:] - new_intervals_boundaries[:, :-1] + # Step 4: Adjust the speed of each new interval + adjusted_intervals = [] + # print(equal_intervals[0].shape) + for i in range(4): + adjusted_interval = resize_motion_sequence_tensor(equal_intervals[i], all_shapes[0, i]) + adjusted_intervals.append(adjusted_interval) + + # Step 5: Concatenate the adjusted intervals + adjusted_chunk = torch.cat(adjusted_intervals, dim=1) + + return adjusted_chunk + +def compute_exact_iou(bbox1, bbox2): + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) + y2 = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) + + intersection_area = max(0, x2 - x1) * max(0, y2 - y1) + bbox1_area = bbox1[2] * bbox1[3] + bbox2_area = bbox2[2] * bbox2[3] + union_area = bbox1_area + bbox2_area - intersection_area + + if union_area == 0: + return 0 + + return intersection_area / union_area + +def compute_iou(mask1, mask2): + # Compute the intersection + intersection = np.logical_and(mask1, mask2).sum() + + # Compute the union + union = np.logical_or(mask1, mask2).sum() + + # Compute the IoU + iou = intersection / union + + return iou + +def blankblending(all_frames, x, n): + return all_frames[x:x+n+1] + +def synthesize_intermediate_frames_FILM(frame1, frame2, t, name, save_path): + import replicate + from urllib.request import urlretrieve + import os + cv2.imwrite(save_path[:-9]+name+"_frame1.png", frame1) + cv2.imwrite(save_path[:-9]+name+"_frame2.png", frame2) + os.environ["REPLICATE_API_TOKEN"] = "r8_He1rkPk9GAxNQ3LpOohK8sYw1SUfMYV3Fxk9b" + output = replicate.run( + "google-research/frame-interpolation:4f88a16a13673a8b589c18866e540556170a5bcb2ccdc12de556e800e9456d3d", + input={ + "frame1": open(save_path[:-9]+name+"_frame1.png", "rb"), + "frame2": open(save_path[:-9]+name+"_frame2.png", "rb"), + "times_to_interpolate": t, + } + ) + print(output) + urlretrieve(output, save_path[:-9]+name+"_inter.mp4") + return load_video_as_numpy_array(save_path[:-9]+name+"_inter.mp4") + +def load_video_as_numpy_array(video_path): + cap = cv2.VideoCapture(video_path) + + # Using list comprehension to read frames and store in a list + frames = [frame for ret, frame in iter(lambda: cap.read(), (False, None)) if ret] + + cap.release() + + return np.array(frames) + +def synthesize_intermediate_frames_bidirectional(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + # Convert the frames to grayscale + gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # Calculate the forward and backward optical flow + forward_flow = cv2.calcOpticalFlowFarneback(gray1, gray2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + backward_flow = cv2.calcOpticalFlowFarneback(gray2, gray1, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / n # Interpolation factor + + # Compute the intermediate forward and backward flow + intermediate_forward_flow = forward_flow * alpha + intermediate_backward_flow = backward_flow * (1 - alpha) + + # Warp the frames based on the intermediate flow + h, w = frame1.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + forward_displacement = flow_map + intermediate_forward_flow.reshape(-1, 2) + backward_displacement = flow_map - intermediate_backward_flow.reshape(-1, 2) + + # Use cv2.remap for efficient warping + remap_x_forward, remap_y_forward = np.clip(forward_displacement[:, 1], 0, w - 1), np.clip(forward_displacement[:, 0], 0, h - 1) + remap_x_backward, remap_y_backward = np.clip(backward_displacement[:, 1], 0, w - 1), np.clip(backward_displacement[:, 0], 0, h - 1) + + warped_forward = cv2.remap(frame1, remap_x_forward.reshape(h, w).astype(np.float32), remap_y_forward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + warped_backward = cv2.remap(frame2, remap_x_backward.reshape(h, w).astype(np.float32), remap_y_backward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + + # Blend the warped frames to generate the intermediate frame + intermediate_frame = cv2.addWeighted(warped_forward, 1 - alpha, warped_backward, alpha, 0) + synthesized_frames.append(intermediate_frame) + + return synthesized_frames # Return n-2 synthesized intermediate frames + + +def linear_interpolate_frames(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / (n) # Correct interpolation factor + inter_frame = cv2.addWeighted(frame1, 1 - alpha, frame2, alpha, 0) + synthesized_frames.append(inter_frame) + return synthesized_frames[:-1] + +def warp_frame(src_frame, flow): + h, w = flow.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + displacement = flow_map + flow.reshape(-1, 2) + + # Extract x and y coordinates of the displacement + x_coords = np.clip(displacement[:, 1], 0, w - 1).reshape(h, w).astype(np.float32) + y_coords = np.clip(displacement[:, 0], 0, h - 1).reshape(h, w).astype(np.float32) + + # Use cv2.remap for efficient warping + warped_frame = cv2.remap(src_frame, x_coords, y_coords, interpolation=cv2.INTER_LINEAR) + + return warped_frame + +def synthesize_intermediate_frames(all_frames, x, n): + # Calculate Optical Flow between the first and last frame + frame1 = cv2.cvtColor(all_frames[x], cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(all_frames[x + n], cv2.COLOR_BGR2GRAY) + flow = cv2.calcOpticalFlowFarneback(frame1, frame2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame + alpha = i / (n) # Interpolation factor + intermediate_flow = flow * alpha # Interpolate the flow + intermediate_frame = warp_frame(all_frames[x], intermediate_flow) # Warp the first frame + synthesized_frames.append(intermediate_frame) + + return synthesized_frames + + +def map2color(s): + m = hashlib.md5() + m.update(s.encode('utf-8')) + color_code = m.hexdigest()[:6] + return '#' + color_code + +def euclidean_distance(a, b): + return np.sqrt(np.sum((a - b)**2)) + +def adjust_array(x, k): + len_x = len(x) + len_k = len(k) + + # If x is shorter than k, pad with zeros + if len_x < len_k: + return np.pad(x, (0, len_k - len_x), 'constant') + + # If x is longer than k, truncate x + elif len_x > len_k: + return x[:len_k] + + # If both are of same length + else: + return x + +def onset_to_frame(onset_times, audio_length, fps): + # Calculate total number of frames for the given audio length + total_frames = int(audio_length * fps) + + # Create an array of zeros of shape (total_frames,) + frame_array = np.zeros(total_frames, dtype=np.int32) + + # For each onset time, calculate the frame number and set it to 1 + for onset in onset_times: + frame_num = int(onset * fps) + # Check if the frame number is within the array bounds + if 0 <= frame_num < total_frames: + frame_array[frame_num] = 1 + + return frame_array + +# def np_slerp(q1, q2, t): +# dot_product = np.sum(q1 * q2, axis=-1) +# q2_flip = np.where(dot_product[:, None] < 0, -q2, q2) # Flip quaternions where dot_product is negative +# dot_product = np.abs(dot_product) + +# angle = np.arccos(np.clip(dot_product, -1, 1)) +# sin_angle = np.sin(angle) + +# t1 = np.sin((1.0 - t) * angle) / sin_angle +# t2 = np.sin(t * angle) / sin_angle + +# return t1 * q1 + t2 * q2_flip + + +def smooth_rotvec_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using SLERP. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + n1, k1 = animation1.shape + n2, k2 = animation2.shape + animation1 = animation1.reshape(n1, k1//3, 3) + animation2 = animation2.reshape(n2, k2//3, 3) + blend_frames = min(blend_frames, len(animation1), len(animation2)) + all_int = [] + for i in range(k1//3): + # Convert rotation vectors to quaternion for the overlapping part + q = R.from_rotvec(np.concatenate([animation1[0:1, i], animation2[-2:-1, i]], axis=0))#.as_quat() + # q2 = R.from_rotvec()#.as_quat() + times = [0, blend_frames * 2 - 1] + slerp = Slerp(times, q) + interpolated = slerp(np.arange(blend_frames * 2)) + interpolated_rotvecs = interpolated.as_rotvec() + all_int.append(interpolated_rotvecs) + interpolated_rotvecs = np.concatenate(all_int, axis=1) + # result = np.vstack((animation1[:-blend_frames], interpolated_rotvecs, animation2[blend_frames:])) + result = interpolated_rotvecs.reshape(2*n1, k1) + return result + +def smooth_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using linear interpolation. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + blend_frames = min(blend_frames, len(animation1), len(animation2)) + + # Extract overlapping sections + overlap_a1 = animation1[-blend_frames:-blend_frames+1, :] + overlap_a2 = animation2[blend_frames-1:blend_frames, :] + + # Create blend weights for linear interpolation + alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1) + + # Linearly interpolate between overlapping sections + blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha + + # Extend the animations to form the result with 2n frames + if blend_frames == len(animation1) and blend_frames == len(animation2): + result = blended_overlap + else: + before_blend = animation1[:-blend_frames] + after_blend = animation2[blend_frames:] + result = np.vstack((before_blend, blended_overlap, after_blend)) + return result + +def interpolate_sequence(quaternions): + bs, n, j, _ = quaternions.shape + new_n = 2 * n + new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype) + + for i in range(n): + q1 = quaternions[:, i, :, :] + new_quaternions[:, 2*i, :, :] = q1 + + if i < n - 1: + q2 = quaternions[:, i + 1, :, :] + new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5) + else: + # For the last point, duplicate the value + new_quaternions[:, 2*i + 1, :, :] = q1 + + return new_quaternions + +def quaternion_multiply(q1, q2): + w1, x1, y1, z1 = q1 + w2, x2, y2, z2 = q2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return w, x, y, z + +def quaternion_conjugate(q): + w, x, y, z = q + return (w, -x, -y, -z) + +def slerp(q1, q2, t): + dot = torch.sum(q1 * q2, dim=-1, keepdim=True) + + flip = (dot < 0).float() + q2 = (1 - flip * 2) * q2 + dot = dot * (1 - flip * 2) + + DOT_THRESHOLD = 0.9995 + mask = (dot > DOT_THRESHOLD).float() + + theta_0 = torch.acos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / torch.norm(q3, dim=-1, keepdim=True) + + interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3) + + return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated + +def estimate_linear_velocity(data_seq, dt): + ''' + Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates + the velocity for the middle T-2 steps using a second order central difference scheme. + The first and last frames are with forward and backward first-order + differences, respectively + - h : step size + ''' + # first steps is forward diff (t+1 - t) / dt + init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt + + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) + return vel_seq + +def velocity2position(data_seq, dt, init_pos): + res_trans = [] + for i in range(data_seq.shape[1]): + if i == 0: + res_trans.append(init_pos.unsqueeze(1)) + else: + res = data_seq[:, i-1:i] * dt + res_trans[-1] + res_trans.append(res) + return torch.cat(res_trans, dim=1) + +def estimate_angular_velocity(rot_seq, dt): + ''' + Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. + Input sequence should be of shape (B, T, ..., 3, 3) + ''' + # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix + dRdt = estimate_linear_velocity(rot_seq, dt) + R = rot_seq + RT = R.transpose(-1, -2) + # compute skew-symmetric angular velocity tensor + w_mat = torch.matmul(dRdt, RT) + # pull out angular velocity vector by averaging symmetric entries + w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 + w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 + w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 + w = torch.stack([w_x, w_y, w_z], axis=-1) + return w + +def image_from_bytes(image_bytes): + import matplotlib.image as mpimg + from io import BytesIO + return mpimg.imread(BytesIO(image_bytes), format='PNG') + +def process_frame(i, vertices_all, vertices1_all, faces, output_dir, filenames): + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import trimesh + import pyrender + + def deg_to_rad(degrees): + return degrees * np.pi / 180 + + uniform_color = [220, 220, 220, 255] + resolution = (1000, 1000) + figsize = (10, 10) + + fig, axs = plt.subplots( + nrows=1, + ncols=2, + figsize=(figsize[0] * 2, figsize[1] * 1) + ) + axs = axs.flatten() + + vertices = vertices_all[i] + vertices1 = vertices1_all[i] + filename = f"{output_dir}frame_{i}.png" + filenames.append(filename) + if i%100 == 0: + print('processed', i, 'frames') + #time_s = time.time() + #print(vertices.shape) + angle_rad = deg_to_rad(-2) + pose_camera = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 1.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 5.0], + [0.0, 0.0, 0.0, 1.0] + ]) + angle_rad = deg_to_rad(-30) + pose_light = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 0.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 3.0], + [0.0, 0.0, 0.0, 1.0] + ]) + + for vtx_idx, vtx in enumerate([vertices, vertices1]): + trimesh_mesh = trimesh.Trimesh( + vertices=vtx, + faces=faces, + vertex_colors=uniform_color + ) + mesh = pyrender.Mesh.from_trimesh( + trimesh_mesh, smooth=True + ) + scene = pyrender.Scene() + scene.add(mesh) + camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + scene.add(camera, pose=pose_camera) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=4.0) + scene.add(light, pose=pose_light) + renderer = pyrender.OffscreenRenderer(*resolution) + color, _ = renderer.render(scene) + axs[vtx_idx].imshow(color) + axs[vtx_idx].axis('off') + renderer.delete() + + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + +def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, filenames): + import multiprocessing + # import trimesh + num_cores = multiprocessing.cpu_count() - 1 # This will get the number of cores on your machine. + # mesh = trimesh.Trimesh(vertices_all[0], faces) + # scene = mesh.scene() + # fov = scene.camera.fov.copy() + # fov[0] = 80.0 + # fov[1] = 60.0 + # camera_params = { + # 'fov': fov, + # 'resolution': scene.camera.resolution, + # 'focal': scene.camera.focal, + # 'z_near': scene.camera.z_near, + # "z_far": scene.camera.z_far, + # 'transform': scene.graph[scene.camera.name][0] + # } + # mesh1 = trimesh.Trimesh(vertices1_all[0], faces) + # scene1 = mesh1.scene() + # camera_params1 = { + # 'fov': fov, + # 'resolution': scene1.camera.resolution, + # 'focal': scene1.camera.focal, + # 'z_near': scene1.camera.z_near, + # "z_far": scene1.camera.z_far, + # 'transform': scene1.graph[scene1.camera.name][0] + # } + # Use a Pool to manage the processes + # print(num_cores) + # for i in range(frames): + # process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) + for i in range(frames): + process_frame(i*3, vertices_all, vertices1_all, faces, output_dir, filenames) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + +def render_one_sequence( + res_npz_path, + gt_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import moviepy.editor as mp + import librosa + + model = smplx.create(model_folder, model_type=model_type, + gender=gender, use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, use_pca=False).cuda() + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + # if not use_matplotlib: + # import trimesh + #import pyrender + from pyvirtualdisplay import Display + #''' + #display = Display(visible=0, size=(1000, 1000)) + #display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).cuda() + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).cuda() + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).cuda() + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).cuda() + pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).cuda() + transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).cuda() + output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + leye_pose=pose1[:, 69:72], + reye_pose=pose1[:, 72:75],return_verts=True) + vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + silent_video_file_path = utils.fast_render.generate_silent_videos(args.render_video_fps, + args.render_video_width, + args.render_video_height, + args.render_concurrent_num, + args.render_tmp_img_filetype, + int(seconds*args.render_video_fps), + vertices_all, + vertices1_all, + faces, + output_dir) + base_filename_without_ext = os.path.splitext(os.path.basename(res_npz_path))[0] + final_clip = os.path.join(output_dir, f"{base_filename_without_ext}.mp4") + utils.media.add_audio_to_video(silent_video_file_path, audio_path, final_clip) + os.remove(silent_video_file_path) + return final_clip + +def render_one_sequence_no_gt( + res_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import moviepy.editor as mp + import librosa + + model = smplx.create(model_folder, model_type=model_type, + gender=gender, use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, use_pca=False).cuda() + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + # gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + # if not use_matplotlib: + # import trimesh + #import pyrender + from pyvirtualdisplay import Display + #''' + #display = Display(visible=0, size=(1000, 1000)) + #display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).cuda() + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).cuda() + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).cuda() + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).cuda() + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + # beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + # expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).cuda() + # jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).cuda() + # pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).cuda() + # transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).cuda() + # output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + # leye_pose=pose1[:, 69:72], + # reye_pose=pose1[:, 72:75],return_verts=True) + # vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + silent_video_file_path = utils.fast_render.generate_silent_videos_no_gt(args.render_video_fps, + args.render_video_width, + args.render_video_height, + args.render_concurrent_num, + args.render_tmp_img_filetype, + int(seconds*args.render_video_fps), + vertices_all, + faces, + output_dir) + base_filename_without_ext = os.path.splitext(os.path.basename(res_npz_path))[0] + final_clip = os.path.join(output_dir, f"{base_filename_without_ext}.mp4") + utils.media.add_audio_to_video(silent_video_file_path, audio_path, final_clip) + os.remove(silent_video_file_path) + return final_clip + +def print_exp_info(args): + logger.info(pprint.pformat(vars(args))) + logger.info(f"# ------------ {args.name} ----------- #") + logger.info("PyTorch version: {}".format(torch.__version__)) + logger.info("CUDA version: {}".format(torch.version.cuda)) + logger.info("{} GPUs".format(torch.cuda.device_count())) + logger.info(f"Random Seed: {args.random_seed}") + +def args2csv(args, get_head=False, list4print=[]): + for k, v in args.items(): + if isinstance(args[k], dict): + args2csv(args[k], get_head, list4print) + else: list4print.append(k) if get_head else list4print.append(v) + return list4print + +class EpochTracker: + def __init__(self, metric_names, metric_directions): + assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length" + + + self.metric_names = metric_names + self.states = ['train', 'val', 'test'] + self.types = ['last', 'best'] + + + self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0} + for type_ in self.types} + for state in self.states} + for name, is_higher_better in zip(metric_names, metric_directions)} + + self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}") + for state in self.states} + for name in metric_names} + + + self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)} + self.train_history = {name: [] for name in metric_names} + self.val_history = {name: [] for name in metric_names} + + + def update_meter(self, name, state, value): + self.loss_meters[name][state].update(value) + + + def update_values(self, name, state, epoch): + value_avg = self.loss_meters[name][state].avg + new_best = False + + + if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or + (value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])): + self.values[name][state]['best']['value'] = value_avg + self.values[name][state]['best']['epoch'] = epoch + new_best = True + self.values[name][state]['last']['value'] = value_avg + self.values[name][state]['last']['epoch'] = epoch + return new_best + + + def get(self, name, state, type_): + return self.values[name][state][type_] + + + def reset(self): + for name in self.metric_names: + for state in self.states: + self.loss_meters[name][state].reset() + + + def flatten_values(self): + flat_dict = {} + for name in self.metric_names: + for state in self.states: + for type_ in self.types: + value_key = f"{name}_{state}_{type_}" + epoch_key = f"{name}_{state}_{type_}_epoch" + flat_dict[value_key] = self.values[name][state][type_]['value'] + flat_dict[epoch_key] = self.values[name][state][type_]['epoch'] + return flat_dict + + def update_and_plot(self, name, epoch, save_path): + new_best_train = self.update_values(name, 'train', epoch) + new_best_val = self.update_values(name, 'val', epoch) + + + self.train_history[name].append(self.loss_meters[name]['train'].avg) + self.val_history[name].append(self.loss_meters[name]['val'].avg) + + + train_values = self.train_history[name] + val_values = self.val_history[name] + epochs = list(range(1, len(train_values) + 1)) + + + plt.figure(figsize=(10, 6)) + plt.plot(epochs, train_values, label='Train') + plt.plot(epochs, val_values, label='Val') + plt.title(f'Train vs Val {name} over epochs') + plt.xlabel('Epochs') + plt.ylabel(name) + plt.legend() + plt.savefig(save_path) + plt.close() + + + return new_best_train, new_best_val + +def record_trial(args, tracker): + """ + 1. record notes, score, env_name, experments_path, + """ + csv_path = args.out_path + "custom/" +args.csv_name+".csv" + all_print_dict = vars(args) + all_print_dict.update(tracker.flatten_values()) + if not os.path.exists(csv_path): + pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False) + else: + df_existing = pd.read_csv(csv_path) + df_new = pd.DataFrame([all_print_dict]) + df_aligned = df_existing.append(df_new).fillna("") + df_aligned.to_csv(csv_path, index=False) + +def set_random_seed(args): + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + np.random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + torch.cuda.manual_seed_all(args.random_seed) + torch.cuda.manual_seed(args.random_seed) + torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC + torch.backends.cudnn.benchmark = args.benchmark + torch.backends.cudnn.enabled = args.cudnn_enabled + +def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None): + if lrs is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(), + 'lrs':lrs.state_dict(),} + elif opt is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(),} + else: + states = { 'model_state': model.state_dict(),} + torch.save(states, save_path) + +def load_checkpoints(model, save_path, load_name='model'): + states = torch.load(save_path) + new_weights = OrderedDict() + flag=False + for k, v in states['model_state'].items(): + #print(k) + if "module" not in k: + break + else: + new_weights[k[7:]]=v + flag=True + if flag: + try: + model.load_state_dict(new_weights) + except: + #print(states['model_state']) + model.load_state_dict(states['model_state']) + else: + model.load_state_dict(states['model_state']) + logger.info(f"load self-pretrained checkpoints for {load_name}") + +def model_complexity(model, args): + from ptflops import get_model_complexity_info + flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN), + as_strings=False, print_per_layer_stat=False) + logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9)) + logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6)) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) \ No newline at end of file diff --git a/utils/rotation_conversions.py b/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bfaa1b2247622bff35d3f9b15e8eb84064aa53 --- /dev/null +++ b/utils/rotation_conversions.py @@ -0,0 +1,550 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)