Elle McFarlane
add gitignore etc
15d6c34
import codecs as cs
import os
import random
from os.path import join as pjoin
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data._utils.collate import default_collate
from tqdm import tqdm
from models import MotionTransformer
from utils.get_opt import get_opt
from utils.word_vectorizer import POS_enumerator, WordVectorizer
from .evaluator_models import *
from .utils import drop_shapes_from_motion_arr
class EvaluationDataset(Dataset):
def __init__(self, opt, trainer, dataset, w_vectorizer, mm_num_samples, mm_num_repeats):
assert mm_num_samples < len(dataset)
print(opt.model_dir)
dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
epoch, it = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar'))
generated_motion = []
min_mov_length = 10 if opt.dataset_name == 't2m' else 6
trainer.eval_mode()
trainer.to(opt.device)
# Pre-process all target captions
mm_generated_motions = []
mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False)
mm_idxs = np.sort(mm_idxs)
all_caption = []
all_m_lens = []
all_data = []
with torch.no_grad():
for i, data in tqdm(enumerate(dataloader)):
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
all_data.append(data)
tokens = tokens[0].split('_')
mm_num_now = len(mm_generated_motions)
is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False
repeat_times = mm_num_repeats if is_mm else 1
m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length)
m_lens = min(m_lens, opt.max_motion_length)
if isinstance(m_lens, int):
m_lens = torch.LongTensor([m_lens]).to(opt.device)
else:
m_lens = m_lens.to(opt.device)
for t in range(repeat_times):
all_m_lens.append(m_lens)
all_caption.extend(caption)
if is_mm:
mm_generated_motions.append(0)
all_m_lens = torch.stack(all_m_lens)
# Generate all sequences
with torch.no_grad():
all_pred_motions = trainer.generate(all_caption, all_m_lens, opt.dim_pose)
cur_idx = 0
mm_generated_motions = []
with torch.no_grad():
for i, data_dummy in tqdm(enumerate(dataloader)):
data = all_data[i]
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
tokens = tokens[0].split('_')
mm_num_now = len(mm_generated_motions)
is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False
repeat_times = mm_num_repeats if is_mm else 1
mm_motions = []
m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length)
m_lens = min(m_lens, opt.max_motion_length)
if isinstance(m_lens, int):
m_lens = torch.LongTensor([m_lens]).to(opt.device)
else:
m_lens = m_lens.to(opt.device)
for t in range(repeat_times):
m_len = m_lens[0].item()
pred_motions = all_pred_motions[cur_idx][:m_lens[0].item()]
assert pred_motions.shape[0] == m_lens[0].item()
cur_idx += 1
if t == 0:
sub_dict = {'motion': pred_motions.cpu().numpy(),
'length': pred_motions.shape[0],
'caption': caption[0],
'cap_len': cap_lens[0].item(),
'tokens': tokens}
generated_motion.append(sub_dict)
if is_mm:
mm_motions.append({
'motion': pred_motions.cpu().numpy(),
'length': m_lens[0].item()
})
if is_mm:
mm_generated_motions.append({'caption': caption[0],
'tokens': tokens,
'cap_len': cap_lens[0].item(),
'mm_motions': mm_motions})
self.generated_motion = generated_motion
self.mm_generated_motion = mm_generated_motions
self.opt = opt
self.w_vectorizer = w_vectorizer
def __len__(self):
return len(self.generated_motion)
def __getitem__(self, item):
data = self.generated_motion[item]
motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
sent_len = data['cap_len']
pos_one_hots = []
word_embeddings = []
for token in tokens:
word_emb, pos_oh = self.w_vectorizer[token]
pos_one_hots.append(pos_oh[None, :])
word_embeddings.append(word_emb[None, :])
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
word_embeddings = np.concatenate(word_embeddings, axis=0)
if m_length < self.opt.max_motion_length:
motion = np.concatenate([motion,
np.zeros((self.opt.max_motion_length - m_length, motion.shape[1]))
], axis=0)
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
def collate_fn(batch):
batch.sort(key=lambda x: x[3], reverse=True)
return default_collate(batch)
'''For use of training text motion matching model, and evaluations'''
class Text2MotionDatasetV2(Dataset):
def __init__(self, opt, mean, std, split_file, w_vectorizer):
self.opt = opt
self.w_vectorizer = w_vectorizer
self.max_length = 20
self.pointer = 0
self.max_motion_length = opt.max_motion_length
min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
data_dict = {}
id_list = []
with cs.open(split_file, 'r') as f:
for line in f.readlines():
id_list.append(line.strip())
new_name_list = []
length_list = []
for name in tqdm(id_list):
try:
print(f"attempting to load motion for {name} at {pjoin(opt.motion_dir, name + '.npy')}")
motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
if self.opt.dataset_name.lower() == 'grab':
motion = drop_shapes_from_motion_arr(motion)
assert motion.shape[-1] == opt.dim_pose, f"motion shape {motion.shape} does not match dim_pose {opt.dim_pose}"
print(f"grab motion shape: {motion.shape}")
print(f"len of motion: {len(motion)}")
# TODO (elmc): verify we don't need this for GRAB data
# if (len(motion)) < min_motion_len or (len(motion) >= 200):
# continue
text_data = []
flag = False
with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
for line in f.readlines():
text_dict = {}
line_split = line.strip().split('#')
caption = line_split[0]
f_tag = 0.0
to_tag = 0.0
# TODO (elmc): add actual tokens back for grab
tokens = []
if self.opt.dataset_name.lower() != 'grab':
tokens = line_split[1].split(' ')
f_tag = float(line_split[2])
to_tag = float(line_split[3])
f_tag = 0.0 if np.isnan(f_tag) else f_tag
to_tag = 0.0 if np.isnan(to_tag) else to_tag
text_dict['caption'] = caption
text_dict['tokens'] = tokens
if f_tag == 0.0 and to_tag == 0.0:
flag = True
text_data.append(text_dict)
else:
n_motion = motion[int(f_tag*20) : int(to_tag*20)]
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
continue
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
while new_name in data_dict:
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
data_dict[new_name] = {'motion': n_motion,
'length': len(n_motion),
'text':[text_dict]}
new_name_list.append(new_name)
length_list.append(len(n_motion))
if flag:
data_dict[name] = {'motion': motion,
'length': len(motion),
'text':text_data}
new_name_list.append(name)
length_list.append(len(motion))
except Exception as e:
# Some motion may not exist in KIT dataset
print(f"failed to load motion for {name} at {pjoin(opt.motion_dir, name + '.npy')} due to {e}")
pass
if not new_name_list or not length_list:
raise ValueError(f'No data loaded, new_name_list has len {len(new_name_list)} and length_list has len {len(length_list)}')
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
print(f"LOADED length of name_list: {len(name_list)}")
self.mean = mean
self.std = std
self.length_arr = np.array(length_list)
self.data_dict = data_dict
self.name_list = name_list
# TODO (elmc): so.... V2 is same as V1 but has reset_max_len??
self.reset_max_len(self.max_length)
def reset_max_len(self, length):
assert length <= self.max_motion_length
self.pointer = np.searchsorted(self.length_arr, length)
print("Pointer Pointing at %d"%self.pointer)
self.max_length = length
def inv_transform(self, data):
return data * self.std + self.mean
def __len__(self):
return len(self.data_dict) - self.pointer
def __getitem__(self, item):
idx = self.pointer + item
data = self.data_dict[self.name_list[idx]]
motion, m_length, text_list = data['motion'], data['length'], data['text']
# Randomly select a caption
text_data = random.choice(text_list)
caption, tokens = text_data['caption'], text_data['tokens']
if len(tokens) < self.opt.max_text_len:
# pad with "unk"
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
sent_len = len(tokens)
tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
else:
# crop
tokens = tokens[:self.opt.max_text_len]
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
sent_len = len(tokens)
pos_one_hots = []
word_embeddings = []
for token in tokens:
word_emb, pos_oh = self.w_vectorizer[token]
pos_one_hots.append(pos_oh[None, :])
word_embeddings.append(word_emb[None, :])
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
word_embeddings = np.concatenate(word_embeddings, axis=0)
# Crop the motions in to times of 4, and introduce small variations
if self.opt.unit_length < 10:
coin2 = np.random.choice(['single', 'single', 'double'])
else:
coin2 = 'single'
if coin2 == 'double':
m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
elif coin2 == 'single':
m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
idx = random.randint(0, len(motion) - m_length)
motion = motion[idx:idx+m_length]
"Z Normalization"
motion = (motion - self.mean) / self.std
if m_length < self.max_motion_length:
motion = np.concatenate([motion,
np.zeros((self.max_motion_length - m_length, motion.shape[1]))
], axis=0)
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
def get_dataset_motion_loader(opt_path, batch_size, device):
opt = get_opt(opt_path, device)
# Configurations of T2M dataset and KIT dataset is almost the same
if opt.dataset_name == 't2m' or opt.dataset_name == 'kit' or opt.dataset_name == 'grab':
print('Loading dataset %s ...' % opt.dataset_name)
mean_path = pjoin(opt.meta_dir, 'mean.npy')
std_path = pjoin(opt.meta_dir, 'std.npy')
if not os.path.exists(mean_path):
mean = np.zeros(opt.dim_pose)
else:
mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
if not os.path.exists(std_path):
std = np.ones(opt.dim_pose)
else:
std = np.load(pjoin(opt.meta_dir, 'std.npy'))
# get glove data via following instructions here
# https://github.com/mingyuan-zhang/MotionDiffuse/blob/main/text2motion/install.md#data-preparation
w_vectorizer = WordVectorizer('./data/glove', 'our_vab')
split_file = pjoin(opt.data_root, 'test.txt')
dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True,
collate_fn=collate_fn, shuffle=True)
else:
raise KeyError('Dataset not Recognized !!')
print('Ground Truth Dataset Loading Completed!!!')
return dataloader, dataset
class MMGeneratedDataset(Dataset):
def __init__(self, opt, motion_dataset, w_vectorizer):
self.opt = opt
self.dataset = motion_dataset.mm_generated_motion
self.w_vectorizer = w_vectorizer
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
data = self.dataset[item]
mm_motions = data['mm_motions']
m_lens = []
motions = []
for mm_motion in mm_motions:
m_lens.append(mm_motion['length'])
motion = mm_motion['motion']
if len(motion) < self.opt.max_motion_length:
motion = np.concatenate([motion,
np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1]))
], axis=0)
motion = motion[None, :]
motions.append(motion)
m_lens = np.array(m_lens, dtype=np.int)
motions = np.concatenate(motions, axis=0)
sort_indx = np.argsort(m_lens)[::-1].copy()
# print(m_lens)
# print(sort_indx)
# print(m_lens[sort_indx])
m_lens = m_lens[sort_indx]
motions = motions[sort_indx]
return motions, m_lens
def get_motion_loader(opt, batch_size, trainer, ground_truth_dataset, mm_num_samples, mm_num_repeats):
# Currently the configurations of two datasets are almost the same
if opt.dataset_name == 't2m' or opt.dataset_name == 'kit':
w_vectorizer = WordVectorizer('./data/glove', 'our_vab')
else:
raise KeyError('Dataset not recognized!!')
print('Generating %s ...' % opt.name)
dataset = EvaluationDataset(opt, trainer, ground_truth_dataset, w_vectorizer, mm_num_samples, mm_num_repeats)
mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer)
motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4)
mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1)
print('Generated Dataset Loading Completed!!!')
return motion_loader, mm_motion_loader
def build_models(opt):
movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
pos_size=opt.dim_pos_ohot,
hidden_size=opt.dim_text_hidden,
output_size=opt.dim_coemb_hidden,
device=opt.device)
motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
hidden_size=opt.dim_motion_hidden,
output_size=opt.dim_coemb_hidden,
device=opt.device)
checkpoint = torch.load(pjoin('data/pretrained_models', opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
map_location=opt.device)
movement_enc.load_state_dict(checkpoint['movement_encoder'])
text_enc.load_state_dict(checkpoint['text_encoder'])
motion_enc.load_state_dict(checkpoint['motion_encoder'])
print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
return text_enc, motion_enc, movement_enc
class EvaluatorModelWrapper(object):
def __init__(self, opt):
if opt.dataset_name == 't2m':
opt.dim_pose = 263
elif opt.dataset_name == 'kit':
opt.dim_pose = 251
elif opt.dataset_name == 'grab':
opt.dim_pose = 212
else:
raise KeyError('Dataset not Recognized!!!')
opt.dim_word = 300
opt.max_motion_length = 196
opt.dim_pos_ohot = len(POS_enumerator)
opt.dim_motion_hidden = 1024
opt.max_text_len = 20
opt.dim_text_hidden = 512
opt.dim_coemb_hidden = 512
self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
self.opt = opt
self.device = opt.device
self.text_encoder.to(opt.device)
self.motion_encoder.to(opt.device)
self.movement_encoder.to(opt.device)
self.text_encoder.eval()
self.motion_encoder.eval()
self.movement_encoder.eval()
# Please note that the results does not following the order of inputs
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
with torch.no_grad():
word_embs = word_embs.detach().to(self.device).float()
pos_ohot = pos_ohot.detach().to(self.device).float()
motions = motions.detach().to(self.device).float()
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
motions = motions[align_idx]
m_lens = m_lens[align_idx]
'''Movement Encoding'''
movements = self.movement_encoder(motions[..., :-4]).detach()
m_lens = m_lens // self.opt.unit_length
motion_embedding = self.motion_encoder(movements, m_lens)
'''Text Encoding'''
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
text_embedding = text_embedding[align_idx]
return text_embedding, motion_embedding
# Please note that the results does not following the order of inputs
def get_motion_embeddings(self, motions, m_lens):
with torch.no_grad():
motions = motions.detach().to(self.device).float()
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
motions = motions[align_idx]
m_lens = m_lens[align_idx]
'''Movement Encoding'''
movements = self.movement_encoder(motions[..., :-4]).detach()
m_lens = m_lens // self.opt.unit_length
motion_embedding = self.motion_encoder(movements, m_lens)
return motion_embedding