Generate_human_motion / VQ-Trans /models /evaluator_wrapper.py
vumichien's picture
First commit
4275cae
raw history blame
No virus
3.76 kB
import torch
from os.path import join as pjoin
import numpy as np
from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo
from utils.word_vectorizer import POS_enumerator
def build_models(opt):
movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
pos_size=opt.dim_pos_ohot,
hidden_size=opt.dim_text_hidden,
output_size=opt.dim_coemb_hidden,
device=opt.device)
motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
hidden_size=opt.dim_motion_hidden,
output_size=opt.dim_coemb_hidden,
device=opt.device)
checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
map_location=opt.device)
movement_enc.load_state_dict(checkpoint['movement_encoder'])
text_enc.load_state_dict(checkpoint['text_encoder'])
motion_enc.load_state_dict(checkpoint['motion_encoder'])
print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
return text_enc, motion_enc, movement_enc
class EvaluatorModelWrapper(object):
def __init__(self, opt):
if opt.dataset_name == 't2m':
opt.dim_pose = 263
elif opt.dataset_name == 'kit':
opt.dim_pose = 251
else:
raise KeyError('Dataset not Recognized!!!')
opt.dim_word = 300
opt.max_motion_length = 196
opt.dim_pos_ohot = len(POS_enumerator)
opt.dim_motion_hidden = 1024
opt.max_text_len = 20
opt.dim_text_hidden = 512
opt.dim_coemb_hidden = 512
# print(opt)
self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
self.opt = opt
self.device = opt.device
self.text_encoder.to(opt.device)
self.motion_encoder.to(opt.device)
self.movement_encoder.to(opt.device)
self.text_encoder.eval()
self.motion_encoder.eval()
self.movement_encoder.eval()
# Please note that the results does not following the order of inputs
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
with torch.no_grad():
word_embs = word_embs.detach().to(self.device).float()
pos_ohot = pos_ohot.detach().to(self.device).float()
motions = motions.detach().to(self.device).float()
'''Movement Encoding'''
movements = self.movement_encoder(motions[..., :-4]).detach()
m_lens = m_lens // self.opt.unit_length
motion_embedding = self.motion_encoder(movements, m_lens)
'''Text Encoding'''
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
return text_embedding, motion_embedding
# Please note that the results does not following the order of inputs
def get_motion_embeddings(self, motions, m_lens):
with torch.no_grad():
motions = motions.detach().to(self.device).float()
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
motions = motions[align_idx]
m_lens = m_lens[align_idx]
'''Movement Encoding'''
movements = self.movement_encoder(motions[..., :-4]).detach()
m_lens = m_lens // self.opt.unit_length
motion_embedding = self.motion_encoder(movements, m_lens)
return motion_embedding