import torch from utils.word_vectorizer import WordVectorizer, POS_enumerator from utils.get_opt import get_opt from models import MotionTransformer from torch.utils.data import Dataset, DataLoader from os.path import join as pjoin from tqdm import tqdm import numpy as np from .evaluator_models import * import os import codecs as cs import random from torch.utils.data._utils.collate import default_collate class EvaluationDataset(Dataset): def __init__(self, opt, trainer, dataset, w_vectorizer, mm_num_samples, mm_num_repeats): assert mm_num_samples < len(dataset) print(opt.model_dir) dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True) epoch, it = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar')) generated_motion = [] min_mov_length = 10 if opt.dataset_name == 't2m' else 6 trainer.eval_mode() trainer.to(opt.device) # Pre-process all target captions mm_generated_motions = [] mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False) mm_idxs = np.sort(mm_idxs) all_caption = [] all_m_lens = [] all_data = [] with torch.no_grad(): for i, data in tqdm(enumerate(dataloader)): word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data all_data.append(data) tokens = tokens[0].split('_') mm_num_now = len(mm_generated_motions) is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False repeat_times = mm_num_repeats if is_mm else 1 m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length) m_lens = min(m_lens, opt.max_motion_length) if isinstance(m_lens, int): m_lens = torch.LongTensor([m_lens]).to(opt.device) else: m_lens = m_lens.to(opt.device) for t in range(repeat_times): all_m_lens.append(m_lens) all_caption.extend(caption) if is_mm: mm_generated_motions.append(0) all_m_lens = torch.stack(all_m_lens) # Generate all sequences with torch.no_grad(): all_pred_motions = trainer.generate(all_caption, all_m_lens, opt.dim_pose) cur_idx = 0 mm_generated_motions = [] with torch.no_grad(): for i, data_dummy in tqdm(enumerate(dataloader)): data = all_data[i] word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data tokens = tokens[0].split('_') mm_num_now = len(mm_generated_motions) is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False repeat_times = mm_num_repeats if is_mm else 1 mm_motions = [] m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length) m_lens = min(m_lens, opt.max_motion_length) if isinstance(m_lens, int): m_lens = torch.LongTensor([m_lens]).to(opt.device) else: m_lens = m_lens.to(opt.device) for t in range(repeat_times): m_len = m_lens[0].item() pred_motions = all_pred_motions[cur_idx][:m_lens[0].item()] assert pred_motions.shape[0] == m_lens[0].item() cur_idx += 1 if t == 0: sub_dict = {'motion': pred_motions.cpu().numpy(), 'length': pred_motions.shape[0], 'caption': caption[0], 'cap_len': cap_lens[0].item(), 'tokens': tokens} generated_motion.append(sub_dict) if is_mm: mm_motions.append({ 'motion': pred_motions.cpu().numpy(), 'length': m_lens[0].item() }) if is_mm: mm_generated_motions.append({'caption': caption[0], 'tokens': tokens, 'cap_len': cap_lens[0].item(), 'mm_motions': mm_motions}) self.generated_motion = generated_motion self.mm_generated_motion = mm_generated_motions self.opt = opt self.w_vectorizer = w_vectorizer def __len__(self): return len(self.generated_motion) def __getitem__(self, item): data = self.generated_motion[item] motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens'] sent_len = data['cap_len'] pos_one_hots = [] word_embeddings = [] for token in tokens: word_emb, pos_oh = self.w_vectorizer[token] pos_one_hots.append(pos_oh[None, :]) word_embeddings.append(word_emb[None, :]) pos_one_hots = np.concatenate(pos_one_hots, axis=0) word_embeddings = np.concatenate(word_embeddings, axis=0) if m_length < self.opt.max_motion_length: motion = np.concatenate([motion, np.zeros((self.opt.max_motion_length - m_length, motion.shape[1])) ], axis=0) return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) def collate_fn(batch): batch.sort(key=lambda x: x[3], reverse=True) return default_collate(batch) '''For use of training text motion matching model, and evaluations''' class Text2MotionDatasetV2(Dataset): def __init__(self, opt, mean, std, split_file, w_vectorizer): self.opt = opt self.w_vectorizer = w_vectorizer self.max_length = 20 self.pointer = 0 self.max_motion_length = opt.max_motion_length min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24 data_dict = {} id_list = [] with cs.open(split_file, 'r') as f: for line in f.readlines(): id_list.append(line.strip()) new_name_list = [] length_list = [] for name in tqdm(id_list): try: motion = np.load(pjoin(opt.motion_dir, name + '.npy')) if (len(motion)) < min_motion_len or (len(motion) >= 200): continue text_data = [] flag = False with cs.open(pjoin(opt.text_dir, name + '.txt')) as f: for line in f.readlines(): text_dict = {} line_split = line.strip().split('#') caption = line_split[0] tokens = line_split[1].split(' ') f_tag = float(line_split[2]) to_tag = float(line_split[3]) f_tag = 0.0 if np.isnan(f_tag) else f_tag to_tag = 0.0 if np.isnan(to_tag) else to_tag text_dict['caption'] = caption text_dict['tokens'] = tokens if f_tag == 0.0 and to_tag == 0.0: flag = True text_data.append(text_dict) else: try: n_motion = motion[int(f_tag*20) : int(to_tag*20)] if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200): continue new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name while new_name in data_dict: new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name data_dict[new_name] = {'motion': n_motion, 'length': len(n_motion), 'text':[text_dict]} new_name_list.append(new_name) length_list.append(len(n_motion)) except: print(line_split) print(line_split[2], line_split[3], f_tag, to_tag, name) # break if flag: data_dict[name] = {'motion': motion, 'length': len(motion), 'text': text_data} new_name_list.append(name) length_list.append(len(motion)) except: pass name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1])) self.mean = mean self.std = std self.length_arr = np.array(length_list) self.data_dict = data_dict self.name_list = name_list self.reset_max_len(self.max_length) def reset_max_len(self, length): assert length <= self.max_motion_length self.pointer = np.searchsorted(self.length_arr, length) print("Pointer Pointing at %d"%self.pointer) self.max_length = length def inv_transform(self, data): return data * self.std + self.mean def __len__(self): return len(self.data_dict) - self.pointer def __getitem__(self, item): idx = self.pointer + item data = self.data_dict[self.name_list[idx]] motion, m_length, text_list = data['motion'], data['length'], data['text'] # Randomly select a caption text_data = random.choice(text_list) caption, tokens = text_data['caption'], text_data['tokens'] if len(tokens) < self.opt.max_text_len: # pad with "unk" tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] sent_len = len(tokens) tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len) else: # crop tokens = tokens[:self.opt.max_text_len] tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] sent_len = len(tokens) pos_one_hots = [] word_embeddings = [] for token in tokens: word_emb, pos_oh = self.w_vectorizer[token] pos_one_hots.append(pos_oh[None, :]) word_embeddings.append(word_emb[None, :]) pos_one_hots = np.concatenate(pos_one_hots, axis=0) word_embeddings = np.concatenate(word_embeddings, axis=0) # Crop the motions in to times of 4, and introduce small variations if self.opt.unit_length < 10: coin2 = np.random.choice(['single', 'single', 'double']) else: coin2 = 'single' if coin2 == 'double': m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length elif coin2 == 'single': m_length = (m_length // self.opt.unit_length) * self.opt.unit_length idx = random.randint(0, len(motion) - m_length) motion = motion[idx:idx+m_length] "Z Normalization" motion = (motion - self.mean) / self.std if m_length < self.max_motion_length: motion = np.concatenate([motion, np.zeros((self.max_motion_length - m_length, motion.shape[1])) ], axis=0) return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) def get_dataset_motion_loader(opt_path, batch_size, device): opt = get_opt(opt_path, device) # Configurations of T2M dataset and KIT dataset is almost the same if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': print('Loading dataset %s ...' % opt.dataset_name) mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) std = np.load(pjoin(opt.meta_dir, 'std.npy')) w_vectorizer = WordVectorizer('./data/glove', 'our_vab') split_file = pjoin(opt.data_root, 'test.txt') dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True, collate_fn=collate_fn, shuffle=True) else: raise KeyError('Dataset not Recognized !!') print('Ground Truth Dataset Loading Completed!!!') return dataloader, dataset class MMGeneratedDataset(Dataset): def __init__(self, opt, motion_dataset, w_vectorizer): self.opt = opt self.dataset = motion_dataset.mm_generated_motion self.w_vectorizer = w_vectorizer def __len__(self): return len(self.dataset) def __getitem__(self, item): data = self.dataset[item] mm_motions = data['mm_motions'] m_lens = [] motions = [] for mm_motion in mm_motions: m_lens.append(mm_motion['length']) motion = mm_motion['motion'] if len(motion) < self.opt.max_motion_length: motion = np.concatenate([motion, np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1])) ], axis=0) motion = motion[None, :] motions.append(motion) m_lens = np.array(m_lens, dtype=np.int) motions = np.concatenate(motions, axis=0) sort_indx = np.argsort(m_lens)[::-1].copy() # print(m_lens) # print(sort_indx) # print(m_lens[sort_indx]) m_lens = m_lens[sort_indx] motions = motions[sort_indx] return motions, m_lens def get_motion_loader(opt, batch_size, trainer, ground_truth_dataset, mm_num_samples, mm_num_repeats): # Currently the configurations of two datasets are almost the same if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': w_vectorizer = WordVectorizer('./data/glove', 'our_vab') else: raise KeyError('Dataset not recognized!!') print('Generating %s ...' % opt.name) dataset = EvaluationDataset(opt, trainer, ground_truth_dataset, w_vectorizer, mm_num_samples, mm_num_repeats) mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer) motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4) mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) print('Generated Dataset Loading Completed!!!') return motion_loader, mm_motion_loader def build_models(opt): movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, pos_size=opt.dim_pos_ohot, hidden_size=opt.dim_text_hidden, output_size=opt.dim_coemb_hidden, device=opt.device) motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, hidden_size=opt.dim_motion_hidden, output_size=opt.dim_coemb_hidden, device=opt.device) checkpoint = torch.load(pjoin('data/pretrained_models', opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), map_location=opt.device) movement_enc.load_state_dict(checkpoint['movement_encoder']) text_enc.load_state_dict(checkpoint['text_encoder']) motion_enc.load_state_dict(checkpoint['motion_encoder']) print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) return text_enc, motion_enc, movement_enc class EvaluatorModelWrapper(object): def __init__(self, opt): if opt.dataset_name == 't2m': opt.dim_pose = 263 elif opt.dataset_name == 'kit': opt.dim_pose = 251 else: raise KeyError('Dataset not Recognized!!!') opt.dim_word = 300 opt.max_motion_length = 196 opt.dim_pos_ohot = len(POS_enumerator) opt.dim_motion_hidden = 1024 opt.max_text_len = 20 opt.dim_text_hidden = 512 opt.dim_coemb_hidden = 512 self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) self.opt = opt self.device = opt.device self.text_encoder.to(opt.device) self.motion_encoder.to(opt.device) self.movement_encoder.to(opt.device) self.text_encoder.eval() self.motion_encoder.eval() self.movement_encoder.eval() # Please note that the results does not following the order of inputs def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): with torch.no_grad(): word_embs = word_embs.detach().to(self.device).float() pos_ohot = pos_ohot.detach().to(self.device).float() motions = motions.detach().to(self.device).float() align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() motions = motions[align_idx] m_lens = m_lens[align_idx] '''Movement Encoding''' movements = self.movement_encoder(motions[..., :-4]).detach() m_lens = m_lens // self.opt.unit_length motion_embedding = self.motion_encoder(movements, m_lens) '''Text Encoding''' text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) text_embedding = text_embedding[align_idx] return text_embedding, motion_embedding # Please note that the results does not following the order of inputs def get_motion_embeddings(self, motions, m_lens): with torch.no_grad(): motions = motions.detach().to(self.device).float() align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() motions = motions[align_idx] m_lens = m_lens[align_idx] '''Movement Encoding''' movements = self.movement_encoder(motions[..., :-4]).detach() m_lens = m_lens // self.opt.unit_length motion_embedding = self.motion_encoder(movements, m_lens) return motion_embedding