|
|
|
|
|
import torch |
|
from os.path import join as pjoin |
|
import numpy as np |
|
from .evaluator_modules import * |
|
|
|
|
|
def build_models(opt): |
|
movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) |
|
text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, |
|
pos_size=opt.dim_pos_ohot, |
|
hidden_size=opt.dim_text_hidden, |
|
output_size=opt.dim_coemb_hidden, |
|
device=opt.device) |
|
|
|
motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, |
|
hidden_size=opt.dim_motion_hidden, |
|
output_size=opt.dim_coemb_hidden, |
|
device=opt.device) |
|
checkpoint = torch.load(pjoin(opt.evaluator_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), |
|
map_location=opt.device) |
|
movement_enc.load_state_dict(checkpoint['movement_encoder']) |
|
text_enc.load_state_dict(checkpoint['text_encoder']) |
|
motion_enc.load_state_dict(checkpoint['motion_encoder']) |
|
print('\nLoading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) |
|
return text_enc, motion_enc, movement_enc |
|
|
|
class EvaluatorModelWrapper(object): |
|
|
|
def __init__(self, opt): |
|
self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) |
|
self.opt = opt |
|
self.device = opt.device |
|
|
|
self.text_encoder.to(opt.device) |
|
self.motion_encoder.to(opt.device) |
|
self.movement_encoder.to(opt.device) |
|
|
|
self.text_encoder.eval() |
|
self.motion_encoder.eval() |
|
self.movement_encoder.eval() |
|
|
|
|
|
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): |
|
with torch.no_grad(): |
|
word_embs = word_embs.detach().to(self.device).float() |
|
pos_ohot = pos_ohot.detach().to(self.device).float() |
|
motions = motions.detach().to(self.device).float() |
|
|
|
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() |
|
motions = motions[align_idx] |
|
m_lens = m_lens[align_idx] |
|
|
|
'''Movement Encoding''' |
|
movements = self.movement_encoder(motions[..., :-4]).detach() |
|
m_lens = torch.div(m_lens, self.opt.unit_length, rounding_mode='trunc') |
|
motion_embedding = self.motion_encoder(movements, m_lens) |
|
|
|
'''Text Encoding''' |
|
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) |
|
text_embedding = text_embedding[align_idx] |
|
return text_embedding, motion_embedding |
|
|
|
|
|
def get_motion_embeddings(self, motions, m_lens): |
|
with torch.no_grad(): |
|
motions = motions.detach().to(self.device).float() |
|
|
|
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() |
|
motions = motions[align_idx] |
|
m_lens = m_lens[align_idx] |
|
|
|
'''Movement Encoding''' |
|
movements = self.movement_encoder(motions[..., :-4]).detach() |
|
m_lens = torch.div(m_lens, self.opt.unit_length, rounding_mode='trunc') |
|
motion_embedding = self.motion_encoder(movements, m_lens) |
|
return motion_embedding |
|
|