| |
|
| | import torch
|
| | from os.path import join as pjoin
|
| | import numpy as np
|
| | from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo
|
| | from utils.word_vectorizer import POS_enumerator
|
| |
|
| | def build_models(opt):
|
| | movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
|
| | text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
|
| | pos_size=opt.dim_pos_ohot,
|
| | hidden_size=opt.dim_text_hidden,
|
| | output_size=opt.dim_coemb_hidden,
|
| | device=opt.device)
|
| |
|
| | motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
|
| | hidden_size=opt.dim_motion_hidden,
|
| | output_size=opt.dim_coemb_hidden,
|
| | device=opt.device)
|
| |
|
| | checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
|
| | map_location=opt.device)
|
| | movement_enc.load_state_dict(checkpoint['movement_encoder'])
|
| | text_enc.load_state_dict(checkpoint['text_encoder'])
|
| | motion_enc.load_state_dict(checkpoint['motion_encoder'])
|
| | print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
|
| | return text_enc, motion_enc, movement_enc
|
| |
|
| |
|
| | class EvaluatorModelWrapper(object):
|
| |
|
| | def __init__(self, opt):
|
| |
|
| | if opt.dataset_name == 't2m':
|
| | opt.dim_pose = 263
|
| | elif opt.dataset_name == 'kit':
|
| | opt.dim_pose = 251
|
| | else:
|
| | raise KeyError('Dataset not Recognized!!!')
|
| |
|
| | opt.dim_word = 300
|
| | opt.max_motion_length = 196
|
| | opt.dim_pos_ohot = len(POS_enumerator)
|
| | opt.dim_motion_hidden = 1024
|
| | opt.max_text_len = 20
|
| | opt.dim_text_hidden = 512
|
| | opt.dim_coemb_hidden = 512
|
| |
|
| |
|
| |
|
| | self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
|
| | self.opt = opt
|
| | self.device = opt.device
|
| |
|
| | self.text_encoder.to(opt.device)
|
| | self.motion_encoder.to(opt.device)
|
| | self.movement_encoder.to(opt.device)
|
| |
|
| | self.text_encoder.eval()
|
| | self.motion_encoder.eval()
|
| | self.movement_encoder.eval()
|
| |
|
| |
|
| | def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
|
| | with torch.no_grad():
|
| | word_embs = word_embs.detach().to(self.device).float()
|
| | pos_ohot = pos_ohot.detach().to(self.device).float()
|
| | motions = motions.detach().to(self.device).float()
|
| |
|
| | '''Movement Encoding'''
|
| | movements = self.movement_encoder(motions[..., :-4]).detach()
|
| | m_lens = m_lens // self.opt.unit_length
|
| | motion_embedding = self.motion_encoder(movements, m_lens)
|
| |
|
| | '''Text Encoding'''
|
| | text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
|
| | return text_embedding, motion_embedding
|
| |
|
| |
|
| | def get_motion_embeddings(self, motions, m_lens):
|
| | with torch.no_grad():
|
| | motions = motions.detach().to(self.device).float()
|
| |
|
| | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
| | motions = motions[align_idx]
|
| | m_lens = m_lens[align_idx]
|
| |
|
| | '''Movement Encoding'''
|
| | movements = self.movement_encoder(motions[..., :-4]).detach()
|
| | m_lens = m_lens // self.opt.unit_length
|
| | motion_embedding = self.motion_encoder(movements, m_lens)
|
| | return motion_embedding
|
| |
|