MotionDiffuse / datasets /evaluator.py
root
initial commit
12deb01
import torch
from utils.word_vectorizer import WordVectorizer, POS_enumerator
from utils.get_opt import get_opt
from models import MotionTransformer
from torch.utils.data import Dataset, DataLoader
from os.path import join as pjoin
from tqdm import tqdm
import numpy as np
from .evaluator_models import *
import os
import codecs as cs
import random
from torch.utils.data._utils.collate import default_collate
class EvaluationDataset(Dataset):
def __init__(self, opt, trainer, dataset, w_vectorizer, mm_num_samples, mm_num_repeats):
assert mm_num_samples < len(dataset)
print(opt.model_dir)
dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
epoch, it = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar'))
generated_motion = []
min_mov_length = 10 if opt.dataset_name == 't2m' else 6
trainer.eval_mode()
trainer.to(opt.device)
# Pre-process all target captions
mm_generated_motions = []
mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False)
mm_idxs = np.sort(mm_idxs)
all_caption = []
all_m_lens = []
all_data = []
with torch.no_grad():
for i, data in tqdm(enumerate(dataloader)):
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
all_data.append(data)
tokens = tokens[0].split('_')
mm_num_now = len(mm_generated_motions)
is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False
repeat_times = mm_num_repeats if is_mm else 1
m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length)
m_lens = min(m_lens, opt.max_motion_length)
if isinstance(m_lens, int):
m_lens = torch.LongTensor([m_lens]).to(opt.device)
else:
m_lens = m_lens.to(opt.device)
for t in range(repeat_times):
all_m_lens.append(m_lens)
all_caption.extend(caption)
if is_mm:
mm_generated_motions.append(0)
all_m_lens = torch.stack(all_m_lens)
# Generate all sequences
with torch.no_grad():
all_pred_motions = trainer.generate(all_caption, all_m_lens, opt.dim_pose)
cur_idx = 0
mm_generated_motions = []
with torch.no_grad():
for i, data_dummy in tqdm(enumerate(dataloader)):
data = all_data[i]
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
tokens = tokens[0].split('_')
mm_num_now = len(mm_generated_motions)
is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False
repeat_times = mm_num_repeats if is_mm else 1
mm_motions = []
m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length)
m_lens = min(m_lens, opt.max_motion_length)
if isinstance(m_lens, int):
m_lens = torch.LongTensor([m_lens]).to(opt.device)
else:
m_lens = m_lens.to(opt.device)
for t in range(repeat_times):
m_len = m_lens[0].item()
pred_motions = all_pred_motions[cur_idx][:m_lens[0].item()]
assert pred_motions.shape[0] == m_lens[0].item()
cur_idx += 1
if t == 0:
sub_dict = {'motion': pred_motions.cpu().numpy(),
'length': pred_motions.shape[0],
'caption': caption[0],
'cap_len': cap_lens[0].item(),
'tokens': tokens}
generated_motion.append(sub_dict)
if is_mm:
mm_motions.append({
'motion': pred_motions.cpu().numpy(),
'length': m_lens[0].item()
})
if is_mm:
mm_generated_motions.append({'caption': caption[0],
'tokens': tokens,
'cap_len': cap_lens[0].item(),
'mm_motions': mm_motions})
self.generated_motion = generated_motion
self.mm_generated_motion = mm_generated_motions
self.opt = opt
self.w_vectorizer = w_vectorizer
def __len__(self):
return len(self.generated_motion)
def __getitem__(self, item):
data = self.generated_motion[item]
motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
sent_len = data['cap_len']
pos_one_hots = []
word_embeddings = []
for token in tokens:
word_emb, pos_oh = self.w_vectorizer[token]
pos_one_hots.append(pos_oh[None, :])
word_embeddings.append(word_emb[None, :])
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
word_embeddings = np.concatenate(word_embeddings, axis=0)
if m_length < self.opt.max_motion_length:
motion = np.concatenate([motion,
np.zeros((self.opt.max_motion_length - m_length, motion.shape[1]))
], axis=0)
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
def collate_fn(batch):
batch.sort(key=lambda x: x[3], reverse=True)
return default_collate(batch)
'''For use of training text motion matching model, and evaluations'''
class Text2MotionDatasetV2(Dataset):
def __init__(self, opt, mean, std, split_file, w_vectorizer):
self.opt = opt
self.w_vectorizer = w_vectorizer
self.max_length = 20
self.pointer = 0
self.max_motion_length = opt.max_motion_length
min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
data_dict = {}
id_list = []
with cs.open(split_file, 'r') as f:
for line in f.readlines():
id_list.append(line.strip())
new_name_list = []
length_list = []
for name in tqdm(id_list):
try:
motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
if (len(motion)) < min_motion_len or (len(motion) >= 200):
continue
text_data = []
flag = False
with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
for line in f.readlines():
text_dict = {}
line_split = line.strip().split('#')
caption = line_split[0]
tokens = line_split[1].split(' ')
f_tag = float(line_split[2])
to_tag = float(line_split[3])
f_tag = 0.0 if np.isnan(f_tag) else f_tag
to_tag = 0.0 if np.isnan(to_tag) else to_tag
text_dict['caption'] = caption
text_dict['tokens'] = tokens
if f_tag == 0.0 and to_tag == 0.0:
flag = True
text_data.append(text_dict)
else:
try:
n_motion = motion[int(f_tag*20) : int(to_tag*20)]
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
continue
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
while new_name in data_dict:
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
data_dict[new_name] = {'motion': n_motion,
'length': len(n_motion),
'text':[text_dict]}
new_name_list.append(new_name)
length_list.append(len(n_motion))
except:
print(line_split)
print(line_split[2], line_split[3], f_tag, to_tag, name)
# break
if flag:
data_dict[name] = {'motion': motion,
'length': len(motion),
'text': text_data}
new_name_list.append(name)
length_list.append(len(motion))
except:
pass
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
self.mean = mean
self.std = std
self.length_arr = np.array(length_list)
self.data_dict = data_dict
self.name_list = name_list
self.reset_max_len(self.max_length)
def reset_max_len(self, length):
assert length <= self.max_motion_length
self.pointer = np.searchsorted(self.length_arr, length)
print("Pointer Pointing at %d"%self.pointer)
self.max_length = length
def inv_transform(self, data):
return data * self.std + self.mean
def __len__(self):
return len(self.data_dict) - self.pointer
def __getitem__(self, item):
idx = self.pointer + item
data = self.data_dict[self.name_list[idx]]
motion, m_length, text_list = data['motion'], data['length'], data['text']
# Randomly select a caption
text_data = random.choice(text_list)
caption, tokens = text_data['caption'], text_data['tokens']
if len(tokens) < self.opt.max_text_len:
# pad with "unk"
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
sent_len = len(tokens)
tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
else:
# crop
tokens = tokens[:self.opt.max_text_len]
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
sent_len = len(tokens)
pos_one_hots = []
word_embeddings = []
for token in tokens:
word_emb, pos_oh = self.w_vectorizer[token]
pos_one_hots.append(pos_oh[None, :])
word_embeddings.append(word_emb[None, :])
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
word_embeddings = np.concatenate(word_embeddings, axis=0)
# Crop the motions in to times of 4, and introduce small variations
if self.opt.unit_length < 10:
coin2 = np.random.choice(['single', 'single', 'double'])
else:
coin2 = 'single'
if coin2 == 'double':
m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
elif coin2 == 'single':
m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
idx = random.randint(0, len(motion) - m_length)
motion = motion[idx:idx+m_length]
"Z Normalization"
motion = (motion - self.mean) / self.std
if m_length < self.max_motion_length:
motion = np.concatenate([motion,
np.zeros((self.max_motion_length - m_length, motion.shape[1]))
], axis=0)
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
def get_dataset_motion_loader(opt_path, batch_size, device):
opt = get_opt(opt_path, device)
# Configurations of T2M dataset and KIT dataset is almost the same
if opt.dataset_name == 't2m' or opt.dataset_name == 'kit':
print('Loading dataset %s ...' % opt.dataset_name)
mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
std = np.load(pjoin(opt.meta_dir, 'std.npy'))
w_vectorizer = WordVectorizer('./data/glove', 'our_vab')
split_file = pjoin(opt.data_root, 'test.txt')
dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True,
collate_fn=collate_fn, shuffle=True)
else:
raise KeyError('Dataset not Recognized !!')
print('Ground Truth Dataset Loading Completed!!!')
return dataloader, dataset
class MMGeneratedDataset(Dataset):
def __init__(self, opt, motion_dataset, w_vectorizer):
self.opt = opt
self.dataset = motion_dataset.mm_generated_motion
self.w_vectorizer = w_vectorizer
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
data = self.dataset[item]
mm_motions = data['mm_motions']
m_lens = []
motions = []
for mm_motion in mm_motions:
m_lens.append(mm_motion['length'])
motion = mm_motion['motion']
if len(motion) < self.opt.max_motion_length:
motion = np.concatenate([motion,
np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1]))
], axis=0)
motion = motion[None, :]
motions.append(motion)
m_lens = np.array(m_lens, dtype=np.int)
motions = np.concatenate(motions, axis=0)
sort_indx = np.argsort(m_lens)[::-1].copy()
# print(m_lens)
# print(sort_indx)
# print(m_lens[sort_indx])
m_lens = m_lens[sort_indx]
motions = motions[sort_indx]
return motions, m_lens
def get_motion_loader(opt, batch_size, trainer, ground_truth_dataset, mm_num_samples, mm_num_repeats):
# Currently the configurations of two datasets are almost the same
if opt.dataset_name == 't2m' or opt.dataset_name == 'kit':
w_vectorizer = WordVectorizer('./data/glove', 'our_vab')
else:
raise KeyError('Dataset not recognized!!')
print('Generating %s ...' % opt.name)
dataset = EvaluationDataset(opt, trainer, ground_truth_dataset, w_vectorizer, mm_num_samples, mm_num_repeats)
mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer)
motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4)
mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1)
print('Generated Dataset Loading Completed!!!')
return motion_loader, mm_motion_loader
def build_models(opt):
movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
pos_size=opt.dim_pos_ohot,
hidden_size=opt.dim_text_hidden,
output_size=opt.dim_coemb_hidden,
device=opt.device)
motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
hidden_size=opt.dim_motion_hidden,
output_size=opt.dim_coemb_hidden,
device=opt.device)
checkpoint = torch.load(pjoin('data/pretrained_models', opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
map_location=opt.device)
movement_enc.load_state_dict(checkpoint['movement_encoder'])
text_enc.load_state_dict(checkpoint['text_encoder'])
motion_enc.load_state_dict(checkpoint['motion_encoder'])
print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
return text_enc, motion_enc, movement_enc
class EvaluatorModelWrapper(object):
def __init__(self, opt):
if opt.dataset_name == 't2m':
opt.dim_pose = 263
elif opt.dataset_name == 'kit':
opt.dim_pose = 251
else:
raise KeyError('Dataset not Recognized!!!')
opt.dim_word = 300
opt.max_motion_length = 196
opt.dim_pos_ohot = len(POS_enumerator)
opt.dim_motion_hidden = 1024
opt.max_text_len = 20
opt.dim_text_hidden = 512
opt.dim_coemb_hidden = 512
self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
self.opt = opt
self.device = opt.device
self.text_encoder.to(opt.device)
self.motion_encoder.to(opt.device)
self.movement_encoder.to(opt.device)
self.text_encoder.eval()
self.motion_encoder.eval()
self.movement_encoder.eval()
# Please note that the results does not following the order of inputs
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
with torch.no_grad():
word_embs = word_embs.detach().to(self.device).float()
pos_ohot = pos_ohot.detach().to(self.device).float()
motions = motions.detach().to(self.device).float()
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
motions = motions[align_idx]
m_lens = m_lens[align_idx]
'''Movement Encoding'''
movements = self.movement_encoder(motions[..., :-4]).detach()
m_lens = m_lens // self.opt.unit_length
motion_embedding = self.motion_encoder(movements, m_lens)
'''Text Encoding'''
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
text_embedding = text_embedding[align_idx]
return text_embedding, motion_embedding
# Please note that the results does not following the order of inputs
def get_motion_embeddings(self, motions, m_lens):
with torch.no_grad():
motions = motions.detach().to(self.device).float()
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
motions = motions[align_idx]
m_lens = m_lens[align_idx]
'''Movement Encoding'''
movements = self.movement_encoder(motions[..., :-4]).detach()
m_lens = m_lens // self.opt.unit_length
motion_embedding = self.motion_encoder(movements, m_lens)
return motion_embedding