root commited on
Commit
12deb01
1 Parent(s): c050fbf

initial commit

Browse files
.gitattributes CHANGED
@@ -29,3 +29,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ checkpoints/ filter=lfs diff=lfs merge=lfs -text
33
+ checkpoints/t2m/t2m_motiondiffuse/model/latest.tar filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ try:
5
+ os.system("pip install -r requirements.txt")
6
+ except Exception as e:
7
+ print(e)
8
+
9
+ sys.path.insert(0, '.')
10
+
11
+
12
+ from utils.get_opt import get_opt
13
+ from os.path import join as pjoin
14
+ import numpy as np
15
+ from trainers import DDPMTrainer
16
+ from models import MotionTransformer
17
+
18
+ device = 'cpu'
19
+ opt = get_opt("checkpoints/t2m/t2m_motiondiffuse/opt.txt", device)
20
+ opt.do_denoise = True
21
+
22
+ assert opt.dataset_name == "t2m"
23
+ opt.data_root = './dataset/HumanML3D'
24
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
25
+ opt.text_dir = pjoin(opt.data_root, 'texts')
26
+ opt.joints_num = 22
27
+ opt.dim_pose = 263
28
+
29
+ mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
30
+ std = np.load(pjoin(opt.meta_dir, 'std.npy'))
31
+
32
+
33
+ def build_models(opt):
34
+ encoder = MotionTransformer(
35
+ input_feats=opt.dim_pose,
36
+ num_frames=opt.max_motion_length,
37
+ num_layers=opt.num_layers,
38
+ latent_dim=opt.latent_dim,
39
+ no_clip=opt.no_clip,
40
+ no_eff=opt.no_eff)
41
+ return encoder
42
+
43
+
44
+ encoder = build_models(opt).to(device)
45
+ trainer = DDPMTrainer(opt, encoder)
46
+ trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar'))
47
+
48
+ trainer.eval_mode()
49
+ trainer.to(opt.device)
50
+
51
+ def generate(prompt, length):
52
+ from tools.visualization import process
53
+ result_path = "outputs/" + str(hash(prompt)) + ".mp4"
54
+ process(trainer, opt, device, mean, std, prompt, int(length), result_path)
55
+ return result_path
56
+
57
+ demo = gr.Interface(
58
+ fn=generate,
59
+ inputs=["text", gr.Slider(20, 196, value=60)],
60
+ examples=[
61
+ ["the man throws a punch with each hand.", 58],
62
+ ["a person jogs clockwise in a circle.", 178],
63
+ ["a person spins quickly and takes off running.", 29],
64
+ ["a person is walking slowly forward.", 142],
65
+ ["a person quickly waves with their right hand", 46],
66
+ ["a person performing a slight bow", 89],
67
+ ],
68
+ outputs="video",
69
+ title="MotionDiffuse: Text-Driven Human Motion Generation with Diffusion Model",
70
+ description="This is an interactive demo for MotionDiffuse. For more information, feel free to visit our project page(https://mingyuan-zhang.github.io/projects/MotionDiffuse.html).")
71
+
72
+ demo.launch(share=True)
checkpoints/t2m/t2m_motiondiffuse/meta/mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebc7c543b4e27e886dba7e1bcde8fc0149f12a981586a548df98977b4b7c1a6a
3
+ size 1180
checkpoints/t2m/t2m_motiondiffuse/meta/std.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a91111b5b44b2b6785cedf426ede6f8bc412ecb995914563e2810738de179e4
3
+ size 1180
checkpoints/t2m/t2m_motiondiffuse/model/latest.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60a8b95aa5d190a95a6b3c20515301f72d70646b7a0d7f927c7167b968fa1222
3
+ size 953997124
checkpoints/t2m/t2m_motiondiffuse/opt.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ------------ Options -------------
2
+ batch_size: 128
3
+ checkpoints_dir: ./checkpoints
4
+ dataset_name: t2m
5
+ decomp_name: Decomp_SP001_SM001_H512
6
+ dim_att_vec: 512
7
+ dim_dec_hidden: 1024
8
+ dim_movement_dec_hidden: 512
9
+ dim_movement_enc_hidden: 512
10
+ dim_movement_latent: 512
11
+ dim_pos_hidden: 1024
12
+ dim_pri_hidden: 1024
13
+ dim_text_hidden: 512
14
+ dim_z: 128
15
+ early_stop_count: 3
16
+ estimator_mod: bigru
17
+ eval_every_e: 5
18
+ feat_bias: 5
19
+ gpu_id: -1
20
+ is_continue: False
21
+ is_train: True
22
+ lambda_kld: 0.005
23
+ lambda_rec_mot: 1
24
+ lambda_rec_mov: 1
25
+ log_every: 50
26
+ lr: 0.0002
27
+ max_sub_epoch: 50
28
+ max_text_len: 20
29
+ n_layers_dec: 1
30
+ n_layers_pos: 1
31
+ n_layers_pri: 1
32
+ name: t2m_motiondiffuse
33
+ save_every_e: 10
34
+ save_latest: 500
35
+ text_enc_mod: bigru
36
+ times: 50
37
+ unit_length: 4
38
+ -------------- End ----------------
datasets/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import Text2MotionDataset
2
+ from .evaluator import (
3
+ EvaluationDataset,
4
+ get_dataset_motion_loader,
5
+ get_motion_loader,
6
+ EvaluatorModelWrapper)
7
+ from .dataloader import build_dataloader
8
+
9
+ __all__ = [
10
+ 'Text2MotionDataset', 'EvaluationDataset', 'build_dataloader',
11
+ 'get_dataset_motion_loader', 'get_motion_loader']
datasets/dataloader.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import random
3
+ from functools import partial
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ from mmcv.runner import get_dist_info
8
+ from mmcv.utils import Registry, build_from_cfg
9
+ from torch.utils.data import DataLoader
10
+ from torch.utils.data.dataset import Dataset
11
+
12
+ import torch
13
+ from torch.utils.data import DistributedSampler as _DistributedSampler
14
+
15
+
16
+ class DistributedSampler(_DistributedSampler):
17
+
18
+ def __init__(self,
19
+ dataset,
20
+ num_replicas=None,
21
+ rank=None,
22
+ shuffle=True,
23
+ round_up=True):
24
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank)
25
+ self.shuffle = shuffle
26
+ self.round_up = round_up
27
+ if self.round_up:
28
+ self.total_size = self.num_samples * self.num_replicas
29
+ else:
30
+ self.total_size = len(self.dataset)
31
+
32
+ def __iter__(self):
33
+ # deterministically shuffle based on epoch
34
+ if self.shuffle:
35
+ g = torch.Generator()
36
+ g.manual_seed(self.epoch)
37
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
38
+ else:
39
+ indices = torch.arange(len(self.dataset)).tolist()
40
+
41
+ # add extra samples to make it evenly divisible
42
+ if self.round_up:
43
+ indices = (
44
+ indices *
45
+ int(self.total_size / len(indices) + 1))[:self.total_size]
46
+ assert len(indices) == self.total_size
47
+
48
+ # subsample
49
+ indices = indices[self.rank:self.total_size:self.num_replicas]
50
+ if self.round_up:
51
+ assert len(indices) == self.num_samples
52
+
53
+ return iter(indices)
54
+
55
+
56
+ def build_dataloader(dataset: Dataset,
57
+ samples_per_gpu: int,
58
+ workers_per_gpu: int,
59
+ num_gpus: Optional[int] = 1,
60
+ dist: Optional[bool] = True,
61
+ shuffle: Optional[bool] = True,
62
+ round_up: Optional[bool] = True,
63
+ seed: Optional[Union[int, None]] = None,
64
+ persistent_workers: Optional[bool] = True,
65
+ **kwargs):
66
+ """Build PyTorch DataLoader.
67
+
68
+ In distributed training, each GPU/process has a dataloader.
69
+ In non-distributed training, there is only one dataloader for all GPUs.
70
+
71
+ Args:
72
+ dataset (:obj:`Dataset`): A PyTorch dataset.
73
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
74
+ batch size of each GPU.
75
+ workers_per_gpu (int): How many subprocesses to use for data loading
76
+ for each GPU.
77
+ num_gpus (int, optional): Number of GPUs. Only used in non-distributed
78
+ training.
79
+ dist (bool, optional): Distributed training/test or not. Default: True.
80
+ shuffle (bool, optional): Whether to shuffle the data at every epoch.
81
+ Default: True.
82
+ round_up (bool, optional): Whether to round up the length of dataset by
83
+ adding extra samples to make it evenly divisible. Default: True.
84
+ persistent_workers (bool): If True, the data loader will not shutdown
85
+ the worker processes after a dataset has been consumed once.
86
+ This allows to maintain the workers Dataset instances alive.
87
+ The argument also has effect in PyTorch>=1.7.0.
88
+ Default: True
89
+ kwargs: any keyword argument to be used to initialize DataLoader
90
+
91
+ Returns:
92
+ DataLoader: A PyTorch dataloader.
93
+ """
94
+ rank, world_size = get_dist_info()
95
+ if dist:
96
+ sampler = DistributedSampler(
97
+ dataset, world_size, rank, shuffle=shuffle, round_up=round_up)
98
+ shuffle = False
99
+ batch_size = samples_per_gpu
100
+ num_workers = workers_per_gpu
101
+ else:
102
+ sampler = None
103
+ batch_size = num_gpus * samples_per_gpu
104
+ num_workers = num_gpus * workers_per_gpu
105
+
106
+ init_fn = partial(
107
+ worker_init_fn, num_workers=num_workers, rank=rank,
108
+ seed=seed) if seed is not None else None
109
+
110
+ data_loader = DataLoader(
111
+ dataset,
112
+ batch_size=batch_size,
113
+ sampler=sampler,
114
+ num_workers=num_workers,
115
+ pin_memory=False,
116
+ shuffle=shuffle,
117
+ worker_init_fn=init_fn,
118
+ persistent_workers=persistent_workers,
119
+ **kwargs)
120
+
121
+ return data_loader
122
+
123
+
124
+ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
125
+ """Init random seed for each worker."""
126
+ # The seed of each worker equals to
127
+ # num_worker * rank + worker_id + user_seed
128
+ worker_seed = num_workers * rank + worker_id + seed
129
+ np.random.seed(worker_seed)
130
+ random.seed(worker_seed)
datasets/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ import os
5
+ from os.path import join as pjoin
6
+ import random
7
+ import codecs as cs
8
+ from tqdm import tqdm
9
+
10
+
11
+ class Text2MotionDataset(data.Dataset):
12
+ """Dataset for Text2Motion generation task.
13
+
14
+ """
15
+ def __init__(self, opt, mean, std, split_file, times=1, w_vectorizer=None, eval_mode=False):
16
+ self.opt = opt
17
+ self.max_length = 20
18
+ self.times = times
19
+ self.w_vectorizer = w_vectorizer
20
+ self.eval_mode = eval_mode
21
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
22
+
23
+ joints_num = opt.joints_num
24
+
25
+ data_dict = {}
26
+ id_list = []
27
+ with cs.open(split_file, 'r') as f:
28
+ for line in f.readlines():
29
+ id_list.append(line.strip())
30
+
31
+ new_name_list = []
32
+ length_list = []
33
+ for name in tqdm(id_list):
34
+ try:
35
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
36
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
37
+ continue
38
+ text_data = []
39
+ flag = False
40
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
41
+ for line in f.readlines():
42
+ text_dict = {}
43
+ line_split = line.strip().split('#')
44
+ caption = line_split[0]
45
+ tokens = line_split[1].split(' ')
46
+ f_tag = float(line_split[2])
47
+ to_tag = float(line_split[3])
48
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
49
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
50
+
51
+ text_dict['caption'] = caption
52
+ text_dict['tokens'] = tokens
53
+ if f_tag == 0.0 and to_tag == 0.0:
54
+ flag = True
55
+ text_data.append(text_dict)
56
+ else:
57
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
58
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
59
+ continue
60
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
61
+ while new_name in data_dict:
62
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
63
+ data_dict[new_name] = {'motion': n_motion,
64
+ 'length': len(n_motion),
65
+ 'text':[text_dict]}
66
+ new_name_list.append(new_name)
67
+ length_list.append(len(n_motion))
68
+
69
+ if flag:
70
+ data_dict[name] = {'motion': motion,
71
+ 'length': len(motion),
72
+ 'text':text_data}
73
+ new_name_list.append(name)
74
+ length_list.append(len(motion))
75
+ except:
76
+ # Some motion may not exist in KIT dataset
77
+ pass
78
+
79
+
80
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
81
+
82
+ if opt.is_train:
83
+ # root_rot_velocity (B, seq_len, 1)
84
+ std[0:1] = std[0:1] / opt.feat_bias
85
+ # root_linear_velocity (B, seq_len, 2)
86
+ std[1:3] = std[1:3] / opt.feat_bias
87
+ # root_y (B, seq_len, 1)
88
+ std[3:4] = std[3:4] / opt.feat_bias
89
+ # ric_data (B, seq_len, (joint_num - 1)*3)
90
+ std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0
91
+ # rot_data (B, seq_len, (joint_num - 1)*6)
92
+ std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + (
93
+ joints_num - 1) * 9] / 1.0
94
+ # local_velocity (B, seq_len, joint_num*3)
95
+ std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[
96
+ 4 + (joints_num - 1) * 9: 4 + (
97
+ joints_num - 1) * 9 + joints_num * 3] / 1.0
98
+ # foot contact (B, seq_len, 4)
99
+ std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[
100
+ 4 + (joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias
101
+
102
+ assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
103
+ np.save(pjoin(opt.meta_dir, 'mean.npy'), mean)
104
+ np.save(pjoin(opt.meta_dir, 'std.npy'), std)
105
+
106
+ self.mean = mean
107
+ self.std = std
108
+ self.length_arr = np.array(length_list)
109
+ self.data_dict = data_dict
110
+ self.name_list = name_list
111
+
112
+ def inv_transform(self, data):
113
+ return data * self.std + self.mean
114
+
115
+ def real_len(self):
116
+ return len(self.data_dict)
117
+
118
+ def __len__(self):
119
+ return self.real_len() * self.times
120
+
121
+ def __getitem__(self, item):
122
+ idx = item % self.real_len()
123
+ data = self.data_dict[self.name_list[idx]]
124
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
125
+ # Randomly select a caption
126
+ text_data = random.choice(text_list)
127
+ caption = text_data['caption']
128
+
129
+ max_motion_length = self.opt.max_motion_length
130
+ if m_length >= self.opt.max_motion_length:
131
+ idx = random.randint(0, len(motion) - max_motion_length)
132
+ motion = motion[idx: idx + max_motion_length]
133
+ else:
134
+ padding_len = max_motion_length - m_length
135
+ D = motion.shape[1]
136
+ padding_zeros = np.zeros((padding_len, D))
137
+ motion = np.concatenate((motion, padding_zeros), axis=0)
138
+
139
+ assert len(motion) == max_motion_length
140
+ "Z Normalization"
141
+ motion = (motion - self.mean) / self.std
142
+
143
+ if self.eval_mode:
144
+ tokens = text_data['tokens']
145
+ if len(tokens) < self.opt.max_text_len:
146
+ # pad with "unk"
147
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
148
+ sent_len = len(tokens)
149
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
150
+ else:
151
+ # crop
152
+ tokens = tokens[:self.opt.max_text_len]
153
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
154
+ sent_len = len(tokens)
155
+ pos_one_hots = []
156
+ word_embeddings = []
157
+ for token in tokens:
158
+ word_emb, pos_oh = self.w_vectorizer[token]
159
+ pos_one_hots.append(pos_oh[None, :])
160
+ word_embeddings.append(word_emb[None, :])
161
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
162
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
163
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length
164
+ return caption, motion, m_length
datasets/evaluator.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.word_vectorizer import WordVectorizer, POS_enumerator
3
+ from utils.get_opt import get_opt
4
+ from models import MotionTransformer
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from os.path import join as pjoin
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+ from .evaluator_models import *
10
+ import os
11
+ import codecs as cs
12
+ import random
13
+ from torch.utils.data._utils.collate import default_collate
14
+
15
+
16
+ class EvaluationDataset(Dataset):
17
+
18
+ def __init__(self, opt, trainer, dataset, w_vectorizer, mm_num_samples, mm_num_repeats):
19
+ assert mm_num_samples < len(dataset)
20
+ print(opt.model_dir)
21
+
22
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
23
+ epoch, it = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar'))
24
+
25
+ generated_motion = []
26
+ min_mov_length = 10 if opt.dataset_name == 't2m' else 6
27
+
28
+ trainer.eval_mode()
29
+ trainer.to(opt.device)
30
+
31
+ # Pre-process all target captions
32
+ mm_generated_motions = []
33
+ mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False)
34
+ mm_idxs = np.sort(mm_idxs)
35
+ all_caption = []
36
+ all_m_lens = []
37
+ all_data = []
38
+ with torch.no_grad():
39
+ for i, data in tqdm(enumerate(dataloader)):
40
+ word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
41
+ all_data.append(data)
42
+ tokens = tokens[0].split('_')
43
+ mm_num_now = len(mm_generated_motions)
44
+ is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False
45
+ repeat_times = mm_num_repeats if is_mm else 1
46
+ m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length)
47
+ m_lens = min(m_lens, opt.max_motion_length)
48
+ if isinstance(m_lens, int):
49
+ m_lens = torch.LongTensor([m_lens]).to(opt.device)
50
+ else:
51
+ m_lens = m_lens.to(opt.device)
52
+ for t in range(repeat_times):
53
+ all_m_lens.append(m_lens)
54
+ all_caption.extend(caption)
55
+ if is_mm:
56
+ mm_generated_motions.append(0)
57
+ all_m_lens = torch.stack(all_m_lens)
58
+
59
+ # Generate all sequences
60
+ with torch.no_grad():
61
+ all_pred_motions = trainer.generate(all_caption, all_m_lens, opt.dim_pose)
62
+
63
+ cur_idx = 0
64
+ mm_generated_motions = []
65
+ with torch.no_grad():
66
+ for i, data_dummy in tqdm(enumerate(dataloader)):
67
+ data = all_data[i]
68
+ word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
69
+ tokens = tokens[0].split('_')
70
+ mm_num_now = len(mm_generated_motions)
71
+ is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False
72
+ repeat_times = mm_num_repeats if is_mm else 1
73
+ mm_motions = []
74
+ m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length)
75
+ m_lens = min(m_lens, opt.max_motion_length)
76
+ if isinstance(m_lens, int):
77
+ m_lens = torch.LongTensor([m_lens]).to(opt.device)
78
+ else:
79
+ m_lens = m_lens.to(opt.device)
80
+ for t in range(repeat_times):
81
+ m_len = m_lens[0].item()
82
+ pred_motions = all_pred_motions[cur_idx][:m_lens[0].item()]
83
+ assert pred_motions.shape[0] == m_lens[0].item()
84
+ cur_idx += 1
85
+ if t == 0:
86
+ sub_dict = {'motion': pred_motions.cpu().numpy(),
87
+ 'length': pred_motions.shape[0],
88
+ 'caption': caption[0],
89
+ 'cap_len': cap_lens[0].item(),
90
+ 'tokens': tokens}
91
+ generated_motion.append(sub_dict)
92
+
93
+ if is_mm:
94
+ mm_motions.append({
95
+ 'motion': pred_motions.cpu().numpy(),
96
+ 'length': m_lens[0].item()
97
+ })
98
+ if is_mm:
99
+ mm_generated_motions.append({'caption': caption[0],
100
+ 'tokens': tokens,
101
+ 'cap_len': cap_lens[0].item(),
102
+ 'mm_motions': mm_motions})
103
+ self.generated_motion = generated_motion
104
+ self.mm_generated_motion = mm_generated_motions
105
+ self.opt = opt
106
+ self.w_vectorizer = w_vectorizer
107
+
108
+
109
+ def __len__(self):
110
+ return len(self.generated_motion)
111
+
112
+
113
+ def __getitem__(self, item):
114
+ data = self.generated_motion[item]
115
+ motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
116
+ sent_len = data['cap_len']
117
+ pos_one_hots = []
118
+ word_embeddings = []
119
+ for token in tokens:
120
+ word_emb, pos_oh = self.w_vectorizer[token]
121
+ pos_one_hots.append(pos_oh[None, :])
122
+ word_embeddings.append(word_emb[None, :])
123
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
124
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
125
+
126
+ if m_length < self.opt.max_motion_length:
127
+ motion = np.concatenate([motion,
128
+ np.zeros((self.opt.max_motion_length - m_length, motion.shape[1]))
129
+ ], axis=0)
130
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
131
+
132
+
133
+ def collate_fn(batch):
134
+ batch.sort(key=lambda x: x[3], reverse=True)
135
+ return default_collate(batch)
136
+
137
+
138
+ '''For use of training text motion matching model, and evaluations'''
139
+ class Text2MotionDatasetV2(Dataset):
140
+ def __init__(self, opt, mean, std, split_file, w_vectorizer):
141
+ self.opt = opt
142
+ self.w_vectorizer = w_vectorizer
143
+ self.max_length = 20
144
+ self.pointer = 0
145
+ self.max_motion_length = opt.max_motion_length
146
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
147
+
148
+ data_dict = {}
149
+ id_list = []
150
+ with cs.open(split_file, 'r') as f:
151
+ for line in f.readlines():
152
+ id_list.append(line.strip())
153
+
154
+ new_name_list = []
155
+ length_list = []
156
+ for name in tqdm(id_list):
157
+ try:
158
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
159
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
160
+ continue
161
+ text_data = []
162
+ flag = False
163
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
164
+ for line in f.readlines():
165
+ text_dict = {}
166
+ line_split = line.strip().split('#')
167
+ caption = line_split[0]
168
+ tokens = line_split[1].split(' ')
169
+ f_tag = float(line_split[2])
170
+ to_tag = float(line_split[3])
171
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
172
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
173
+
174
+ text_dict['caption'] = caption
175
+ text_dict['tokens'] = tokens
176
+ if f_tag == 0.0 and to_tag == 0.0:
177
+ flag = True
178
+ text_data.append(text_dict)
179
+ else:
180
+ try:
181
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
182
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
183
+ continue
184
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
185
+ while new_name in data_dict:
186
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
187
+ data_dict[new_name] = {'motion': n_motion,
188
+ 'length': len(n_motion),
189
+ 'text':[text_dict]}
190
+ new_name_list.append(new_name)
191
+ length_list.append(len(n_motion))
192
+ except:
193
+ print(line_split)
194
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
195
+ # break
196
+
197
+ if flag:
198
+ data_dict[name] = {'motion': motion,
199
+ 'length': len(motion),
200
+ 'text': text_data}
201
+ new_name_list.append(name)
202
+ length_list.append(len(motion))
203
+ except:
204
+ pass
205
+
206
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
207
+
208
+ self.mean = mean
209
+ self.std = std
210
+ self.length_arr = np.array(length_list)
211
+ self.data_dict = data_dict
212
+ self.name_list = name_list
213
+ self.reset_max_len(self.max_length)
214
+
215
+ def reset_max_len(self, length):
216
+ assert length <= self.max_motion_length
217
+ self.pointer = np.searchsorted(self.length_arr, length)
218
+ print("Pointer Pointing at %d"%self.pointer)
219
+ self.max_length = length
220
+
221
+ def inv_transform(self, data):
222
+ return data * self.std + self.mean
223
+
224
+ def __len__(self):
225
+ return len(self.data_dict) - self.pointer
226
+
227
+ def __getitem__(self, item):
228
+ idx = self.pointer + item
229
+ data = self.data_dict[self.name_list[idx]]
230
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
231
+ # Randomly select a caption
232
+ text_data = random.choice(text_list)
233
+ caption, tokens = text_data['caption'], text_data['tokens']
234
+
235
+ if len(tokens) < self.opt.max_text_len:
236
+ # pad with "unk"
237
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
238
+ sent_len = len(tokens)
239
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
240
+ else:
241
+ # crop
242
+ tokens = tokens[:self.opt.max_text_len]
243
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
244
+ sent_len = len(tokens)
245
+ pos_one_hots = []
246
+ word_embeddings = []
247
+ for token in tokens:
248
+ word_emb, pos_oh = self.w_vectorizer[token]
249
+ pos_one_hots.append(pos_oh[None, :])
250
+ word_embeddings.append(word_emb[None, :])
251
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
252
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
253
+
254
+ # Crop the motions in to times of 4, and introduce small variations
255
+ if self.opt.unit_length < 10:
256
+ coin2 = np.random.choice(['single', 'single', 'double'])
257
+ else:
258
+ coin2 = 'single'
259
+
260
+ if coin2 == 'double':
261
+ m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
262
+ elif coin2 == 'single':
263
+ m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
264
+ idx = random.randint(0, len(motion) - m_length)
265
+ motion = motion[idx:idx+m_length]
266
+
267
+ "Z Normalization"
268
+ motion = (motion - self.mean) / self.std
269
+
270
+ if m_length < self.max_motion_length:
271
+ motion = np.concatenate([motion,
272
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
273
+ ], axis=0)
274
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
275
+
276
+
277
+ def get_dataset_motion_loader(opt_path, batch_size, device):
278
+ opt = get_opt(opt_path, device)
279
+
280
+ # Configurations of T2M dataset and KIT dataset is almost the same
281
+ if opt.dataset_name == 't2m' or opt.dataset_name == 'kit':
282
+ print('Loading dataset %s ...' % opt.dataset_name)
283
+
284
+ mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
285
+ std = np.load(pjoin(opt.meta_dir, 'std.npy'))
286
+
287
+ w_vectorizer = WordVectorizer('./data/glove', 'our_vab')
288
+ split_file = pjoin(opt.data_root, 'test.txt')
289
+ dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer)
290
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True,
291
+ collate_fn=collate_fn, shuffle=True)
292
+ else:
293
+ raise KeyError('Dataset not Recognized !!')
294
+
295
+ print('Ground Truth Dataset Loading Completed!!!')
296
+ return dataloader, dataset
297
+
298
+
299
+ class MMGeneratedDataset(Dataset):
300
+ def __init__(self, opt, motion_dataset, w_vectorizer):
301
+ self.opt = opt
302
+ self.dataset = motion_dataset.mm_generated_motion
303
+ self.w_vectorizer = w_vectorizer
304
+
305
+ def __len__(self):
306
+ return len(self.dataset)
307
+
308
+ def __getitem__(self, item):
309
+ data = self.dataset[item]
310
+ mm_motions = data['mm_motions']
311
+ m_lens = []
312
+ motions = []
313
+ for mm_motion in mm_motions:
314
+ m_lens.append(mm_motion['length'])
315
+ motion = mm_motion['motion']
316
+ if len(motion) < self.opt.max_motion_length:
317
+ motion = np.concatenate([motion,
318
+ np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1]))
319
+ ], axis=0)
320
+ motion = motion[None, :]
321
+ motions.append(motion)
322
+ m_lens = np.array(m_lens, dtype=np.int)
323
+ motions = np.concatenate(motions, axis=0)
324
+ sort_indx = np.argsort(m_lens)[::-1].copy()
325
+ # print(m_lens)
326
+ # print(sort_indx)
327
+ # print(m_lens[sort_indx])
328
+ m_lens = m_lens[sort_indx]
329
+ motions = motions[sort_indx]
330
+ return motions, m_lens
331
+
332
+
333
+
334
+ def get_motion_loader(opt, batch_size, trainer, ground_truth_dataset, mm_num_samples, mm_num_repeats):
335
+
336
+ # Currently the configurations of two datasets are almost the same
337
+ if opt.dataset_name == 't2m' or opt.dataset_name == 'kit':
338
+ w_vectorizer = WordVectorizer('./data/glove', 'our_vab')
339
+ else:
340
+ raise KeyError('Dataset not recognized!!')
341
+ print('Generating %s ...' % opt.name)
342
+
343
+ dataset = EvaluationDataset(opt, trainer, ground_truth_dataset, w_vectorizer, mm_num_samples, mm_num_repeats)
344
+ mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer)
345
+
346
+ motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4)
347
+ mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1)
348
+
349
+ print('Generated Dataset Loading Completed!!!')
350
+
351
+ return motion_loader, mm_motion_loader
352
+
353
+
354
+ def build_models(opt):
355
+ movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
356
+ text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
357
+ pos_size=opt.dim_pos_ohot,
358
+ hidden_size=opt.dim_text_hidden,
359
+ output_size=opt.dim_coemb_hidden,
360
+ device=opt.device)
361
+
362
+ motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
363
+ hidden_size=opt.dim_motion_hidden,
364
+ output_size=opt.dim_coemb_hidden,
365
+ device=opt.device)
366
+
367
+ checkpoint = torch.load(pjoin('data/pretrained_models', opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
368
+ map_location=opt.device)
369
+ movement_enc.load_state_dict(checkpoint['movement_encoder'])
370
+ text_enc.load_state_dict(checkpoint['text_encoder'])
371
+ motion_enc.load_state_dict(checkpoint['motion_encoder'])
372
+ print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
373
+ return text_enc, motion_enc, movement_enc
374
+
375
+
376
+ class EvaluatorModelWrapper(object):
377
+
378
+ def __init__(self, opt):
379
+
380
+ if opt.dataset_name == 't2m':
381
+ opt.dim_pose = 263
382
+ elif opt.dataset_name == 'kit':
383
+ opt.dim_pose = 251
384
+ else:
385
+ raise KeyError('Dataset not Recognized!!!')
386
+
387
+ opt.dim_word = 300
388
+ opt.max_motion_length = 196
389
+ opt.dim_pos_ohot = len(POS_enumerator)
390
+ opt.dim_motion_hidden = 1024
391
+ opt.max_text_len = 20
392
+ opt.dim_text_hidden = 512
393
+ opt.dim_coemb_hidden = 512
394
+
395
+ self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
396
+ self.opt = opt
397
+ self.device = opt.device
398
+
399
+ self.text_encoder.to(opt.device)
400
+ self.motion_encoder.to(opt.device)
401
+ self.movement_encoder.to(opt.device)
402
+
403
+ self.text_encoder.eval()
404
+ self.motion_encoder.eval()
405
+ self.movement_encoder.eval()
406
+
407
+ # Please note that the results does not following the order of inputs
408
+ def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
409
+ with torch.no_grad():
410
+ word_embs = word_embs.detach().to(self.device).float()
411
+ pos_ohot = pos_ohot.detach().to(self.device).float()
412
+ motions = motions.detach().to(self.device).float()
413
+
414
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
415
+ motions = motions[align_idx]
416
+ m_lens = m_lens[align_idx]
417
+
418
+ '''Movement Encoding'''
419
+ movements = self.movement_encoder(motions[..., :-4]).detach()
420
+ m_lens = m_lens // self.opt.unit_length
421
+ motion_embedding = self.motion_encoder(movements, m_lens)
422
+
423
+ '''Text Encoding'''
424
+ text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
425
+ text_embedding = text_embedding[align_idx]
426
+ return text_embedding, motion_embedding
427
+
428
+ # Please note that the results does not following the order of inputs
429
+ def get_motion_embeddings(self, motions, m_lens):
430
+ with torch.no_grad():
431
+ motions = motions.detach().to(self.device).float()
432
+
433
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
434
+ motions = motions[align_idx]
435
+ m_lens = m_lens[align_idx]
436
+
437
+ '''Movement Encoding'''
438
+ movements = self.movement_encoder(motions[..., :-4]).detach()
439
+ m_lens = m_lens // self.opt.unit_length
440
+ motion_embedding = self.motion_encoder(movements, m_lens)
441
+ return motion_embedding
datasets/evaluator_models.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import time
5
+ import math
6
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
7
+ # from networks.layers import *
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class ContrastiveLoss(torch.nn.Module):
12
+ """
13
+ Contrastive loss function.
14
+ Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
15
+ """
16
+ def __init__(self, margin=3.0):
17
+ super(ContrastiveLoss, self).__init__()
18
+ self.margin = margin
19
+
20
+ def forward(self, output1, output2, label):
21
+ euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
22
+ loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
23
+ (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
24
+ return loss_contrastive
25
+
26
+
27
+ def init_weight(m):
28
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
29
+ nn.init.xavier_normal_(m.weight)
30
+ # m.bias.data.fill_(0.01)
31
+ if m.bias is not None:
32
+ nn.init.constant_(m.bias, 0)
33
+
34
+
35
+ def reparameterize(mu, logvar):
36
+ s_var = logvar.mul(0.5).exp_()
37
+ eps = s_var.data.new(s_var.size()).normal_()
38
+ return eps.mul(s_var).add_(mu)
39
+
40
+
41
+ # batch_size, dimension and position
42
+ # output: (batch_size, dim)
43
+ def positional_encoding(batch_size, dim, pos):
44
+ assert batch_size == pos.shape[0]
45
+ positions_enc = np.array([
46
+ [pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)]
47
+ for j in range(batch_size)
48
+ ], dtype=np.float32)
49
+ positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2])
50
+ positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2])
51
+ return torch.from_numpy(positions_enc).float()
52
+
53
+
54
+ def get_padding_mask(batch_size, seq_len, cap_lens):
55
+ cap_lens = cap_lens.data.tolist()
56
+ mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32)
57
+ for i, cap_len in enumerate(cap_lens):
58
+ mask_2d[i, :, :cap_len] = 0
59
+ return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone()
60
+
61
+
62
+ class PositionalEncoding(nn.Module):
63
+
64
+ def __init__(self, d_model, max_len=300):
65
+ super(PositionalEncoding, self).__init__()
66
+
67
+ pe = torch.zeros(max_len, d_model)
68
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
69
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
70
+ pe[:, 0::2] = torch.sin(position * div_term)
71
+ pe[:, 1::2] = torch.cos(position * div_term)
72
+ # pe = pe.unsqueeze(0).transpose(0, 1)
73
+ self.register_buffer('pe', pe)
74
+
75
+ def forward(self, pos):
76
+ return self.pe[pos]
77
+
78
+
79
+ class MovementConvEncoder(nn.Module):
80
+ def __init__(self, input_size, hidden_size, output_size):
81
+ super(MovementConvEncoder, self).__init__()
82
+ self.main = nn.Sequential(
83
+ nn.Conv1d(input_size, hidden_size, 4, 2, 1),
84
+ nn.Dropout(0.2, inplace=True),
85
+ nn.LeakyReLU(0.2, inplace=True),
86
+ nn.Conv1d(hidden_size, output_size, 4, 2, 1),
87
+ nn.Dropout(0.2, inplace=True),
88
+ nn.LeakyReLU(0.2, inplace=True),
89
+ )
90
+ self.out_net = nn.Linear(output_size, output_size)
91
+ self.main.apply(init_weight)
92
+ self.out_net.apply(init_weight)
93
+
94
+ def forward(self, inputs):
95
+ inputs = inputs.permute(0, 2, 1)
96
+ outputs = self.main(inputs).permute(0, 2, 1)
97
+ # print(outputs.shape)
98
+ return self.out_net(outputs)
99
+
100
+
101
+ class MovementConvDecoder(nn.Module):
102
+ def __init__(self, input_size, hidden_size, output_size):
103
+ super(MovementConvDecoder, self).__init__()
104
+ self.main = nn.Sequential(
105
+ nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
106
+ # nn.Dropout(0.2, inplace=True),
107
+ nn.LeakyReLU(0.2, inplace=True),
108
+ nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
109
+ # nn.Dropout(0.2, inplace=True),
110
+ nn.LeakyReLU(0.2, inplace=True),
111
+ )
112
+ self.out_net = nn.Linear(output_size, output_size)
113
+
114
+ self.main.apply(init_weight)
115
+ self.out_net.apply(init_weight)
116
+
117
+ def forward(self, inputs):
118
+ inputs = inputs.permute(0, 2, 1)
119
+ outputs = self.main(inputs).permute(0, 2, 1)
120
+ return self.out_net(outputs)
121
+
122
+
123
+ class TextVAEDecoder(nn.Module):
124
+ def __init__(self, text_size, input_size, output_size, hidden_size, n_layers):
125
+ super(TextVAEDecoder, self).__init__()
126
+ self.input_size = input_size
127
+ self.output_size = output_size
128
+ self.hidden_size = hidden_size
129
+ self.n_layers = n_layers
130
+ self.emb = nn.Sequential(
131
+ nn.Linear(input_size, hidden_size),
132
+ nn.LayerNorm(hidden_size),
133
+ nn.LeakyReLU(0.2, inplace=True))
134
+
135
+ self.z2init = nn.Linear(text_size, hidden_size * n_layers)
136
+ self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)])
137
+ self.positional_encoder = PositionalEncoding(hidden_size)
138
+
139
+
140
+ self.output = nn.Sequential(
141
+ nn.Linear(hidden_size, hidden_size),
142
+ nn.LayerNorm(hidden_size),
143
+ nn.LeakyReLU(0.2, inplace=True),
144
+ nn.Linear(hidden_size, output_size)
145
+ )
146
+
147
+ #
148
+ # self.output = nn.Sequential(
149
+ # nn.Linear(hidden_size, hidden_size),
150
+ # nn.LayerNorm(hidden_size),
151
+ # nn.LeakyReLU(0.2, inplace=True),
152
+ # nn.Linear(hidden_size, output_size-4)
153
+ # )
154
+
155
+ # self.contact_net = nn.Sequential(
156
+ # nn.Linear(output_size-4, 64),
157
+ # nn.LayerNorm(64),
158
+ # nn.LeakyReLU(0.2, inplace=True),
159
+ # nn.Linear(64, 4)
160
+ # )
161
+
162
+ self.output.apply(init_weight)
163
+ self.emb.apply(init_weight)
164
+ self.z2init.apply(init_weight)
165
+ # self.contact_net.apply(init_weight)
166
+
167
+ def get_init_hidden(self, latent):
168
+ hidden = self.z2init(latent)
169
+ hidden = torch.split(hidden, self.hidden_size, dim=-1)
170
+ return list(hidden)
171
+
172
+ def forward(self, inputs, last_pred, hidden, p):
173
+ h_in = self.emb(inputs)
174
+ pos_enc = self.positional_encoder(p).to(inputs.device).detach()
175
+ h_in = h_in + pos_enc
176
+ for i in range(self.n_layers):
177
+ # print(h_in.shape)
178
+ hidden[i] = self.gru[i](h_in, hidden[i])
179
+ h_in = hidden[i]
180
+ pose_pred = self.output(h_in)
181
+ # pose_pred = self.output(h_in) + last_pred.detach()
182
+ # contact = self.contact_net(pose_pred)
183
+ # return torch.cat([pose_pred, contact], dim=-1), hidden
184
+ return pose_pred, hidden
185
+
186
+
187
+ class TextDecoder(nn.Module):
188
+ def __init__(self, text_size, input_size, output_size, hidden_size, n_layers):
189
+ super(TextDecoder, self).__init__()
190
+ self.input_size = input_size
191
+ self.output_size = output_size
192
+ self.hidden_size = hidden_size
193
+ self.n_layers = n_layers
194
+ self.emb = nn.Sequential(
195
+ nn.Linear(input_size, hidden_size),
196
+ nn.LayerNorm(hidden_size),
197
+ nn.LeakyReLU(0.2, inplace=True))
198
+
199
+ self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)])
200
+ self.z2init = nn.Linear(text_size, hidden_size * n_layers)
201
+ self.positional_encoder = PositionalEncoding(hidden_size)
202
+
203
+ self.mu_net = nn.Linear(hidden_size, output_size)
204
+ self.logvar_net = nn.Linear(hidden_size, output_size)
205
+
206
+ self.emb.apply(init_weight)
207
+ self.z2init.apply(init_weight)
208
+ self.mu_net.apply(init_weight)
209
+ self.logvar_net.apply(init_weight)
210
+
211
+ def get_init_hidden(self, latent):
212
+
213
+ hidden = self.z2init(latent)
214
+ hidden = torch.split(hidden, self.hidden_size, dim=-1)
215
+
216
+ return list(hidden)
217
+
218
+ def forward(self, inputs, hidden, p):
219
+ # print(inputs.shape)
220
+ x_in = self.emb(inputs)
221
+ pos_enc = self.positional_encoder(p).to(inputs.device).detach()
222
+ x_in = x_in + pos_enc
223
+
224
+ for i in range(self.n_layers):
225
+ hidden[i] = self.gru[i](x_in, hidden[i])
226
+ h_in = hidden[i]
227
+ mu = self.mu_net(h_in)
228
+ logvar = self.logvar_net(h_in)
229
+ z = reparameterize(mu, logvar)
230
+ return z, mu, logvar, hidden
231
+
232
+ class AttLayer(nn.Module):
233
+ def __init__(self, query_dim, key_dim, value_dim):
234
+ super(AttLayer, self).__init__()
235
+ self.W_q = nn.Linear(query_dim, value_dim)
236
+ self.W_k = nn.Linear(key_dim, value_dim, bias=False)
237
+ self.W_v = nn.Linear(key_dim, value_dim)
238
+
239
+ self.softmax = nn.Softmax(dim=1)
240
+ self.dim = value_dim
241
+
242
+ self.W_q.apply(init_weight)
243
+ self.W_k.apply(init_weight)
244
+ self.W_v.apply(init_weight)
245
+
246
+ def forward(self, query, key_mat):
247
+ '''
248
+ query (batch, query_dim)
249
+ key (batch, seq_len, key_dim)
250
+ '''
251
+ # print(query.shape)
252
+ query_vec = self.W_q(query).unsqueeze(-1) # (batch, value_dim, 1)
253
+ val_set = self.W_v(key_mat) # (batch, seq_len, value_dim)
254
+ key_set = self.W_k(key_mat) # (batch, seq_len, value_dim)
255
+
256
+ weights = torch.matmul(key_set, query_vec) / np.sqrt(self.dim)
257
+
258
+ co_weights = self.softmax(weights) # (batch, seq_len, 1)
259
+ values = val_set * co_weights # (batch, seq_len, value_dim)
260
+ pred = values.sum(dim=1) # (batch, value_dim)
261
+ return pred, co_weights
262
+
263
+ def short_cut(self, querys, keys):
264
+ return self.W_q(querys), self.W_k(keys)
265
+
266
+
267
+ class TextEncoderBiGRU(nn.Module):
268
+ def __init__(self, word_size, pos_size, hidden_size, device):
269
+ super(TextEncoderBiGRU, self).__init__()
270
+ self.device = device
271
+
272
+ self.pos_emb = nn.Linear(pos_size, word_size)
273
+ self.input_emb = nn.Linear(word_size, hidden_size)
274
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
275
+ # self.linear2 = nn.Linear(hidden_size, output_size)
276
+
277
+ self.input_emb.apply(init_weight)
278
+ self.pos_emb.apply(init_weight)
279
+ # self.linear2.apply(init_weight)
280
+ # self.batch_size = batch_size
281
+ self.hidden_size = hidden_size
282
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
283
+
284
+ # input(batch_size, seq_len, dim)
285
+ def forward(self, word_embs, pos_onehot, cap_lens):
286
+ num_samples = word_embs.shape[0]
287
+
288
+ pos_embs = self.pos_emb(pos_onehot)
289
+ inputs = word_embs + pos_embs
290
+ input_embs = self.input_emb(inputs)
291
+ hidden = self.hidden.repeat(1, num_samples, 1)
292
+
293
+ cap_lens = cap_lens.data.tolist()
294
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
295
+
296
+ gru_seq, gru_last = self.gru(emb, hidden)
297
+
298
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
299
+ gru_seq = pad_packed_sequence(gru_seq, batch_first=True)[0]
300
+ forward_seq = gru_seq[..., :self.hidden_size]
301
+ backward_seq = gru_seq[..., self.hidden_size:].clone()
302
+
303
+ # Concate the forward and backward word embeddings
304
+ for i, length in enumerate(cap_lens):
305
+ backward_seq[i:i+1, :length] = torch.flip(backward_seq[i:i+1, :length].clone(), dims=[1])
306
+ gru_seq = torch.cat([forward_seq, backward_seq], dim=-1)
307
+
308
+ return gru_seq, gru_last
309
+
310
+
311
+ class TextEncoderBiGRUCo(nn.Module):
312
+ def __init__(self, word_size, pos_size, hidden_size, output_size, device):
313
+ super(TextEncoderBiGRUCo, self).__init__()
314
+ self.device = device
315
+
316
+ self.pos_emb = nn.Linear(pos_size, word_size)
317
+ self.input_emb = nn.Linear(word_size, hidden_size)
318
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
319
+ self.output_net = nn.Sequential(
320
+ nn.Linear(hidden_size * 2, hidden_size),
321
+ nn.LayerNorm(hidden_size),
322
+ nn.LeakyReLU(0.2, inplace=True),
323
+ nn.Linear(hidden_size, output_size)
324
+ )
325
+
326
+ self.input_emb.apply(init_weight)
327
+ self.pos_emb.apply(init_weight)
328
+ self.output_net.apply(init_weight)
329
+ # self.linear2.apply(init_weight)
330
+ # self.batch_size = batch_size
331
+ self.hidden_size = hidden_size
332
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
333
+
334
+ # input(batch_size, seq_len, dim)
335
+ def forward(self, word_embs, pos_onehot, cap_lens):
336
+ num_samples = word_embs.shape[0]
337
+
338
+ pos_embs = self.pos_emb(pos_onehot)
339
+ inputs = word_embs + pos_embs
340
+ input_embs = self.input_emb(inputs)
341
+ hidden = self.hidden.repeat(1, num_samples, 1)
342
+
343
+ cap_lens = cap_lens.data.tolist()
344
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
345
+
346
+ gru_seq, gru_last = self.gru(emb, hidden)
347
+
348
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
349
+
350
+ return self.output_net(gru_last)
351
+
352
+
353
+ class MotionEncoderBiGRUCo(nn.Module):
354
+ def __init__(self, input_size, hidden_size, output_size, device):
355
+ super(MotionEncoderBiGRUCo, self).__init__()
356
+ self.device = device
357
+
358
+ self.input_emb = nn.Linear(input_size, hidden_size)
359
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
360
+ self.output_net = nn.Sequential(
361
+ nn.Linear(hidden_size*2, hidden_size),
362
+ nn.LayerNorm(hidden_size),
363
+ nn.LeakyReLU(0.2, inplace=True),
364
+ nn.Linear(hidden_size, output_size)
365
+ )
366
+
367
+ self.input_emb.apply(init_weight)
368
+ self.output_net.apply(init_weight)
369
+ self.hidden_size = hidden_size
370
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
371
+
372
+ # input(batch_size, seq_len, dim)
373
+ def forward(self, inputs, m_lens):
374
+ num_samples = inputs.shape[0]
375
+
376
+ input_embs = self.input_emb(inputs)
377
+ hidden = self.hidden.repeat(1, num_samples, 1)
378
+
379
+ cap_lens = m_lens.data.tolist()
380
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
381
+
382
+ gru_seq, gru_last = self.gru(emb, hidden)
383
+
384
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
385
+
386
+ return self.output_net(gru_last)
387
+
388
+
389
+ class MotionLenEstimatorBiGRU(nn.Module):
390
+ def __init__(self, word_size, pos_size, hidden_size, output_size):
391
+ super(MotionLenEstimatorBiGRU, self).__init__()
392
+
393
+ self.pos_emb = nn.Linear(pos_size, word_size)
394
+ self.input_emb = nn.Linear(word_size, hidden_size)
395
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
396
+ nd = 512
397
+ self.output = nn.Sequential(
398
+ nn.Linear(hidden_size*2, nd),
399
+ nn.LayerNorm(nd),
400
+ nn.LeakyReLU(0.2, inplace=True),
401
+
402
+ nn.Linear(nd, nd // 2),
403
+ nn.LayerNorm(nd // 2),
404
+ nn.LeakyReLU(0.2, inplace=True),
405
+
406
+ nn.Linear(nd // 2, nd // 4),
407
+ nn.LayerNorm(nd // 4),
408
+ nn.LeakyReLU(0.2, inplace=True),
409
+
410
+ nn.Linear(nd // 4, output_size)
411
+ )
412
+ # self.linear2 = nn.Linear(hidden_size, output_size)
413
+
414
+ self.input_emb.apply(init_weight)
415
+ self.pos_emb.apply(init_weight)
416
+ self.output.apply(init_weight)
417
+ # self.linear2.apply(init_weight)
418
+ # self.batch_size = batch_size
419
+ self.hidden_size = hidden_size
420
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
421
+
422
+ # input(batch_size, seq_len, dim)
423
+ def forward(self, word_embs, pos_onehot, cap_lens):
424
+ num_samples = word_embs.shape[0]
425
+
426
+ pos_embs = self.pos_emb(pos_onehot)
427
+ inputs = word_embs + pos_embs
428
+ input_embs = self.input_emb(inputs)
429
+ hidden = self.hidden.repeat(1, num_samples, 1)
430
+
431
+ cap_lens = cap_lens.data.tolist()
432
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
433
+
434
+ gru_seq, gru_last = self.gru(emb, hidden)
435
+
436
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
437
+
438
+ return self.output(gru_last)
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ from .transformer import MotionTransformer
2
+ from .gaussian_diffusion import GaussianDiffusion
3
+
4
+ __all__ = ['MotionTransformer', 'GaussianDiffusion']
models/gaussian_diffusion.py ADDED
@@ -0,0 +1,1145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is borrowed from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py
3
+ """
4
+
5
+ import enum
6
+ import math
7
+
8
+ import numpy as np
9
+ import torch as th
10
+
11
+
12
+ from abc import ABC, abstractmethod
13
+ import torch.distributed as dist
14
+
15
+
16
+ def create_named_schedule_sampler(name, diffusion):
17
+ """
18
+ Create a ScheduleSampler from a library of pre-defined samplers.
19
+ :param name: the name of the sampler.
20
+ :param diffusion: the diffusion object to sample for.
21
+ """
22
+ if name == "uniform":
23
+ return UniformSampler(diffusion)
24
+ elif name == "loss-second-moment":
25
+ return LossSecondMomentResampler(diffusion)
26
+ else:
27
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
28
+
29
+
30
+ class ScheduleSampler(ABC):
31
+ """
32
+ A distribution over timesteps in the diffusion process, intended to reduce
33
+ variance of the objective.
34
+ By default, samplers perform unbiased importance sampling, in which the
35
+ objective's mean is unchanged.
36
+ However, subclasses may override sample() to change how the resampled
37
+ terms are reweighted, allowing for actual changes in the objective.
38
+ """
39
+
40
+ @abstractmethod
41
+ def weights(self):
42
+ """
43
+ Get a numpy array of weights, one per diffusion step.
44
+ The weights needn't be normalized, but must be positive.
45
+ """
46
+
47
+ def sample(self, batch_size, device):
48
+ """
49
+ Importance-sample timesteps for a batch.
50
+ :param batch_size: the number of timesteps.
51
+ :param device: the torch device to save to.
52
+ :return: a tuple (timesteps, weights):
53
+ - timesteps: a tensor of timestep indices.
54
+ - weights: a tensor of weights to scale the resulting losses.
55
+ """
56
+ w = self.weights()
57
+ p = w / np.sum(w)
58
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
59
+ indices = th.from_numpy(indices_np).long().to(device)
60
+ weights_np = 1 / (len(p) * p[indices_np])
61
+ weights = th.from_numpy(weights_np).float().to(device)
62
+ return indices, weights
63
+
64
+
65
+ class UniformSampler(ScheduleSampler):
66
+ def __init__(self, diffusion):
67
+ self.diffusion = diffusion
68
+ self._weights = np.ones([diffusion.num_timesteps])
69
+
70
+ def weights(self):
71
+ return self._weights
72
+
73
+
74
+ class LossAwareSampler(ScheduleSampler):
75
+ def update_with_local_losses(self, local_ts, local_losses):
76
+ """
77
+ Update the reweighting using losses from a model.
78
+ Call this method from each rank with a batch of timesteps and the
79
+ corresponding losses for each of those timesteps.
80
+ This method will perform synchronization to make sure all of the ranks
81
+ maintain the exact same reweighting.
82
+ :param local_ts: an integer Tensor of timesteps.
83
+ :param local_losses: a 1D Tensor of losses.
84
+ """
85
+ batch_sizes = [
86
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
87
+ for _ in range(dist.get_world_size())
88
+ ]
89
+ dist.all_gather(
90
+ batch_sizes,
91
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
92
+ )
93
+
94
+ # Pad all_gather batches to be the maximum batch size.
95
+ batch_sizes = [x.item() for x in batch_sizes]
96
+ max_bs = max(batch_sizes)
97
+
98
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
99
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
100
+ dist.all_gather(timestep_batches, local_ts)
101
+ dist.all_gather(loss_batches, local_losses)
102
+ timesteps = [
103
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
104
+ ]
105
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
106
+ self.update_with_all_losses(timesteps, losses)
107
+
108
+ @abstractmethod
109
+ def update_with_all_losses(self, ts, losses):
110
+ """
111
+ Update the reweighting using losses from a model.
112
+ Sub-classes should override this method to update the reweighting
113
+ using losses from the model.
114
+ This method directly updates the reweighting without synchronizing
115
+ between workers. It is called by update_with_local_losses from all
116
+ ranks with identical arguments. Thus, it should have deterministic
117
+ behavior to maintain state across workers.
118
+ :param ts: a list of int timesteps.
119
+ :param losses: a list of float losses, one per timestep.
120
+ """
121
+
122
+
123
+ class LossSecondMomentResampler(LossAwareSampler):
124
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
125
+ self.diffusion = diffusion
126
+ self.history_per_term = history_per_term
127
+ self.uniform_prob = uniform_prob
128
+ self._loss_history = np.zeros(
129
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
130
+ )
131
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
132
+
133
+ def weights(self):
134
+ if not self._warmed_up():
135
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
136
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
137
+ weights /= np.sum(weights)
138
+ weights *= 1 - self.uniform_prob
139
+ weights += self.uniform_prob / len(weights)
140
+ return weights
141
+
142
+ def update_with_all_losses(self, ts, losses):
143
+ for t, loss in zip(ts, losses):
144
+ if self._loss_counts[t] == self.history_per_term:
145
+ # Shift out the oldest loss term.
146
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
147
+ self._loss_history[t, -1] = loss
148
+ else:
149
+ self._loss_history[t, self._loss_counts[t]] = loss
150
+ self._loss_counts[t] += 1
151
+
152
+ def _warmed_up(self):
153
+ return (self._loss_counts == self.history_per_term).all()
154
+
155
+
156
+ def mean_flat(tensor):
157
+ """
158
+ Take the mean over all non-batch dimensions.
159
+ """
160
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
161
+
162
+
163
+ def normal_kl(mean1, logvar1, mean2, logvar2):
164
+ """
165
+ Compute the KL divergence between two gaussians.
166
+ Shapes are automatically broadcasted, so batches can be compared to
167
+ scalars, among other use cases.
168
+ """
169
+ tensor = None
170
+ for obj in (mean1, logvar1, mean2, logvar2):
171
+ if isinstance(obj, th.Tensor):
172
+ tensor = obj
173
+ break
174
+ assert tensor is not None, "at least one argument must be a Tensor"
175
+
176
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
177
+ # Tensors, but it does not work for th.exp().
178
+ logvar1, logvar2 = [
179
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
180
+ for x in (logvar1, logvar2)
181
+ ]
182
+
183
+ return 0.5 * (
184
+ -1.0
185
+ + logvar2
186
+ - logvar1
187
+ + th.exp(logvar1 - logvar2)
188
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
189
+ )
190
+
191
+
192
+ def approx_standard_normal_cdf(x):
193
+ """
194
+ A fast approximation of the cumulative distribution function of the
195
+ standard normal.
196
+ """
197
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
198
+
199
+
200
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
201
+ """
202
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
203
+ given image.
204
+ :param x: the target images. It is assumed that this was uint8 values,
205
+ rescaled to the range [-1, 1].
206
+ :param means: the Gaussian mean Tensor.
207
+ :param log_scales: the Gaussian log stddev Tensor.
208
+ :return: a tensor like x of log probabilities (in nats).
209
+ """
210
+ assert x.shape == means.shape == log_scales.shape
211
+ centered_x = x - means
212
+ inv_stdv = th.exp(-log_scales)
213
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
214
+ cdf_plus = approx_standard_normal_cdf(plus_in)
215
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
216
+ cdf_min = approx_standard_normal_cdf(min_in)
217
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
218
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
219
+ cdf_delta = cdf_plus - cdf_min
220
+ log_probs = th.where(
221
+ x < -0.999,
222
+ log_cdf_plus,
223
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
224
+ )
225
+ assert log_probs.shape == x.shape
226
+ return log_probs
227
+
228
+
229
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
230
+ """
231
+ Get a pre-defined beta schedule for the given name.
232
+
233
+ The beta schedule library consists of beta schedules which remain similar
234
+ in the limit of num_diffusion_timesteps.
235
+ Beta schedules may be added, but should not be removed or changed once
236
+ they are committed to maintain backwards compatibility.
237
+ """
238
+ if schedule_name == "linear":
239
+ # Linear schedule from Ho et al, extended to work for any number of
240
+ # diffusion steps.
241
+ scale = 1000 / num_diffusion_timesteps
242
+ beta_start = scale * 0.0001
243
+ beta_end = scale * 0.02
244
+ return np.linspace(
245
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
246
+ )
247
+ elif schedule_name == "cosine":
248
+ return betas_for_alpha_bar(
249
+ num_diffusion_timesteps,
250
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
251
+ )
252
+ else:
253
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
254
+
255
+
256
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
257
+ """
258
+ Create a beta schedule that discretizes the given alpha_t_bar function,
259
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
260
+
261
+ :param num_diffusion_timesteps: the number of betas to produce.
262
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
263
+ produces the cumulative product of (1-beta) up to that
264
+ part of the diffusion process.
265
+ :param max_beta: the maximum beta to use; use values lower than 1 to
266
+ prevent singularities.
267
+ """
268
+ betas = []
269
+ for i in range(num_diffusion_timesteps):
270
+ t1 = i / num_diffusion_timesteps
271
+ t2 = (i + 1) / num_diffusion_timesteps
272
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
273
+ return np.array(betas)
274
+
275
+
276
+ class ModelMeanType(enum.Enum):
277
+ """
278
+ Which type of output the model predicts.
279
+ """
280
+
281
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
282
+ START_X = enum.auto() # the model predicts x_0
283
+ EPSILON = enum.auto() # the model predicts epsilon
284
+
285
+
286
+ class ModelVarType(enum.Enum):
287
+ """
288
+ What is used as the model's output variance.
289
+
290
+ The LEARNED_RANGE option has been added to allow the model to predict
291
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
292
+ """
293
+
294
+ LEARNED = enum.auto()
295
+ FIXED_SMALL = enum.auto()
296
+ FIXED_LARGE = enum.auto()
297
+ LEARNED_RANGE = enum.auto()
298
+
299
+
300
+ class LossType(enum.Enum):
301
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
302
+ RESCALED_MSE = (
303
+ enum.auto()
304
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
305
+ KL = enum.auto() # use the variational lower-bound
306
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
307
+
308
+ def is_vb(self):
309
+ return self == LossType.KL or self == LossType.RESCALED_KL
310
+
311
+
312
+ class GaussianDiffusion:
313
+ """
314
+ Utilities for training and sampling diffusion models.
315
+
316
+ Ported directly from here, and then adapted over time to further experimentation.
317
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
318
+
319
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
320
+ starting at T and going to 1.
321
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
322
+ :param model_var_type: a ModelVarType determining how variance is output.
323
+ :param loss_type: a LossType determining the loss function to use.
324
+ :param rescale_timesteps: if True, pass floating point timesteps into the
325
+ model so that they are always scaled like in the
326
+ original paper (0 to 1000).
327
+ """
328
+
329
+ def __init__(
330
+ self,
331
+ *,
332
+ betas,
333
+ model_mean_type,
334
+ model_var_type,
335
+ loss_type,
336
+ rescale_timesteps=False,
337
+ ):
338
+ self.model_mean_type = model_mean_type
339
+ self.model_var_type = model_var_type
340
+ self.loss_type = loss_type
341
+ self.rescale_timesteps = rescale_timesteps
342
+
343
+ # Use float64 for accuracy.
344
+ betas = np.array(betas, dtype=np.float64)
345
+ self.betas = betas
346
+ assert len(betas.shape) == 1, "betas must be 1-D"
347
+ assert (betas > 0).all() and (betas <= 1).all()
348
+
349
+ self.num_timesteps = int(betas.shape[0])
350
+
351
+ alphas = 1.0 - betas
352
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
353
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
354
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
355
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
356
+
357
+ # calculations for diffusion q(x_t | x_{t-1}) and others
358
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
359
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
360
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
361
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
362
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
363
+
364
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
365
+ self.posterior_variance = (
366
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
367
+ )
368
+ # log calculation clipped because the posterior variance is 0 at the
369
+ # beginning of the diffusion chain.
370
+ self.posterior_log_variance_clipped = np.log(
371
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
372
+ )
373
+ self.posterior_mean_coef1 = (
374
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
375
+ )
376
+ self.posterior_mean_coef2 = (
377
+ (1.0 - self.alphas_cumprod_prev)
378
+ * np.sqrt(alphas)
379
+ / (1.0 - self.alphas_cumprod)
380
+ )
381
+
382
+ def q_mean_variance(self, x_start, t):
383
+ """
384
+ Get the distribution q(x_t | x_0).
385
+
386
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
387
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
388
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
389
+ """
390
+ mean = (
391
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
392
+ )
393
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
394
+ log_variance = _extract_into_tensor(
395
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
396
+ )
397
+ return mean, variance, log_variance
398
+
399
+ def q_sample(self, x_start, t, noise=None):
400
+ """
401
+ Diffuse the data for a given number of diffusion steps.
402
+
403
+ In other words, sample from q(x_t | x_0).
404
+
405
+ :param x_start: the initial data batch.
406
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
407
+ :param noise: if specified, the split-out normal noise.
408
+ :return: A noisy version of x_start.
409
+ """
410
+ if noise is None:
411
+ noise = th.randn_like(x_start)
412
+ assert noise.shape == x_start.shape
413
+ return (
414
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
415
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
416
+ * noise
417
+ )
418
+
419
+ def q_posterior_mean_variance(self, x_start, x_t, t):
420
+ """
421
+ Compute the mean and variance of the diffusion posterior:
422
+
423
+ q(x_{t-1} | x_t, x_0)
424
+
425
+ """
426
+ assert x_start.shape == x_t.shape
427
+ posterior_mean = (
428
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
429
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
430
+ )
431
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
432
+ posterior_log_variance_clipped = _extract_into_tensor(
433
+ self.posterior_log_variance_clipped, t, x_t.shape
434
+ )
435
+ assert (
436
+ posterior_mean.shape[0]
437
+ == posterior_variance.shape[0]
438
+ == posterior_log_variance_clipped.shape[0]
439
+ == x_start.shape[0]
440
+ )
441
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
442
+
443
+ def p_mean_variance(
444
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
445
+ ):
446
+ """
447
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
448
+ the initial x, x_0.
449
+
450
+ :param model: the model, which takes a signal and a batch of timesteps
451
+ as input.
452
+ :param x: the [N x C x ...] tensor at time t.
453
+ :param t: a 1-D Tensor of timesteps.
454
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
455
+ :param denoised_fn: if not None, a function which applies to the
456
+ x_start prediction before it is used to sample. Applies before
457
+ clip_denoised.
458
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
459
+ pass to the model. This can be used for conditioning.
460
+ :return: a dict with the following keys:
461
+ - 'mean': the model mean output.
462
+ - 'variance': the model variance output.
463
+ - 'log_variance': the log of 'variance'.
464
+ - 'pred_xstart': the prediction for x_0.
465
+ """
466
+ if model_kwargs is None:
467
+ model_kwargs = {}
468
+
469
+ B, C = x.shape[:2]
470
+ assert t.shape == (B,)
471
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
472
+
473
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
474
+ assert model_output.shape == (B, 2 * C, *x.shape[2:])
475
+ model_output, model_var_values = th.split(model_output, C, dim=1)
476
+ if self.model_var_type == ModelVarType.LEARNED:
477
+ model_log_variance = model_var_values
478
+ model_variance = th.exp(model_log_variance)
479
+ else:
480
+ min_log = _extract_into_tensor(
481
+ self.posterior_log_variance_clipped, t, x.shape
482
+ )
483
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
484
+ # The model_var_values is [-1, 1] for [min_var, max_var].
485
+ frac = (model_var_values + 1) / 2
486
+ model_log_variance = frac * max_log + (1 - frac) * min_log
487
+ model_variance = th.exp(model_log_variance)
488
+ else:
489
+ model_variance, model_log_variance = {
490
+ # for fixedlarge, we set the initial (log-)variance like so
491
+ # to get a better decoder log likelihood.
492
+ ModelVarType.FIXED_LARGE: (
493
+ np.append(self.posterior_variance[1], self.betas[1:]),
494
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
495
+ ),
496
+ ModelVarType.FIXED_SMALL: (
497
+ self.posterior_variance,
498
+ self.posterior_log_variance_clipped,
499
+ ),
500
+ }[self.model_var_type]
501
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
502
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
503
+
504
+ def process_xstart(x):
505
+ if denoised_fn is not None:
506
+ x = denoised_fn(x)
507
+ if clip_denoised:
508
+ return x.clamp(-1, 1)
509
+ return x
510
+
511
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
512
+ pred_xstart = process_xstart(
513
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
514
+ )
515
+ model_mean = model_output
516
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
517
+ if self.model_mean_type == ModelMeanType.START_X:
518
+ pred_xstart = process_xstart(model_output)
519
+ else:
520
+ pred_xstart = process_xstart(
521
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
522
+ )
523
+ model_mean, _, _ = self.q_posterior_mean_variance(
524
+ x_start=pred_xstart, x_t=x, t=t
525
+ )
526
+ else:
527
+ raise NotImplementedError(self.model_mean_type)
528
+
529
+ assert (
530
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
531
+ )
532
+ return {
533
+ "mean": model_mean,
534
+ "variance": model_variance,
535
+ "log_variance": model_log_variance,
536
+ "pred_xstart": pred_xstart,
537
+ }
538
+
539
+ def _predict_xstart_from_eps(self, x_t, t, eps):
540
+ assert x_t.shape == eps.shape
541
+ return (
542
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
543
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
544
+ )
545
+
546
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
547
+ assert x_t.shape == xprev.shape
548
+ return ( # (xprev - coef2*x_t) / coef1
549
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
550
+ - _extract_into_tensor(
551
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
552
+ )
553
+ * x_t
554
+ )
555
+
556
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
557
+ return (
558
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
559
+ - pred_xstart
560
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
561
+
562
+ def _scale_timesteps(self, t):
563
+ if self.rescale_timesteps:
564
+ return t.float() * (1000.0 / self.num_timesteps)
565
+ return t
566
+
567
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
568
+ """
569
+ Compute the mean for the previous step, given a function cond_fn that
570
+ computes the gradient of a conditional log probability with respect to
571
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
572
+ condition on y.
573
+
574
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
575
+ """
576
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
577
+ new_mean = (
578
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
579
+ )
580
+ return new_mean
581
+
582
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
583
+ """
584
+ Compute what the p_mean_variance output would have been, should the
585
+ model's score function be conditioned by cond_fn.
586
+
587
+ See condition_mean() for details on cond_fn.
588
+
589
+ Unlike condition_mean(), this instead uses the conditioning strategy
590
+ from Song et al (2020).
591
+ """
592
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
593
+
594
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
595
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
596
+ x, self._scale_timesteps(t), **model_kwargs
597
+ )
598
+
599
+ out = p_mean_var.copy()
600
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
601
+ out["mean"], _, _ = self.q_posterior_mean_variance(
602
+ x_start=out["pred_xstart"], x_t=x, t=t
603
+ )
604
+ return out
605
+
606
+ def p_sample(
607
+ self,
608
+ model,
609
+ x,
610
+ t,
611
+ clip_denoised=True,
612
+ denoised_fn=None,
613
+ cond_fn=None,
614
+ pre_seq=None,
615
+ transl_req=None,
616
+ model_kwargs=None,
617
+ ):
618
+ """
619
+ Sample x_{t-1} from the model at the given timestep.
620
+
621
+ :param model: the model to sample from.
622
+ :param x: the current tensor at x_{t-1}.
623
+ :param t: the value of t, starting at 0 for the first diffusion step.
624
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
625
+ :param denoised_fn: if not None, a function which applies to the
626
+ x_start prediction before it is used to sample.
627
+ :param cond_fn: if not None, this is a gradient function that acts
628
+ similarly to the model.
629
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
630
+ pass to the model. This can be used for conditioning.
631
+ :return: a dict containing the following keys:
632
+ - 'sample': a random sample from the model.
633
+ - 'pred_xstart': a prediction of x_0.
634
+ """
635
+ # concat seq
636
+ if pre_seq is not None:
637
+ T = pre_seq.shape[2]
638
+ noise = th.randn_like(pre_seq)
639
+ x_t = self.q_sample(pre_seq, t, noise=noise)
640
+ x[:, :, :T] = x_t
641
+
642
+ if transl_req is not None:
643
+ for item in transl_req:
644
+ noise = th.randn(2).type_as(x)
645
+ transl = th.Tensor(item[1:]).type_as(x)
646
+ x_t = self.q_sample(transl, t, noise=noise)
647
+ x[:, :2, item[0]] = x_t
648
+
649
+ out = self.p_mean_variance(
650
+ model,
651
+ x,
652
+ t,
653
+ clip_denoised=clip_denoised,
654
+ denoised_fn=denoised_fn,
655
+ model_kwargs=model_kwargs,
656
+ )
657
+ noise = th.randn_like(x)
658
+ nonzero_mask = (
659
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
660
+ ) # no noise when t == 0
661
+ if cond_fn is not None:
662
+ out["mean"] = self.condition_mean(
663
+ cond_fn, out, x, t, model_kwargs=model_kwargs
664
+ )
665
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
666
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
667
+
668
+ def p_sample_loop(
669
+ self,
670
+ model,
671
+ shape,
672
+ noise=None,
673
+ clip_denoised=True,
674
+ denoised_fn=None,
675
+ cond_fn=None,
676
+ model_kwargs=None,
677
+ device=None,
678
+ pre_seq=None,
679
+ transl_req=None,
680
+ progress=False,
681
+ ):
682
+ """
683
+ Generate samples from the model.
684
+
685
+ :param model: the model module.
686
+ :param shape: the shape of the samples, (N, C, H, W).
687
+ :param noise: if specified, the noise from the encoder to sample.
688
+ Should be of the same shape as `shape`.
689
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
690
+ :param denoised_fn: if not None, a function which applies to the
691
+ x_start prediction before it is used to sample.
692
+ :param cond_fn: if not None, this is a gradient function that acts
693
+ similarly to the model.
694
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
695
+ pass to the model. This can be used for conditioning.
696
+ :param device: if specified, the device to create the samples on.
697
+ If not specified, use a model parameter's device.
698
+ :param progress: if True, show a tqdm progress bar.
699
+ :return: a non-differentiable batch of samples.
700
+ """
701
+ final = None
702
+ for sample in self.p_sample_loop_progressive(
703
+ model,
704
+ shape,
705
+ noise=noise,
706
+ clip_denoised=clip_denoised,
707
+ denoised_fn=denoised_fn,
708
+ cond_fn=cond_fn,
709
+ model_kwargs=model_kwargs,
710
+ device=device,
711
+ pre_seq=pre_seq,
712
+ transl_req=transl_req,
713
+ progress=progress,
714
+ ):
715
+ final = sample
716
+ return final["sample"]
717
+
718
+ def p_sample_loop_progressive(
719
+ self,
720
+ model,
721
+ shape,
722
+ noise=None,
723
+ clip_denoised=True,
724
+ denoised_fn=None,
725
+ cond_fn=None,
726
+ model_kwargs=None,
727
+ device=None,
728
+ pre_seq=None,
729
+ transl_req=None,
730
+ progress=False,
731
+ ):
732
+ """
733
+ Generate samples from the model and yield intermediate samples from
734
+ each timestep of diffusion.
735
+
736
+ Arguments are the same as p_sample_loop().
737
+ Returns a generator over dicts, where each dict is the return value of
738
+ p_sample().
739
+ """
740
+ if device is None:
741
+ device = next(model.parameters()).device
742
+ assert isinstance(shape, (tuple, list))
743
+ if noise is not None:
744
+ img = noise
745
+ else:
746
+ img = th.randn(*shape, device=device)
747
+ indices = list(range(self.num_timesteps))[::-1]
748
+ if progress:
749
+ # Lazy import so that we don't depend on tqdm.
750
+ from tqdm.auto import tqdm
751
+
752
+ indices = tqdm(indices)
753
+
754
+ for i in indices:
755
+ t = th.tensor([i] * shape[0], device=device)
756
+ with th.no_grad():
757
+ out = self.p_sample(
758
+ model,
759
+ img,
760
+ t,
761
+ clip_denoised=clip_denoised,
762
+ denoised_fn=denoised_fn,
763
+ cond_fn=cond_fn,
764
+ model_kwargs=model_kwargs,
765
+ pre_seq=pre_seq,
766
+ transl_req=transl_req
767
+ )
768
+ yield out
769
+ img = out["sample"]
770
+
771
+ def ddim_sample(
772
+ self,
773
+ model,
774
+ x,
775
+ t,
776
+ clip_denoised=True,
777
+ denoised_fn=None,
778
+ cond_fn=None,
779
+ model_kwargs=None,
780
+ eta=0.0,
781
+ ):
782
+ """
783
+ Sample x_{t-1} from the model using DDIM.
784
+
785
+ Same usage as p_sample().
786
+ """
787
+ out = self.p_mean_variance(
788
+ model,
789
+ x,
790
+ t,
791
+ clip_denoised=clip_denoised,
792
+ denoised_fn=denoised_fn,
793
+ model_kwargs=model_kwargs,
794
+ )
795
+ if cond_fn is not None:
796
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
797
+
798
+ # Usually our model outputs epsilon, but we re-derive it
799
+ # in case we used x_start or x_prev prediction.
800
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
801
+
802
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
803
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
804
+ sigma = (
805
+ eta
806
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
807
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
808
+ )
809
+ # Equation 12.
810
+ noise = th.randn_like(x)
811
+ mean_pred = (
812
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
813
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
814
+ )
815
+ nonzero_mask = (
816
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
817
+ ) # no noise when t == 0
818
+ sample = mean_pred + nonzero_mask * sigma * noise
819
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
820
+
821
+ def ddim_reverse_sample(
822
+ self,
823
+ model,
824
+ x,
825
+ t,
826
+ clip_denoised=True,
827
+ denoised_fn=None,
828
+ model_kwargs=None,
829
+ eta=0.0,
830
+ ):
831
+ """
832
+ Sample x_{t+1} from the model using DDIM reverse ODE.
833
+ """
834
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
835
+ out = self.p_mean_variance(
836
+ model,
837
+ x,
838
+ t,
839
+ clip_denoised=clip_denoised,
840
+ denoised_fn=denoised_fn,
841
+ model_kwargs=model_kwargs,
842
+ )
843
+ # Usually our model outputs epsilon, but we re-derive it
844
+ # in case we used x_start or x_prev prediction.
845
+ eps = (
846
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
847
+ - out["pred_xstart"]
848
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
849
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
850
+
851
+ # Equation 12. reversed
852
+ mean_pred = (
853
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
854
+ + th.sqrt(1 - alpha_bar_next) * eps
855
+ )
856
+
857
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
858
+
859
+ def ddim_sample_loop(
860
+ self,
861
+ model,
862
+ shape,
863
+ noise=None,
864
+ clip_denoised=True,
865
+ denoised_fn=None,
866
+ cond_fn=None,
867
+ model_kwargs=None,
868
+ device=None,
869
+ progress=False,
870
+ eta=0.0,
871
+ ):
872
+ """
873
+ Generate samples from the model using DDIM.
874
+
875
+ Same usage as p_sample_loop().
876
+ """
877
+ final = None
878
+ for sample in self.ddim_sample_loop_progressive(
879
+ model,
880
+ shape,
881
+ noise=noise,
882
+ clip_denoised=clip_denoised,
883
+ denoised_fn=denoised_fn,
884
+ cond_fn=cond_fn,
885
+ model_kwargs=model_kwargs,
886
+ device=device,
887
+ progress=progress,
888
+ eta=eta,
889
+ ):
890
+ final = sample
891
+ return final["sample"]
892
+
893
+ def ddim_sample_loop_progressive(
894
+ self,
895
+ model,
896
+ shape,
897
+ noise=None,
898
+ clip_denoised=True,
899
+ denoised_fn=None,
900
+ cond_fn=None,
901
+ model_kwargs=None,
902
+ device=None,
903
+ progress=False,
904
+ eta=0.0,
905
+ ):
906
+ """
907
+ Use DDIM to sample from the model and yield intermediate samples from
908
+ each timestep of DDIM.
909
+
910
+ Same usage as p_sample_loop_progressive().
911
+ """
912
+ if device is None:
913
+ device = next(model.parameters()).device
914
+ assert isinstance(shape, (tuple, list))
915
+ if noise is not None:
916
+ img = noise
917
+ else:
918
+ img = th.randn(*shape, device=device)
919
+ indices = list(range(self.num_timesteps))[::-1]
920
+
921
+ if progress:
922
+ # Lazy import so that we don't depend on tqdm.
923
+ from tqdm.auto import tqdm
924
+
925
+ indices = tqdm(indices)
926
+
927
+ for i in indices:
928
+ t = th.tensor([i] * shape[0], device=device)
929
+ with th.no_grad():
930
+ out = self.ddim_sample(
931
+ model,
932
+ img,
933
+ t,
934
+ clip_denoised=clip_denoised,
935
+ denoised_fn=denoised_fn,
936
+ cond_fn=cond_fn,
937
+ model_kwargs=model_kwargs,
938
+ eta=eta,
939
+ )
940
+ yield out
941
+ img = out["sample"]
942
+
943
+ def _vb_terms_bpd(
944
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
945
+ ):
946
+ """
947
+ Get a term for the variational lower-bound.
948
+
949
+ The resulting units are bits (rather than nats, as one might expect).
950
+ This allows for comparison to other papers.
951
+
952
+ :return: a dict with the following keys:
953
+ - 'output': a shape [N] tensor of NLLs or KLs.
954
+ - 'pred_xstart': the x_0 predictions.
955
+ """
956
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
957
+ x_start=x_start, x_t=x_t, t=t
958
+ )
959
+ out = self.p_mean_variance(
960
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
961
+ )
962
+ kl = normal_kl(
963
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
964
+ )
965
+ kl = mean_flat(kl) / np.log(2.0)
966
+
967
+ decoder_nll = -discretized_gaussian_log_likelihood(
968
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
969
+ )
970
+ assert decoder_nll.shape == x_start.shape
971
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
972
+
973
+ # At the first timestep return the decoder NLL,
974
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
975
+ output = th.where((t == 0), decoder_nll, kl)
976
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
977
+
978
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
979
+ """
980
+ Compute training losses for a single timestep.
981
+
982
+ :param model: the model to evaluate loss on.
983
+ :param x_start: the [N x C x ...] tensor of inputs.
984
+ :param t: a batch of timestep indices.
985
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
986
+ pass to the model. This can be used for conditioning.
987
+ :param noise: if specified, the specific Gaussian noise to try to remove.
988
+ :return: a dict with the key "loss" containing a tensor of shape [N].
989
+ Some mean or variance settings may also have other keys.
990
+ """
991
+ if model_kwargs is None:
992
+ model_kwargs = {}
993
+ if noise is None:
994
+ noise = th.randn_like(x_start)
995
+ x_t = self.q_sample(x_start, t, noise=noise)
996
+
997
+ terms = {}
998
+
999
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
1000
+ terms["loss"] = self._vb_terms_bpd(
1001
+ model=model,
1002
+ x_start=x_start,
1003
+ x_t=x_t,
1004
+ t=t,
1005
+ clip_denoised=False,
1006
+ model_kwargs=model_kwargs,
1007
+ )["output"]
1008
+ if self.loss_type == LossType.RESCALED_KL:
1009
+ terms["loss"] *= self.num_timesteps
1010
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
1011
+ model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
1012
+
1013
+ if self.model_var_type in [
1014
+ ModelVarType.LEARNED,
1015
+ ModelVarType.LEARNED_RANGE,
1016
+ ]:
1017
+ B, C = x_t.shape[:2]
1018
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
1019
+ model_output, model_var_values = th.split(model_output, C, dim=1)
1020
+ # Learn the variance using the variational bound, but don't let
1021
+ # it affect our mean prediction.
1022
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
1023
+ terms["vb"] = self._vb_terms_bpd(
1024
+ model=lambda *args, r=frozen_out: r,
1025
+ x_start=x_start,
1026
+ x_t=x_t,
1027
+ t=t,
1028
+ clip_denoised=False,
1029
+ )["output"]
1030
+ if self.loss_type == LossType.RESCALED_MSE:
1031
+ # Divide by 1000 for equivalence with initial implementation.
1032
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
1033
+ terms["vb"] *= self.num_timesteps / 1000.0
1034
+
1035
+ target = {
1036
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
1037
+ x_start=x_start, x_t=x_t, t=t
1038
+ )[0],
1039
+ ModelMeanType.START_X: x_start,
1040
+ ModelMeanType.EPSILON: noise,
1041
+ }[self.model_mean_type]
1042
+ assert model_output.shape == target.shape == x_start.shape
1043
+ terms["mse"] = mean_flat((target - model_output) ** 2).view(-1, 1).mean(-1)
1044
+ # if "vb" in terms:
1045
+ # terms["loss"] = terms["mse"] + terms["vb"]
1046
+ # else:
1047
+ # terms["loss"] = terms["mse"]
1048
+ terms["target"] = target
1049
+ terms["pred"] = model_output
1050
+ else:
1051
+ raise NotImplementedError(self.loss_type)
1052
+
1053
+ return terms
1054
+
1055
+ def _prior_bpd(self, x_start):
1056
+ """
1057
+ Get the prior KL term for the variational lower-bound, measured in
1058
+ bits-per-dim.
1059
+
1060
+ This term can't be optimized, as it only depends on the encoder.
1061
+
1062
+ :param x_start: the [N x C x ...] tensor of inputs.
1063
+ :return: a batch of [N] KL values (in bits), one per batch element.
1064
+ """
1065
+ batch_size = x_start.shape[0]
1066
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1067
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1068
+ kl_prior = normal_kl(
1069
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1070
+ )
1071
+ return mean_flat(kl_prior) / np.log(2.0)
1072
+
1073
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
1074
+ """
1075
+ Compute the entire variational lower-bound, measured in bits-per-dim,
1076
+ as well as other related quantities.
1077
+
1078
+ :param model: the model to evaluate loss on.
1079
+ :param x_start: the [N x C x ...] tensor of inputs.
1080
+ :param clip_denoised: if True, clip denoised samples.
1081
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1082
+ pass to the model. This can be used for conditioning.
1083
+
1084
+ :return: a dict containing the following keys:
1085
+ - total_bpd: the total variational lower-bound, per batch element.
1086
+ - prior_bpd: the prior term in the lower-bound.
1087
+ - vb: an [N x T] tensor of terms in the lower-bound.
1088
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1089
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1090
+ """
1091
+ device = x_start.device
1092
+ batch_size = x_start.shape[0]
1093
+
1094
+ vb = []
1095
+ xstart_mse = []
1096
+ mse = []
1097
+ for t in list(range(self.num_timesteps))[::-1]:
1098
+ t_batch = th.tensor([t] * batch_size, device=device)
1099
+ noise = th.randn_like(x_start)
1100
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1101
+ # Calculate VLB term at the current timestep
1102
+ with th.no_grad():
1103
+ out = self._vb_terms_bpd(
1104
+ model,
1105
+ x_start=x_start,
1106
+ x_t=x_t,
1107
+ t=t_batch,
1108
+ clip_denoised=clip_denoised,
1109
+ model_kwargs=model_kwargs,
1110
+ )
1111
+ vb.append(out["output"])
1112
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1113
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1114
+ mse.append(mean_flat((eps - noise) ** 2))
1115
+
1116
+ vb = th.stack(vb, dim=1)
1117
+ xstart_mse = th.stack(xstart_mse, dim=1)
1118
+ mse = th.stack(mse, dim=1)
1119
+
1120
+ prior_bpd = self._prior_bpd(x_start)
1121
+ total_bpd = vb.sum(dim=1) + prior_bpd
1122
+ return {
1123
+ "total_bpd": total_bpd,
1124
+ "prior_bpd": prior_bpd,
1125
+ "vb": vb,
1126
+ "xstart_mse": xstart_mse,
1127
+ "mse": mse,
1128
+ }
1129
+
1130
+
1131
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1132
+ """
1133
+ Extract values from a 1-D numpy array for a batch of indices.
1134
+
1135
+ :param arr: the 1-D numpy array.
1136
+ :param timesteps: a tensor of indices into the array to extract.
1137
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1138
+ dimension equal to the length of timesteps.
1139
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1140
+ """
1141
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1142
+ while len(res.shape) < len(broadcast_shape):
1143
+ res = res[..., None]
1144
+ return res.expand(broadcast_shape)
1145
+
models/transformer.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright 2021 S-Lab
3
+ """
4
+
5
+ from cv2 import norm
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import layer_norm, nn
9
+ import numpy as np
10
+ import clip
11
+
12
+ import math
13
+
14
+
15
+ def timestep_embedding(timesteps, dim, max_period=10000):
16
+ """
17
+ Create sinusoidal timestep embeddings.
18
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
19
+ These may be fractional.
20
+ :param dim: the dimension of the output.
21
+ :param max_period: controls the minimum frequency of the embeddings.
22
+ :return: an [N x dim] Tensor of positional embeddings.
23
+ """
24
+ half = dim // 2
25
+ freqs = torch.exp(
26
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
27
+ ).to(device=timesteps.device)
28
+ args = timesteps[:, None].float() * freqs[None]
29
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
30
+ if dim % 2:
31
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
32
+ return embedding
33
+
34
+
35
+ def set_requires_grad(nets, requires_grad=False):
36
+ """Set requies_grad for all the networks.
37
+
38
+ Args:
39
+ nets (nn.Module | list[nn.Module]): A list of networks or a single
40
+ network.
41
+ requires_grad (bool): Whether the networks require gradients or not
42
+ """
43
+ if not isinstance(nets, list):
44
+ nets = [nets]
45
+ for net in nets:
46
+ if net is not None:
47
+ for param in net.parameters():
48
+ param.requires_grad = requires_grad
49
+
50
+
51
+ def zero_module(module):
52
+ """
53
+ Zero out the parameters of a module and return it.
54
+ """
55
+ for p in module.parameters():
56
+ p.detach().zero_()
57
+ return module
58
+
59
+
60
+ class StylizationBlock(nn.Module):
61
+
62
+ def __init__(self, latent_dim, time_embed_dim, dropout):
63
+ super().__init__()
64
+ self.emb_layers = nn.Sequential(
65
+ nn.SiLU(),
66
+ nn.Linear(time_embed_dim, 2 * latent_dim),
67
+ )
68
+ self.norm = nn.LayerNorm(latent_dim)
69
+ self.out_layers = nn.Sequential(
70
+ nn.SiLU(),
71
+ nn.Dropout(p=dropout),
72
+ zero_module(nn.Linear(latent_dim, latent_dim)),
73
+ )
74
+
75
+ def forward(self, h, emb):
76
+ """
77
+ h: B, T, D
78
+ emb: B, D
79
+ """
80
+ # B, 1, 2D
81
+ emb_out = self.emb_layers(emb).unsqueeze(1)
82
+ # scale: B, 1, D / shift: B, 1, D
83
+ scale, shift = torch.chunk(emb_out, 2, dim=2)
84
+ h = self.norm(h) * (1 + scale) + shift
85
+ h = self.out_layers(h)
86
+ return h
87
+
88
+
89
+ class LinearTemporalSelfAttention(nn.Module):
90
+
91
+ def __init__(self, seq_len, latent_dim, num_head, dropout, time_embed_dim):
92
+ super().__init__()
93
+ self.num_head = num_head
94
+ self.norm = nn.LayerNorm(latent_dim)
95
+ self.query = nn.Linear(latent_dim, latent_dim)
96
+ self.key = nn.Linear(latent_dim, latent_dim)
97
+ self.value = nn.Linear(latent_dim, latent_dim)
98
+ self.dropout = nn.Dropout(dropout)
99
+ self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
100
+
101
+ def forward(self, x, emb, src_mask):
102
+ """
103
+ x: B, T, D
104
+ """
105
+ B, T, D = x.shape
106
+ H = self.num_head
107
+ # B, T, D
108
+ query = self.query(self.norm(x))
109
+ # B, T, D
110
+ key = (self.key(self.norm(x)) + (1 - src_mask) * -1000000)
111
+ query = F.softmax(query.view(B, T, H, -1), dim=-1)
112
+ key = F.softmax(key.view(B, T, H, -1), dim=1)
113
+ # B, T, H, HD
114
+ value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1)
115
+ # B, H, HD, HD
116
+ attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
117
+ y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
118
+ y = x + self.proj_out(y, emb)
119
+ return y
120
+
121
+
122
+ class LinearTemporalCrossAttention(nn.Module):
123
+
124
+ def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim):
125
+ super().__init__()
126
+ self.num_head = num_head
127
+ self.norm = nn.LayerNorm(latent_dim)
128
+ self.text_norm = nn.LayerNorm(text_latent_dim)
129
+ self.query = nn.Linear(latent_dim, latent_dim)
130
+ self.key = nn.Linear(text_latent_dim, latent_dim)
131
+ self.value = nn.Linear(text_latent_dim, latent_dim)
132
+ self.dropout = nn.Dropout(dropout)
133
+ self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
134
+
135
+ def forward(self, x, xf, emb):
136
+ """
137
+ x: B, T, D
138
+ xf: B, N, L
139
+ """
140
+ B, T, D = x.shape
141
+ N = xf.shape[1]
142
+ H = self.num_head
143
+ # B, T, D
144
+ query = self.query(self.norm(x))
145
+ # B, N, D
146
+ key = self.key(self.text_norm(xf))
147
+ query = F.softmax(query.view(B, T, H, -1), dim=-1)
148
+ key = F.softmax(key.view(B, N, H, -1), dim=1)
149
+ # B, N, H, HD
150
+ value = self.value(self.text_norm(xf)).view(B, N, H, -1)
151
+ # B, H, HD, HD
152
+ attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
153
+ y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
154
+ y = x + self.proj_out(y, emb)
155
+ return y
156
+
157
+ class FFN(nn.Module):
158
+
159
+ def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim):
160
+ super().__init__()
161
+ self.linear1 = nn.Linear(latent_dim, ffn_dim)
162
+ self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
163
+ self.activation = nn.GELU()
164
+ self.dropout = nn.Dropout(dropout)
165
+ self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
166
+
167
+ def forward(self, x, emb):
168
+ y = self.linear2(self.dropout(self.activation(self.linear1(x))))
169
+ y = x + self.proj_out(y, emb)
170
+ return y
171
+
172
+
173
+ class LinearTemporalDiffusionTransformerDecoderLayer(nn.Module):
174
+
175
+ def __init__(self,
176
+ seq_len=60,
177
+ latent_dim=32,
178
+ text_latent_dim=512,
179
+ time_embed_dim=128,
180
+ ffn_dim=256,
181
+ num_head=4,
182
+ dropout=0.1):
183
+ super().__init__()
184
+ self.sa_block = LinearTemporalSelfAttention(
185
+ seq_len, latent_dim, num_head, dropout, time_embed_dim)
186
+ self.ca_block = LinearTemporalCrossAttention(
187
+ seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim)
188
+ self.ffn = FFN(latent_dim, ffn_dim, dropout, time_embed_dim)
189
+
190
+ def forward(self, x, xf, emb, src_mask):
191
+ x = self.sa_block(x, emb, src_mask)
192
+ x = self.ca_block(x, xf, emb)
193
+ x = self.ffn(x, emb)
194
+ return x
195
+
196
+ class TemporalSelfAttention(nn.Module):
197
+
198
+ def __init__(self, seq_len, latent_dim, num_head, dropout, time_embed_dim):
199
+ super().__init__()
200
+ self.num_head = num_head
201
+ self.norm = nn.LayerNorm(latent_dim)
202
+ self.query = nn.Linear(latent_dim, latent_dim)
203
+ self.key = nn.Linear(latent_dim, latent_dim)
204
+ self.value = nn.Linear(latent_dim, latent_dim)
205
+ self.dropout = nn.Dropout(dropout)
206
+ self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
207
+
208
+ def forward(self, x, emb, src_mask):
209
+ """
210
+ x: B, T, D
211
+ """
212
+ B, T, D = x.shape
213
+ H = self.num_head
214
+ # B, T, 1, D
215
+ query = self.query(self.norm(x)).unsqueeze(2)
216
+ # B, 1, T, D
217
+ key = self.key(self.norm(x)).unsqueeze(1)
218
+ query = query.view(B, T, H, -1)
219
+ key = key.view(B, T, H, -1)
220
+ # B, T, T, H
221
+ attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
222
+ attention = attention + (1 - src_mask.unsqueeze(-1)) * -100000
223
+ weight = self.dropout(F.softmax(attention, dim=2))
224
+ value = self.value(self.norm(x)).view(B, T, H, -1)
225
+ y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
226
+ y = x + self.proj_out(y, emb)
227
+ return y
228
+
229
+ class TemporalCrossAttention(nn.Module):
230
+
231
+ def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim):
232
+ super().__init__()
233
+ self.num_head = num_head
234
+ self.norm = nn.LayerNorm(latent_dim)
235
+ self.text_norm = nn.LayerNorm(text_latent_dim)
236
+ self.query = nn.Linear(latent_dim, latent_dim)
237
+ self.key = nn.Linear(text_latent_dim, latent_dim)
238
+ self.value = nn.Linear(text_latent_dim, latent_dim)
239
+ self.dropout = nn.Dropout(dropout)
240
+ self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
241
+
242
+ def forward(self, x, xf, emb):
243
+ """
244
+ x: B, T, D
245
+ xf: B, N, L
246
+ """
247
+ B, T, D = x.shape
248
+ N = xf.shape[1]
249
+ H = self.num_head
250
+ # B, T, 1, D
251
+ query = self.query(self.norm(x)).unsqueeze(2)
252
+ # B, 1, N, D
253
+ key = self.key(self.text_norm(xf)).unsqueeze(1)
254
+ query = query.view(B, T, H, -1)
255
+ key = key.view(B, N, H, -1)
256
+ # B, T, N, H
257
+ attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
258
+ weight = self.dropout(F.softmax(attention, dim=2))
259
+ value = self.value(self.text_norm(xf)).view(B, N, H, -1)
260
+ y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
261
+ y = x + self.proj_out(y, emb)
262
+ return y
263
+
264
+ class TemporalDiffusionTransformerDecoderLayer(nn.Module):
265
+
266
+ def __init__(self,
267
+ seq_len=60,
268
+ latent_dim=32,
269
+ text_latent_dim=512,
270
+ time_embed_dim=128,
271
+ ffn_dim=256,
272
+ num_head=4,
273
+ dropout=0.1):
274
+ super().__init__()
275
+ self.sa_block = TemporalSelfAttention(
276
+ seq_len, latent_dim, num_head, dropout, time_embed_dim)
277
+ self.ca_block = TemporalCrossAttention(
278
+ seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim)
279
+ self.ffn = FFN(latent_dim, ffn_dim, dropout, time_embed_dim)
280
+
281
+ def forward(self, x, xf, emb, src_mask):
282
+ x = self.sa_block(x, emb, src_mask)
283
+ x = self.ca_block(x, xf, emb)
284
+ x = self.ffn(x, emb)
285
+ return x
286
+
287
+
288
+ class MotionTransformer(nn.Module):
289
+ def __init__(self,
290
+ input_feats,
291
+ num_frames=240,
292
+ latent_dim=512,
293
+ ff_size=1024,
294
+ num_layers=8,
295
+ num_heads=8,
296
+ dropout=0,
297
+ activation="gelu",
298
+ num_text_layers=4,
299
+ text_latent_dim=256,
300
+ text_ff_size=2048,
301
+ text_num_heads=4,
302
+ no_clip=False,
303
+ no_eff=False,
304
+ **kargs):
305
+ super().__init__()
306
+
307
+ self.num_frames = num_frames
308
+ self.latent_dim = latent_dim
309
+ self.ff_size = ff_size
310
+ self.num_layers = num_layers
311
+ self.num_heads = num_heads
312
+ self.dropout = dropout
313
+ self.activation = activation
314
+ self.input_feats = input_feats
315
+ self.time_embed_dim = latent_dim * 4
316
+ self.sequence_embedding = nn.Parameter(torch.randn(num_frames, latent_dim))
317
+
318
+ # Text Transformer
319
+ self.clip, _ = clip.load('ViT-B/32', "cpu")
320
+ if no_clip:
321
+ self.clip.initialize_parameters()
322
+ else:
323
+ set_requires_grad(self.clip, False)
324
+ if text_latent_dim != 512:
325
+ self.text_pre_proj = nn.Linear(512, text_latent_dim)
326
+ else:
327
+ self.text_pre_proj = nn.Identity()
328
+ textTransEncoderLayer = nn.TransformerEncoderLayer(
329
+ d_model=text_latent_dim,
330
+ nhead=text_num_heads,
331
+ dim_feedforward=text_ff_size,
332
+ dropout=dropout,
333
+ activation=activation)
334
+ self.textTransEncoder = nn.TransformerEncoder(
335
+ textTransEncoderLayer,
336
+ num_layers=num_text_layers)
337
+ self.text_ln = nn.LayerNorm(text_latent_dim)
338
+ self.text_proj = nn.Sequential(
339
+ nn.Linear(text_latent_dim, self.time_embed_dim)
340
+ )
341
+
342
+ # Input Embedding
343
+ self.joint_embed = nn.Linear(self.input_feats, self.latent_dim)
344
+
345
+ self.time_embed = nn.Sequential(
346
+ nn.Linear(self.latent_dim, self.time_embed_dim),
347
+ nn.SiLU(),
348
+ nn.Linear(self.time_embed_dim, self.time_embed_dim),
349
+ )
350
+ self.temporal_decoder_blocks = nn.ModuleList()
351
+ for i in range(num_layers):
352
+ if no_eff:
353
+ self.temporal_decoder_blocks.append(
354
+ TemporalDiffusionTransformerDecoderLayer(
355
+ seq_len=num_frames,
356
+ latent_dim=latent_dim,
357
+ text_latent_dim=text_latent_dim,
358
+ time_embed_dim=self.time_embed_dim,
359
+ ffn_dim=ff_size,
360
+ num_head=num_heads,
361
+ dropout=dropout
362
+ )
363
+ )
364
+ else:
365
+ self.temporal_decoder_blocks.append(
366
+ LinearTemporalDiffusionTransformerDecoderLayer(
367
+ seq_len=num_frames,
368
+ latent_dim=latent_dim,
369
+ text_latent_dim=text_latent_dim,
370
+ time_embed_dim=self.time_embed_dim,
371
+ ffn_dim=ff_size,
372
+ num_head=num_heads,
373
+ dropout=dropout
374
+ )
375
+ )
376
+
377
+ # Output Module
378
+ self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats))
379
+
380
+ def encode_text(self, text, device):
381
+ with torch.no_grad():
382
+ text = clip.tokenize(text, truncate=True).to(device)
383
+ x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, d_model]
384
+
385
+ x = x + self.clip.positional_embedding.type(self.clip.dtype)
386
+ x = x.permute(1, 0, 2) # NLD -> LND
387
+ x = self.clip.transformer(x)
388
+ x = self.clip.ln_final(x).type(self.clip.dtype)
389
+
390
+ # T, B, D
391
+ x = self.text_pre_proj(x)
392
+ xf_out = self.textTransEncoder(x)
393
+ xf_out = self.text_ln(xf_out)
394
+ xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])])
395
+ # B, T, D
396
+ xf_out = xf_out.permute(1, 0, 2)
397
+ return xf_proj, xf_out
398
+
399
+ def generate_src_mask(self, T, length):
400
+ B = len(length)
401
+ src_mask = torch.ones(B, T)
402
+ for i in range(B):
403
+ for j in range(length[i], T):
404
+ src_mask[i, j] = 0
405
+ return src_mask
406
+
407
+ def forward(self, x, timesteps, length=None, text=None, xf_proj=None, xf_out=None):
408
+ """
409
+ x: B, T, D
410
+ """
411
+ B, T = x.shape[0], x.shape[1]
412
+ if xf_proj is None or xf_out is None:
413
+ xf_proj, xf_out = self.encode_text(text, x.device)
414
+
415
+ emb = self.time_embed(timestep_embedding(timesteps, self.latent_dim)) + xf_proj
416
+
417
+ # B, T, latent_dim
418
+ h = self.joint_embed(x)
419
+ h = h + self.sequence_embedding.unsqueeze(0)[:, :T, :]
420
+
421
+ src_mask = self.generate_src_mask(T, length).to(x.device).unsqueeze(-1)
422
+ for module in self.temporal_decoder_blocks:
423
+ h = module(h, xf_out, emb, src_mask)
424
+
425
+ output = self.out(h).view(B, T, -1).contiguous()
426
+ return output
options/base_options.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from mmcv.runner import get_dist_info
5
+ import torch.distributed as dist
6
+
7
+
8
+ class BaseOptions():
9
+ def __init__(self):
10
+ self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
11
+ self.initialized = False
12
+
13
+ def initialize(self):
14
+ self.parser.add_argument('--name', type=str, default="test", help='Name of this trial')
15
+ self.parser.add_argument('--decomp_name', type=str, default="Decomp_SP001_SM001_H512", help='Name of autoencoder model')
16
+
17
+ self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
18
+ self.parser.add_argument("--distributed", action="store_true", help='Weather to use DDP training')
19
+
20
+ self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name')
21
+ self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
22
+
23
+ self.parser.add_argument("--unit_length", type=int, default=4, help="Motions are cropped to the maximum times of unit_length")
24
+ self.parser.add_argument("--max_text_len", type=int, default=20, help="Maximum length of text description")
25
+
26
+ self.parser.add_argument('--text_enc_mod', type=str, default='bigru')
27
+ self.parser.add_argument('--estimator_mod', type=str, default='bigru')
28
+
29
+ self.parser.add_argument('--dim_text_hidden', type=int, default=512, help='Dimension of hidden unit in text encoder')
30
+ self.parser.add_argument('--dim_att_vec', type=int, default=512, help='Dimension of attention vector')
31
+ self.parser.add_argument('--dim_z', type=int, default=128, help='Dimension of latent Gaussian vector')
32
+
33
+ self.parser.add_argument('--n_layers_pri', type=int, default=1, help='Number of layers in prior network')
34
+ self.parser.add_argument('--n_layers_pos', type=int, default=1, help='Number of layers in posterior network')
35
+ self.parser.add_argument('--n_layers_dec', type=int, default=1, help='Number of layers in generator')
36
+
37
+ self.parser.add_argument('--dim_pri_hidden', type=int, default=1024, help='Dimension of hidden unit in prior network')
38
+ self.parser.add_argument('--dim_pos_hidden', type=int, default=1024, help='Dimension of hidden unit in posterior network')
39
+ self.parser.add_argument('--dim_dec_hidden', type=int, default=1024, help='Dimension of hidden unit in generator')
40
+
41
+ self.parser.add_argument('--dim_movement_enc_hidden', type=int, default=512,
42
+ help='Dimension of hidden in AutoEncoder(encoder)')
43
+ self.parser.add_argument('--dim_movement_dec_hidden', type=int, default=512,
44
+ help='Dimension of hidden in AutoEncoder(decoder)')
45
+ self.parser.add_argument('--dim_movement_latent', type=int, default=512, help='Dimension of motion snippet')
46
+
47
+ self.initialized = True
48
+
49
+
50
+
51
+ def parse(self):
52
+ if not self.initialized:
53
+ self.initialize()
54
+
55
+ self.opt = self.parser.parse_args()
56
+
57
+ self.opt.is_train = self.is_train
58
+
59
+ if self.opt.gpu_id != -1:
60
+ # self.opt.gpu_id = int(self.opt.gpu_id)
61
+ torch.cuda.set_device(self.opt.gpu_id)
62
+
63
+ args = vars(self.opt)
64
+
65
+ if args["distributed"]:
66
+ init_dist('slurm')
67
+ rank, world_size = get_dist_info()
68
+ if rank == 0:
69
+ print('------------ Options -------------')
70
+ for k, v in sorted(args.items()):
71
+ print('%s: %s' % (str(k), str(v)))
72
+ print('-------------- End ----------------')
73
+ if self.is_train:
74
+ # save to the disk
75
+ expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.dataset_name, self.opt.name)
76
+ if not os.path.exists(expr_dir):
77
+ os.makedirs(expr_dir)
78
+ file_name = os.path.join(expr_dir, 'opt.txt')
79
+ with open(file_name, 'wt') as opt_file:
80
+ opt_file.write('------------ Options -------------\n')
81
+ for k, v in sorted(args.items()):
82
+ opt_file.write('%s: %s\n' % (str(k), str(v)))
83
+ opt_file.write('-------------- End ----------------\n')
84
+ if world_size > 1:
85
+ dist.barrier()
86
+ return self.opt
options/evaluate_options.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from options.base_options import BaseOptions
2
+
3
+
4
+ class TestOptions(BaseOptions):
5
+ def initialize(self):
6
+ BaseOptions.initialize(self)
7
+ self.parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
8
+ self.parser.add_argument('--start_mov_len', type=int, default=10)
9
+ self.parser.add_argument('--est_length', action="store_true", help="Whether to use sampled motion length")
10
+ self.parser.add_argument('--num_layers', type=int, default=8, help='num_layers of transformer')
11
+ self.parser.add_argument('--latent_dim', type=int, default=512, help='latent_dim of transformer')
12
+ self.parser.add_argument('--diffusion_steps', type=int, default=1000, help='diffusion_steps of transformer')
13
+ self.parser.add_argument('--no_clip', action='store_true', help='whether use clip pretrain')
14
+ self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient attention')
15
+
16
+
17
+ self.parser.add_argument('--repeat_times', type=int, default=3, help="Number of generation rounds for each text description")
18
+ self.parser.add_argument('--split_file', type=str, default='test.txt')
19
+ self.parser.add_argument('--text', type=str, default="", help='Text description for motion generation')
20
+ self.parser.add_argument('--motion_length', type=int, default=0, help='Number of framese for motion generation')
21
+ self.parser.add_argument('--text_file', type=str, default="", help='Path of text description for motion generation')
22
+ self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint that will be used')
23
+ self.parser.add_argument('--result_path', type=str, default="./eval_results/", help='Path to save generation results')
24
+ self.parser.add_argument('--num_results', type=int, default=40, help='Number of descriptions that will be used')
25
+ self.parser.add_argument('--ext', type=str, default='default', help='Save file path extension')
26
+
27
+ self.is_train = False
options/train_options.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from options.base_options import BaseOptions
2
+ import argparse
3
+
4
+ class TrainCompOptions(BaseOptions):
5
+ def initialize(self):
6
+ BaseOptions.initialize(self)
7
+ self.parser.add_argument('--num_layers', type=int, default=8, help='num_layers of transformer')
8
+ self.parser.add_argument('--latent_dim', type=int, default=512, help='latent_dim of transformer')
9
+ self.parser.add_argument('--diffusion_steps', type=int, default=1000, help='diffusion_steps of transformer')
10
+ self.parser.add_argument('--no_clip', action='store_true', help='whether use clip pretrain')
11
+ self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient attention')
12
+
13
+ self.parser.add_argument('--num_epochs', type=int, default=50, help='Number of epochs')
14
+ self.parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate')
15
+ self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU')
16
+ self.parser.add_argument('--times', type=int, default=1, help='times of dataset')
17
+
18
+ self.parser.add_argument('--feat_bias', type=float, default=5, help='Scales for global motion features and foot contact')
19
+
20
+ self.parser.add_argument('--is_continue', action="store_true", help='Is this trail continued from previous trail?')
21
+
22
+ self.parser.add_argument('--log_every', type=int, default=50, help='Frequency of printing training progress (by iteration)')
23
+ self.parser.add_argument('--save_every_e', type=int, default=5, help='Frequency of saving models (by epoch)')
24
+ self.parser.add_argument('--eval_every_e', type=int, default=5, help='Frequency of animation results (by epoch)')
25
+ self.parser.add_argument('--save_latest', type=int, default=500, help='Frequency of saving models (by iteration)')
26
+ self.is_train = True
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ mmcv
2
+ matplotlib==3.3.1
3
+ torch==1.7.1
4
+ git+https://github.com/openai/CLIP.git
tools/evaluation.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import numpy as np
3
+ import torch
4
+ from datasets import get_dataset_motion_loader, get_motion_loader
5
+ from models import MotionTransformer
6
+ from utils.get_opt import get_opt
7
+ from utils.metrics import *
8
+ from datasets import EvaluatorModelWrapper
9
+ from collections import OrderedDict
10
+ from utils.plot_script import *
11
+ from utils import paramUtil
12
+ from utils.utils import *
13
+ from trainers import DDPMTrainer
14
+
15
+ from os.path import join as pjoin
16
+ import sys
17
+
18
+
19
+ def build_models(opt, dim_pose):
20
+ encoder = MotionTransformer(
21
+ input_feats=dim_pose,
22
+ num_frames=opt.max_motion_length,
23
+ num_layers=opt.num_layers,
24
+ latent_dim=opt.latent_dim,
25
+ no_clip=opt.no_clip,
26
+ no_eff=opt.no_eff)
27
+ return encoder
28
+
29
+
30
+ torch.multiprocessing.set_sharing_strategy('file_system')
31
+
32
+
33
+ def evaluate_matching_score(motion_loaders, file):
34
+ match_score_dict = OrderedDict({})
35
+ R_precision_dict = OrderedDict({})
36
+ activation_dict = OrderedDict({})
37
+ # print(motion_loaders.keys())
38
+ print('========== Evaluating Matching Score ==========')
39
+ for motion_loader_name, motion_loader in motion_loaders.items():
40
+ all_motion_embeddings = []
41
+ score_list = []
42
+ all_size = 0
43
+ matching_score_sum = 0
44
+ top_k_count = 0
45
+ # print(motion_loader_name)
46
+ with torch.no_grad():
47
+ for idx, batch in enumerate(motion_loader):
48
+ word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch
49
+ text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(
50
+ word_embs=word_embeddings,
51
+ pos_ohot=pos_one_hots,
52
+ cap_lens=sent_lens,
53
+ motions=motions,
54
+ m_lens=m_lens
55
+ )
56
+ dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
57
+ motion_embeddings.cpu().numpy())
58
+ matching_score_sum += dist_mat.trace()
59
+
60
+ argsmax = np.argsort(dist_mat, axis=1)
61
+ top_k_mat = calculate_top_k(argsmax, top_k=3)
62
+ top_k_count += top_k_mat.sum(axis=0)
63
+
64
+ all_size += text_embeddings.shape[0]
65
+
66
+ all_motion_embeddings.append(motion_embeddings.cpu().numpy())
67
+
68
+ all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
69
+ matching_score = matching_score_sum / all_size
70
+ R_precision = top_k_count / all_size
71
+ match_score_dict[motion_loader_name] = matching_score
72
+ R_precision_dict[motion_loader_name] = R_precision
73
+ activation_dict[motion_loader_name] = all_motion_embeddings
74
+
75
+ print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}')
76
+ print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True)
77
+
78
+ line = f'---> [{motion_loader_name}] R_precision: '
79
+ for i in range(len(R_precision)):
80
+ line += '(top %d): %.4f ' % (i+1, R_precision[i])
81
+ print(line)
82
+ print(line, file=file, flush=True)
83
+
84
+ return match_score_dict, R_precision_dict, activation_dict
85
+
86
+
87
+ def evaluate_fid(groundtruth_loader, activation_dict, file):
88
+ eval_dict = OrderedDict({})
89
+ gt_motion_embeddings = []
90
+ print('========== Evaluating FID ==========')
91
+ with torch.no_grad():
92
+ for idx, batch in enumerate(groundtruth_loader):
93
+ _, _, _, sent_lens, motions, m_lens, _ = batch
94
+ motion_embeddings = eval_wrapper.get_motion_embeddings(
95
+ motions=motions,
96
+ m_lens=m_lens
97
+ )
98
+ gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
99
+ gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
100
+ gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
101
+
102
+ # print(gt_mu)
103
+ for model_name, motion_embeddings in activation_dict.items():
104
+ mu, cov = calculate_activation_statistics(motion_embeddings)
105
+ # print(mu)
106
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
107
+ print(f'---> [{model_name}] FID: {fid:.4f}')
108
+ print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
109
+ eval_dict[model_name] = fid
110
+ return eval_dict
111
+
112
+
113
+ def evaluate_diversity(activation_dict, file):
114
+ eval_dict = OrderedDict({})
115
+ print('========== Evaluating Diversity ==========')
116
+ for model_name, motion_embeddings in activation_dict.items():
117
+ diversity = calculate_diversity(motion_embeddings, diversity_times)
118
+ eval_dict[model_name] = diversity
119
+ print(f'---> [{model_name}] Diversity: {diversity:.4f}')
120
+ print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
121
+ return eval_dict
122
+
123
+
124
+ def evaluate_multimodality(mm_motion_loaders, file):
125
+ eval_dict = OrderedDict({})
126
+ print('========== Evaluating MultiModality ==========')
127
+ for model_name, mm_motion_loader in mm_motion_loaders.items():
128
+ mm_motion_embeddings = []
129
+ with torch.no_grad():
130
+ for idx, batch in enumerate(mm_motion_loader):
131
+ # (1, mm_replications, dim_pos)
132
+ motions, m_lens = batch
133
+ motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0])
134
+ mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
135
+ if len(mm_motion_embeddings) == 0:
136
+ multimodality = 0
137
+ else:
138
+ mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
139
+ multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
140
+ print(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
141
+ print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
142
+ eval_dict[model_name] = multimodality
143
+ return eval_dict
144
+
145
+
146
+ def get_metric_statistics(values):
147
+ mean = np.mean(values, axis=0)
148
+ std = np.std(values, axis=0)
149
+ conf_interval = 1.96 * std / np.sqrt(replication_times)
150
+ return mean, conf_interval
151
+
152
+
153
+ def evaluation(log_file):
154
+ with open(log_file, 'w') as f:
155
+ all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
156
+ 'R_precision': OrderedDict({}),
157
+ 'FID': OrderedDict({}),
158
+ 'Diversity': OrderedDict({}),
159
+ 'MultiModality': OrderedDict({})})
160
+ for replication in range(replication_times):
161
+ motion_loaders = {}
162
+ mm_motion_loaders = {}
163
+ motion_loaders['ground truth'] = gt_loader
164
+ for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
165
+ motion_loader, mm_motion_loader = motion_loader_getter()
166
+ motion_loaders[motion_loader_name] = motion_loader
167
+ mm_motion_loaders[motion_loader_name] = mm_motion_loader
168
+
169
+ print(f'==================== Replication {replication} ====================')
170
+ print(f'==================== Replication {replication} ====================', file=f, flush=True)
171
+ print(f'Time: {datetime.now()}')
172
+ print(f'Time: {datetime.now()}', file=f, flush=True)
173
+ mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(motion_loaders, f)
174
+
175
+ print(f'Time: {datetime.now()}')
176
+ print(f'Time: {datetime.now()}', file=f, flush=True)
177
+ fid_score_dict = evaluate_fid(gt_loader, acti_dict, f)
178
+
179
+ print(f'Time: {datetime.now()}')
180
+ print(f'Time: {datetime.now()}', file=f, flush=True)
181
+ div_score_dict = evaluate_diversity(acti_dict, f)
182
+
183
+ print(f'Time: {datetime.now()}')
184
+ print(f'Time: {datetime.now()}', file=f, flush=True)
185
+ mm_score_dict = evaluate_multimodality(mm_motion_loaders, f)
186
+
187
+ print(f'!!! DONE !!!')
188
+ print(f'!!! DONE !!!', file=f, flush=True)
189
+
190
+ for key, item in mat_score_dict.items():
191
+ if key not in all_metrics['Matching Score']:
192
+ all_metrics['Matching Score'][key] = [item]
193
+ else:
194
+ all_metrics['Matching Score'][key] += [item]
195
+
196
+ for key, item in R_precision_dict.items():
197
+ if key not in all_metrics['R_precision']:
198
+ all_metrics['R_precision'][key] = [item]
199
+ else:
200
+ all_metrics['R_precision'][key] += [item]
201
+
202
+ for key, item in fid_score_dict.items():
203
+ if key not in all_metrics['FID']:
204
+ all_metrics['FID'][key] = [item]
205
+ else:
206
+ all_metrics['FID'][key] += [item]
207
+
208
+ for key, item in div_score_dict.items():
209
+ if key not in all_metrics['Diversity']:
210
+ all_metrics['Diversity'][key] = [item]
211
+ else:
212
+ all_metrics['Diversity'][key] += [item]
213
+
214
+ for key, item in mm_score_dict.items():
215
+ if key not in all_metrics['MultiModality']:
216
+ all_metrics['MultiModality'][key] = [item]
217
+ else:
218
+ all_metrics['MultiModality'][key] += [item]
219
+
220
+
221
+ # print(all_metrics['Diversity'])
222
+ for metric_name, metric_dict in all_metrics.items():
223
+ print('========== %s Summary ==========' % metric_name)
224
+ print('========== %s Summary ==========' % metric_name, file=f, flush=True)
225
+
226
+ for model_name, values in metric_dict.items():
227
+ # print(metric_name, model_name)
228
+ mean, conf_interval = get_metric_statistics(np.array(values))
229
+ # print(mean, mean.dtype)
230
+ if isinstance(mean, np.float64) or isinstance(mean, np.float32):
231
+ print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
232
+ print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
233
+ elif isinstance(mean, np.ndarray):
234
+ line = f'---> [{model_name}]'
235
+ for i in range(len(mean)):
236
+ line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
237
+ print(line)
238
+ print(line, file=f, flush=True)
239
+
240
+
241
+ if __name__ == '__main__':
242
+ mm_num_samples = 100
243
+ mm_num_repeats = 30
244
+ mm_num_times = 10
245
+
246
+ diversity_times = 300
247
+ replication_times = 1
248
+ batch_size = 32
249
+ opt_path = sys.argv[1]
250
+ dataset_opt_path = opt_path
251
+
252
+ try:
253
+ device_id = int(sys.argv[2])
254
+ except:
255
+ device_id = 0
256
+ device = torch.device('cuda:%d' % device_id if torch.cuda.is_available() else 'cpu')
257
+ torch.cuda.set_device(device_id)
258
+
259
+ gt_loader, gt_dataset = get_dataset_motion_loader(dataset_opt_path, batch_size, device)
260
+ wrapper_opt = get_opt(dataset_opt_path, device)
261
+ eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
262
+
263
+ opt = get_opt(opt_path, device)
264
+ encoder = build_models(opt, opt.dim_pose)
265
+ trainer = DDPMTrainer(opt, encoder)
266
+ eval_motion_loaders = {
267
+ 'text2motion': lambda: get_motion_loader(
268
+ opt,
269
+ batch_size,
270
+ trainer,
271
+ gt_dataset,
272
+ mm_num_samples,
273
+ mm_num_repeats
274
+ )
275
+ }
276
+
277
+ log_file = './t2m_evaluation.log'
278
+ evaluation(log_file)
tools/train.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join as pjoin
3
+
4
+ import utils.paramUtil as paramUtil
5
+ from options.train_options import TrainCompOptions
6
+ from utils.plot_script import *
7
+
8
+ from models import MotionTransformer
9
+ from trainers import DDPMTrainer
10
+ from datasets import Text2MotionDataset
11
+
12
+ from mmcv.runner import get_dist_info, init_dist
13
+ from mmcv.parallel import MMDistributedDataParallel
14
+ import torch
15
+ import torch.distributed as dist
16
+
17
+
18
+ def build_models(opt, dim_pose):
19
+ encoder = MotionTransformer(
20
+ input_feats=dim_pose,
21
+ num_frames=opt.max_motion_length,
22
+ num_layers=opt.num_layers,
23
+ latent_dim=opt.latent_dim,
24
+ no_clip=opt.no_clip,
25
+ no_eff=opt.no_eff)
26
+ return encoder
27
+
28
+
29
+ if __name__ == '__main__':
30
+ parser = TrainCompOptions()
31
+ opt = parser.parse()
32
+ rank, world_size = get_dist_info()
33
+
34
+ opt.device = torch.device("cuda")
35
+ torch.autograd.set_detect_anomaly(True)
36
+
37
+ opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
38
+ opt.model_dir = pjoin(opt.save_root, 'model')
39
+ opt.meta_dir = pjoin(opt.save_root, 'meta')
40
+
41
+ if rank == 0:
42
+ os.makedirs(opt.model_dir, exist_ok=True)
43
+ os.makedirs(opt.meta_dir, exist_ok=True)
44
+ if world_size > 1:
45
+ dist.barrier()
46
+
47
+ if opt.dataset_name == 't2m':
48
+ opt.data_root = './data/HumanML3D'
49
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
50
+ opt.text_dir = pjoin(opt.data_root, 'texts')
51
+ opt.joints_num = 22
52
+ radius = 4
53
+ fps = 20
54
+ opt.max_motion_length = 196
55
+ dim_pose = 263
56
+ kinematic_chain = paramUtil.t2m_kinematic_chain
57
+ elif opt.dataset_name == 'kit':
58
+ opt.data_root = './data/KIT-ML'
59
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
60
+ opt.text_dir = pjoin(opt.data_root, 'texts')
61
+ opt.joints_num = 21
62
+ radius = 240 * 8
63
+ fps = 12.5
64
+ dim_pose = 251
65
+ opt.max_motion_length = 196
66
+ kinematic_chain = paramUtil.kit_kinematic_chain
67
+
68
+ else:
69
+ raise KeyError('Dataset Does Not Exist')
70
+
71
+ dim_word = 300
72
+ mean = np.load(pjoin(opt.data_root, 'Mean.npy'))
73
+ std = np.load(pjoin(opt.data_root, 'Std.npy'))
74
+
75
+ train_split_file = pjoin(opt.data_root, 'train.txt')
76
+
77
+ encoder = build_models(opt, dim_pose)
78
+ if world_size > 1:
79
+ encoder = MMDistributedDataParallel(
80
+ encoder.cuda(),
81
+ device_ids=[torch.cuda.current_device()],
82
+ broadcast_buffers=False,
83
+ find_unused_parameters=True)
84
+ else:
85
+ encoder = encoder.cuda()
86
+
87
+ trainer = DDPMTrainer(opt, encoder)
88
+ train_dataset = Text2MotionDataset(opt, mean, std, train_split_file, opt.times)
89
+ trainer.train(train_dataset)
tools/visualization.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+
5
+ import utils.paramUtil as paramUtil
6
+ from torch.utils.data import DataLoader
7
+ from utils.plot_script import *
8
+
9
+ from utils.utils import *
10
+ from utils.motion_process import recover_from_ric
11
+
12
+
13
+ def plot_t2m(opt, data, result_path, caption):
14
+ joint = recover_from_ric(torch.from_numpy(data).float(), opt.joints_num).numpy()
15
+ # joint = motion_temporal_filter(joint, sigma=1)
16
+ plot_3d_motion(result_path, paramUtil.t2m_kinematic_chain, joint, title=caption, fps=20)
17
+
18
+
19
+ def process(trainer, opt, device, mean, std, text, motion_length, result_path):
20
+
21
+ result_dict = {}
22
+ with torch.no_grad():
23
+ if motion_length != -1:
24
+ caption = [text]
25
+ m_lens = torch.LongTensor([motion_length]).to(device)
26
+ pred_motions = trainer.generate(caption, m_lens, opt.dim_pose)
27
+ motion = pred_motions[0].cpu().numpy()
28
+ motion = motion * std + mean
29
+ title = text + " #%d" % motion.shape[0]
30
+ plot_t2m(opt, motion, result_path, title)
trainers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ from .ddpm_trainer import DDPMTrainer
2
+
3
+
4
+ __all__ = ['DDPMTrainer']
trainers/ddpm_trainer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import random
4
+ import time
5
+ from models.transformer import MotionTransformer
6
+ from torch.utils.data import DataLoader
7
+ import torch.optim as optim
8
+ from torch.nn.utils import clip_grad_norm_
9
+ from collections import OrderedDict
10
+ from utils.utils import print_current_loss
11
+ from os.path import join as pjoin
12
+ import codecs as cs
13
+ import torch.distributed as dist
14
+
15
+
16
+ from mmcv.runner import get_dist_info
17
+ from models.gaussian_diffusion import (
18
+ GaussianDiffusion,
19
+ get_named_beta_schedule,
20
+ create_named_schedule_sampler,
21
+ ModelMeanType,
22
+ ModelVarType,
23
+ LossType
24
+ )
25
+
26
+ from datasets import build_dataloader
27
+
28
+
29
+ class DDPMTrainer(object):
30
+
31
+ def __init__(self, args, encoder):
32
+ self.opt = args
33
+ self.device = args.device
34
+ self.encoder = encoder
35
+ self.diffusion_steps = args.diffusion_steps
36
+ sampler = 'uniform'
37
+ beta_scheduler = 'linear'
38
+ betas = get_named_beta_schedule(beta_scheduler, self.diffusion_steps)
39
+ self.diffusion = GaussianDiffusion(
40
+ betas=betas,
41
+ model_mean_type=ModelMeanType.EPSILON,
42
+ model_var_type=ModelVarType.FIXED_SMALL,
43
+ loss_type=LossType.MSE
44
+ )
45
+ self.sampler = create_named_schedule_sampler(sampler, self.diffusion)
46
+ self.sampler_name = sampler
47
+
48
+ if args.is_train:
49
+ self.mse_criterion = torch.nn.MSELoss(reduction='none')
50
+ self.to(self.device)
51
+
52
+ @staticmethod
53
+ def zero_grad(opt_list):
54
+ for opt in opt_list:
55
+ opt.zero_grad()
56
+
57
+ @staticmethod
58
+ def clip_norm(network_list):
59
+ for network in network_list:
60
+ clip_grad_norm_(network.parameters(), 0.5)
61
+
62
+ @staticmethod
63
+ def step(opt_list):
64
+ for opt in opt_list:
65
+ opt.step()
66
+
67
+ def forward(self, batch_data, eval_mode=False):
68
+ caption, motions, m_lens = batch_data
69
+ motions = motions.detach().to(self.device).float()
70
+
71
+ self.caption = caption
72
+ self.motions = motions
73
+ x_start = motions
74
+ B, T = x_start.shape[:2]
75
+ cur_len = torch.LongTensor([min(T, m_len) for m_len in m_lens]).to(self.device)
76
+ t, _ = self.sampler.sample(B, x_start.device)
77
+ output = self.diffusion.training_losses(
78
+ model=self.encoder,
79
+ x_start=x_start,
80
+ t=t,
81
+ model_kwargs={"text": caption, "length": cur_len}
82
+ )
83
+
84
+ self.real_noise = output['target']
85
+ self.fake_noise = output['pred']
86
+ try:
87
+ self.src_mask = self.encoder.module.generate_src_mask(T, cur_len).to(x_start.device)
88
+ except:
89
+ self.src_mask = self.encoder.generate_src_mask(T, cur_len).to(x_start.device)
90
+
91
+ def generate_batch(self, caption, m_lens, dim_pose):
92
+ xf_proj, xf_out = self.encoder.encode_text(caption, self.device)
93
+
94
+ B = len(caption)
95
+ T = min(m_lens.max(), self.encoder.num_frames)
96
+ output = self.diffusion.p_sample_loop(
97
+ self.encoder,
98
+ (B, T, dim_pose),
99
+ clip_denoised=False,
100
+ progress=True,
101
+ model_kwargs={
102
+ 'xf_proj': xf_proj,
103
+ 'xf_out': xf_out,
104
+ 'length': m_lens
105
+ })
106
+ return output
107
+
108
+ def generate(self, caption, m_lens, dim_pose, batch_size=1024):
109
+ N = len(caption)
110
+ cur_idx = 0
111
+ self.encoder.eval()
112
+ all_output = []
113
+ while cur_idx < N:
114
+ if cur_idx + batch_size >= N:
115
+ batch_caption = caption[cur_idx:]
116
+ batch_m_lens = m_lens[cur_idx:]
117
+ else:
118
+ batch_caption = caption[cur_idx: cur_idx + batch_size]
119
+ batch_m_lens = m_lens[cur_idx: cur_idx + batch_size]
120
+ output = self.generate_batch(batch_caption, batch_m_lens, dim_pose)
121
+ B = output.shape[0]
122
+
123
+ for i in range(B):
124
+ all_output.append(output[i])
125
+ cur_idx += batch_size
126
+ return all_output
127
+
128
+ def backward_G(self):
129
+ loss_mot_rec = self.mse_criterion(self.fake_noise, self.real_noise).mean(dim=-1)
130
+ loss_mot_rec = (loss_mot_rec * self.src_mask).sum() / self.src_mask.sum()
131
+ self.loss_mot_rec = loss_mot_rec
132
+ loss_logs = OrderedDict({})
133
+ loss_logs['loss_mot_rec'] = self.loss_mot_rec.item()
134
+ return loss_logs
135
+
136
+ def update(self):
137
+ self.zero_grad([self.opt_encoder])
138
+ loss_logs = self.backward_G()
139
+ self.loss_mot_rec.backward()
140
+ self.clip_norm([self.encoder])
141
+ self.step([self.opt_encoder])
142
+
143
+ return loss_logs
144
+
145
+ def to(self, device):
146
+ if self.opt.is_train:
147
+ self.mse_criterion.to(device)
148
+ self.encoder = self.encoder.to(device)
149
+
150
+ def train_mode(self):
151
+ self.encoder.train()
152
+
153
+ def eval_mode(self):
154
+ self.encoder.eval()
155
+
156
+ def save(self, file_name, ep, total_it):
157
+ state = {
158
+ 'opt_encoder': self.opt_encoder.state_dict(),
159
+ 'ep': ep,
160
+ 'total_it': total_it
161
+ }
162
+ try:
163
+ state['encoder'] = self.encoder.module.state_dict()
164
+ except:
165
+ state['encoder'] = self.encoder.state_dict()
166
+ torch.save(state, file_name)
167
+ return
168
+
169
+ def load(self, model_dir):
170
+ checkpoint = torch.load(model_dir, map_location=self.device)
171
+ if self.opt.is_train:
172
+ self.opt_encoder.load_state_dict(checkpoint['opt_encoder'])
173
+ self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
174
+ return checkpoint['ep'], checkpoint.get('total_it', 0)
175
+
176
+ def train(self, train_dataset):
177
+ rank, world_size = get_dist_info()
178
+ self.to(self.device)
179
+ self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr)
180
+ it = 0
181
+ cur_epoch = 0
182
+ if self.opt.is_continue:
183
+ model_dir = pjoin(self.opt.model_dir, 'latest.tar')
184
+ cur_epoch, it = self.load(model_dir)
185
+
186
+ start_time = time.time()
187
+
188
+ train_loader = build_dataloader(
189
+ train_dataset,
190
+ samples_per_gpu=self.opt.batch_size,
191
+ drop_last=True,
192
+ workers_per_gpu=4,
193
+ shuffle=True)
194
+
195
+ logs = OrderedDict()
196
+ for epoch in range(cur_epoch, self.opt.num_epochs):
197
+ self.train_mode()
198
+ for i, batch_data in enumerate(train_loader):
199
+ self.forward(batch_data)
200
+ log_dict = self.update()
201
+ for k, v in log_dict.items():
202
+ if k not in logs:
203
+ logs[k] = v
204
+ else:
205
+ logs[k] += v
206
+ it += 1
207
+ if it % self.opt.log_every == 0 and rank == 0:
208
+ mean_loss = OrderedDict({})
209
+ for tag, value in logs.items():
210
+ mean_loss[tag] = value / self.opt.log_every
211
+ logs = OrderedDict()
212
+ print_current_loss(start_time, it, mean_loss, epoch, inner_iter=i)
213
+
214
+ if it % self.opt.save_latest == 0 and rank == 0:
215
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
216
+
217
+ if rank == 0:
218
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
219
+
220
+ if epoch % self.opt.save_every_e == 0 and rank == 0:
221
+ self.save(pjoin(self.opt.model_dir, 'ckpt_e%03d.tar'%(epoch)),
222
+ epoch, total_it=it)
utils/__init__.py ADDED
File without changes
utils/get_opt.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import Namespace
3
+ import re
4
+ from os.path import join as pjoin
5
+ from utils.word_vectorizer import POS_enumerator
6
+
7
+
8
+ def is_float(numStr):
9
+ flag = False
10
+ numStr = str(numStr).strip().lstrip('-').lstrip('+')
11
+ try:
12
+ reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$')
13
+ res = reg.match(str(numStr))
14
+ if res:
15
+ flag = True
16
+ except Exception as ex:
17
+ print("is_float() - error: " + str(ex))
18
+ return flag
19
+
20
+
21
+ def is_number(numStr):
22
+ flag = False
23
+ numStr = str(numStr).strip().lstrip('-').lstrip('+')
24
+ if str(numStr).isdigit():
25
+ flag = True
26
+ return flag
27
+
28
+
29
+ def get_opt(opt_path, device):
30
+ opt = Namespace()
31
+ opt_dict = vars(opt)
32
+
33
+ skip = ('-------------- End ----------------',
34
+ '------------ Options -------------',
35
+ '\n')
36
+ print('Reading', opt_path)
37
+ with open(opt_path) as f:
38
+ for line in f:
39
+ if line.strip() not in skip:
40
+ # print(line.strip())
41
+ key, value = line.strip().split(': ')
42
+ if value in ('True', 'False'):
43
+ opt_dict[key] = True if value == 'True' else False
44
+ elif is_float(value):
45
+ opt_dict[key] = float(value)
46
+ elif is_number(value):
47
+ opt_dict[key] = int(value)
48
+ else:
49
+ opt_dict[key] = str(value)
50
+
51
+ opt_dict['which_epoch'] = 'latest'
52
+ if 'num_layers' not in opt_dict:
53
+ opt_dict['num_layers'] = 8
54
+ if 'latent_dim' not in opt_dict:
55
+ opt_dict['latent_dim'] = 512
56
+ if 'diffusion_steps' not in opt_dict:
57
+ opt_dict['diffusion_steps'] = 1000
58
+ if 'no_clip' not in opt_dict:
59
+ opt_dict['no_clip'] = False
60
+ if 'no_eff' not in opt_dict:
61
+ opt_dict['no_eff'] = False
62
+
63
+ opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
64
+ opt.model_dir = pjoin(opt.save_root, 'model')
65
+ opt.meta_dir = pjoin(opt.save_root, 'meta')
66
+
67
+ if opt.dataset_name == 't2m':
68
+ opt.data_root = './data/HumanML3D'
69
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
70
+ opt.text_dir = pjoin(opt.data_root, 'texts')
71
+ opt.joints_num = 22
72
+ opt.dim_pose = 263
73
+ opt.max_motion_length = 196
74
+ elif opt.dataset_name == 'kit':
75
+ opt.data_root = './data/KIT-ML'
76
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
77
+ opt.text_dir = pjoin(opt.data_root, 'texts')
78
+ opt.joints_num = 21
79
+ opt.dim_pose = 251
80
+ opt.max_motion_length = 196
81
+ else:
82
+ raise KeyError('Dataset not recognized')
83
+
84
+ opt.dim_word = 300
85
+ opt.num_classes = 200 // opt.unit_length
86
+ opt.dim_pos_ohot = len(POS_enumerator)
87
+ opt.is_train = False
88
+ opt.is_continue = False
89
+ opt.device = device
90
+
91
+ return opt
utils/metrics.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import linalg
3
+
4
+
5
+ # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
6
+ def euclidean_distance_matrix(matrix1, matrix2):
7
+ """
8
+ Params:
9
+ -- matrix1: N1 x D
10
+ -- matrix2: N2 x D
11
+ Returns:
12
+ -- dist: N1 x N2
13
+ dist[i, j] == distance(matrix1[i], matrix2[j])
14
+ """
15
+ assert matrix1.shape[1] == matrix2.shape[1]
16
+ d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
17
+ d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
18
+ d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
19
+ dists = np.sqrt(d1 + d2 + d3) # broadcasting
20
+ return dists
21
+
22
+ def calculate_top_k(mat, top_k):
23
+ size = mat.shape[0]
24
+ gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
25
+ bool_mat = (mat == gt_mat)
26
+ correct_vec = False
27
+ top_k_list = []
28
+ for i in range(top_k):
29
+ # print(correct_vec, bool_mat[:, i])
30
+ correct_vec = (correct_vec | bool_mat[:, i])
31
+ # print(correct_vec)
32
+ top_k_list.append(correct_vec[:, None])
33
+ top_k_mat = np.concatenate(top_k_list, axis=1)
34
+ return top_k_mat
35
+
36
+
37
+ def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
38
+ dist_mat = euclidean_distance_matrix(embedding1, embedding2)
39
+ argmax = np.argsort(dist_mat, axis=1)
40
+ top_k_mat = calculate_top_k(argmax, top_k)
41
+ if sum_all:
42
+ return top_k_mat.sum(axis=0)
43
+ else:
44
+ return top_k_mat
45
+
46
+
47
+ def calculate_matching_score(embedding1, embedding2, sum_all=False):
48
+ assert len(embedding1.shape) == 2
49
+ assert embedding1.shape[0] == embedding2.shape[0]
50
+ assert embedding1.shape[1] == embedding2.shape[1]
51
+
52
+ dist = linalg.norm(embedding1 - embedding2, axis=1)
53
+ if sum_all:
54
+ return dist.sum(axis=0)
55
+ else:
56
+ return dist
57
+
58
+
59
+
60
+ def calculate_activation_statistics(activations):
61
+ """
62
+ Params:
63
+ -- activation: num_samples x dim_feat
64
+ Returns:
65
+ -- mu: dim_feat
66
+ -- sigma: dim_feat x dim_feat
67
+ """
68
+ mu = np.mean(activations, axis=0)
69
+ cov = np.cov(activations, rowvar=False)
70
+ return mu, cov
71
+
72
+
73
+ def calculate_diversity(activation, diversity_times):
74
+ assert len(activation.shape) == 2
75
+ assert activation.shape[0] > diversity_times
76
+ num_samples = activation.shape[0]
77
+
78
+ first_indices = np.random.choice(num_samples, diversity_times, replace=False)
79
+ second_indices = np.random.choice(num_samples, diversity_times, replace=False)
80
+ dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
81
+ return dist.mean()
82
+
83
+
84
+ def calculate_multimodality(activation, multimodality_times):
85
+ assert len(activation.shape) == 3
86
+ assert activation.shape[1] > multimodality_times
87
+ num_per_sent = activation.shape[1]
88
+
89
+ first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
90
+ second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
91
+ dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
92
+ return dist.mean()
93
+
94
+
95
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
96
+ """Numpy implementation of the Frechet Distance.
97
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
98
+ and X_2 ~ N(mu_2, C_2) is
99
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
100
+ Stable version by Dougal J. Sutherland.
101
+ Params:
102
+ -- mu1 : Numpy array containing the activations of a layer of the
103
+ inception net (like returned by the function 'get_predictions')
104
+ for generated samples.
105
+ -- mu2 : The sample mean over activations, precalculated on an
106
+ representative data set.
107
+ -- sigma1: The covariance matrix over activations for generated samples.
108
+ -- sigma2: The covariance matrix over activations, precalculated on an
109
+ representative data set.
110
+ Returns:
111
+ -- : The Frechet Distance.
112
+ """
113
+
114
+ mu1 = np.atleast_1d(mu1)
115
+ mu2 = np.atleast_1d(mu2)
116
+
117
+ sigma1 = np.atleast_2d(sigma1)
118
+ sigma2 = np.atleast_2d(sigma2)
119
+
120
+ assert mu1.shape == mu2.shape, \
121
+ 'Training and test mean vectors have different lengths'
122
+ assert sigma1.shape == sigma2.shape, \
123
+ 'Training and test covariances have different dimensions'
124
+
125
+ diff = mu1 - mu2
126
+
127
+ # Product might be almost singular
128
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
129
+ if not np.isfinite(covmean).all():
130
+ msg = ('fid calculation produces singular product; '
131
+ 'adding %s to diagonal of cov estimates') % eps
132
+ print(msg)
133
+ offset = np.eye(sigma1.shape[0]) * eps
134
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
135
+
136
+ # Numerical error might give slight imaginary component
137
+ if np.iscomplexobj(covmean):
138
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
139
+ m = np.max(np.abs(covmean.imag))
140
+ raise ValueError('Imaginary component {}'.format(m))
141
+ covmean = covmean.real
142
+
143
+ tr_covmean = np.trace(covmean)
144
+
145
+ return (diff.dot(diff) + np.trace(sigma1) +
146
+ np.trace(sigma2) - 2 * tr_covmean)
utils/motion_process.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join as pjoin
2
+
3
+ import numpy as np
4
+ import os
5
+ from utils.quaternion import *
6
+ from utils.skeleton import Skeleton
7
+ from utils.paramUtil import *
8
+
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ # positions (batch, joint_num, 3)
13
+ def uniform_skeleton(positions, target_offset):
14
+ src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
15
+ src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0]))
16
+ src_offset = src_offset.numpy()
17
+ tgt_offset = target_offset.numpy()
18
+ # print(src_offset)
19
+ # print(tgt_offset)
20
+ '''Calculate Scale Ratio as the ratio of legs'''
21
+ src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max()
22
+ tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max()
23
+
24
+ scale_rt = tgt_leg_len / src_leg_len
25
+ # print(scale_rt)
26
+ src_root_pos = positions[:, 0]
27
+ tgt_root_pos = src_root_pos * scale_rt
28
+
29
+ '''Inverse Kinematics'''
30
+ quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx)
31
+ # print(quat_params.shape)
32
+
33
+ '''Forward Kinematics'''
34
+ src_skel.set_offset(target_offset)
35
+ new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos)
36
+ return new_joints
37
+
38
+
39
+ def extract_features(positions, feet_thre, n_raw_offsets, kinematic_chain, face_joint_indx, fid_r, fid_l):
40
+ global_positions = positions.copy()
41
+ """ Get Foot Contacts """
42
+
43
+ def foot_detect(positions, thres):
44
+ velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
45
+
46
+ feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
47
+ feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
48
+ feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
49
+ # feet_l_h = positions[:-1,fid_l,1]
50
+ # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
51
+ feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float)
52
+
53
+ feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
54
+ feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
55
+ feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
56
+ # feet_r_h = positions[:-1,fid_r,1]
57
+ # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
58
+ feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float)
59
+ return feet_l, feet_r
60
+
61
+ #
62
+ feet_l, feet_r = foot_detect(positions, feet_thre)
63
+ # feet_l, feet_r = foot_detect(positions, 0.002)
64
+
65
+ '''Quaternion and Cartesian representation'''
66
+ r_rot = None
67
+
68
+ def get_rifke(positions):
69
+ '''Local pose'''
70
+ positions[..., 0] -= positions[:, 0:1, 0]
71
+ positions[..., 2] -= positions[:, 0:1, 2]
72
+ '''All pose face Z+'''
73
+ positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
74
+ return positions
75
+
76
+ def get_quaternion(positions):
77
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
78
+ # (seq_len, joints_num, 4)
79
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
80
+
81
+ '''Fix Quaternion Discontinuity'''
82
+ quat_params = qfix(quat_params)
83
+ # (seq_len, 4)
84
+ r_rot = quat_params[:, 0].copy()
85
+ # print(r_rot[0])
86
+ '''Root Linear Velocity'''
87
+ # (seq_len - 1, 3)
88
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
89
+ # print(r_rot.shape, velocity.shape)
90
+ velocity = qrot_np(r_rot[1:], velocity)
91
+ '''Root Angular Velocity'''
92
+ # (seq_len - 1, 4)
93
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
94
+ quat_params[1:, 0] = r_velocity
95
+ # (seq_len, joints_num, 4)
96
+ return quat_params, r_velocity, velocity, r_rot
97
+
98
+ def get_cont6d_params(positions):
99
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
100
+ # (seq_len, joints_num, 4)
101
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
102
+
103
+ '''Quaternion to continuous 6D'''
104
+ cont_6d_params = quaternion_to_cont6d_np(quat_params)
105
+ # (seq_len, 4)
106
+ r_rot = quat_params[:, 0].copy()
107
+ # print(r_rot[0])
108
+ '''Root Linear Velocity'''
109
+ # (seq_len - 1, 3)
110
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
111
+ # print(r_rot.shape, velocity.shape)
112
+ velocity = qrot_np(r_rot[1:], velocity)
113
+ '''Root Angular Velocity'''
114
+ # (seq_len - 1, 4)
115
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
116
+ # (seq_len, joints_num, 4)
117
+ return cont_6d_params, r_velocity, velocity, r_rot
118
+
119
+ cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
120
+ positions = get_rifke(positions)
121
+
122
+ # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0)
123
+ # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]])
124
+
125
+ # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
126
+ # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r')
127
+ # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g')
128
+ # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y')
129
+ # plt.xlabel('x')
130
+ # plt.ylabel('z')
131
+ # plt.axis('equal')
132
+ # plt.show()
133
+
134
+ '''Root height'''
135
+ root_y = positions[:, 0, 1:2]
136
+
137
+ '''Root rotation and linear velocity'''
138
+ # (seq_len-1, 1) rotation velocity along y-axis
139
+ # (seq_len-1, 2) linear velovity on xz plane
140
+ r_velocity = np.arcsin(r_velocity[:, 2:3])
141
+ l_velocity = velocity[:, [0, 2]]
142
+ # print(r_velocity.shape, l_velocity.shape, root_y.shape)
143
+ root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
144
+
145
+ '''Get Joint Rotation Representation'''
146
+ # (seq_len, (joints_num-1) *6) quaternion for skeleton joints
147
+ rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
148
+
149
+ '''Get Joint Rotation Invariant Position Represention'''
150
+ # (seq_len, (joints_num-1)*3) local joint position
151
+ ric_data = positions[:, 1:].reshape(len(positions), -1)
152
+
153
+ '''Get Joint Velocity Representation'''
154
+ # (seq_len-1, joints_num*3)
155
+ local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
156
+ global_positions[1:] - global_positions[:-1])
157
+ local_vel = local_vel.reshape(len(local_vel), -1)
158
+
159
+ data = root_data
160
+ data = np.concatenate([data, ric_data[:-1]], axis=-1)
161
+ data = np.concatenate([data, rot_data[:-1]], axis=-1)
162
+ # print(data.shape, local_vel.shape)
163
+ data = np.concatenate([data, local_vel], axis=-1)
164
+ data = np.concatenate([data, feet_l, feet_r], axis=-1)
165
+
166
+ return data
167
+
168
+
169
+ def process_file(positions, feet_thre):
170
+ # (seq_len, joints_num, 3)
171
+ # '''Down Sample'''
172
+ # positions = positions[::ds_num]
173
+
174
+ '''Uniform Skeleton'''
175
+ positions = uniform_skeleton(positions, tgt_offsets)
176
+
177
+ '''Put on Floor'''
178
+ floor_height = positions.min(axis=0).min(axis=0)[1]
179
+ positions[:, :, 1] -= floor_height
180
+ # print(floor_height)
181
+
182
+ # plot_3d_motion("./positions_1.mp4", kinematic_chain, positions, 'title', fps=20)
183
+
184
+ '''XZ at origin'''
185
+ root_pos_init = positions[0]
186
+ root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1])
187
+ positions = positions - root_pose_init_xz
188
+
189
+ # '''Move the first pose to origin '''
190
+ # root_pos_init = positions[0]
191
+ # positions = positions - root_pos_init[0]
192
+
193
+ '''All initially face Z+'''
194
+ r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
195
+ across1 = root_pos_init[r_hip] - root_pos_init[l_hip]
196
+ across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l]
197
+ across = across1 + across2
198
+ across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis]
199
+
200
+ # forward (3,), rotate around y-axis
201
+ forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
202
+ # forward (3,)
203
+ forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis]
204
+
205
+ # print(forward_init)
206
+
207
+ target = np.array([[0, 0, 1]])
208
+ root_quat_init = qbetween_np(forward_init, target)
209
+ root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init
210
+
211
+ positions_b = positions.copy()
212
+
213
+ positions = qrot_np(root_quat_init, positions)
214
+
215
+ # plot_3d_motion("./positions_2.mp4", kinematic_chain, positions, 'title', fps=20)
216
+
217
+ '''New ground truth positions'''
218
+ global_positions = positions.copy()
219
+
220
+ # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
221
+ # plt.plot(positions[:, 0, 0], positions[:, 0, 2], marker='o', color='r')
222
+ # plt.xlabel('x')
223
+ # plt.ylabel('z')
224
+ # plt.axis('equal')
225
+ # plt.show()
226
+
227
+ """ Get Foot Contacts """
228
+
229
+ def foot_detect(positions, thres):
230
+ velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
231
+
232
+ feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
233
+ feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
234
+ feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
235
+ # feet_l_h = positions[:-1,fid_l,1]
236
+ # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
237
+ feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float)
238
+
239
+ feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
240
+ feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
241
+ feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
242
+ # feet_r_h = positions[:-1,fid_r,1]
243
+ # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
244
+ feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float)
245
+ return feet_l, feet_r
246
+ #
247
+ feet_l, feet_r = foot_detect(positions, feet_thre)
248
+ # feet_l, feet_r = foot_detect(positions, 0.002)
249
+
250
+ '''Quaternion and Cartesian representation'''
251
+ r_rot = None
252
+
253
+ def get_rifke(positions):
254
+ '''Local pose'''
255
+ positions[..., 0] -= positions[:, 0:1, 0]
256
+ positions[..., 2] -= positions[:, 0:1, 2]
257
+ '''All pose face Z+'''
258
+ positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
259
+ return positions
260
+
261
+ def get_quaternion(positions):
262
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
263
+ # (seq_len, joints_num, 4)
264
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
265
+
266
+ '''Fix Quaternion Discontinuity'''
267
+ quat_params = qfix(quat_params)
268
+ # (seq_len, 4)
269
+ r_rot = quat_params[:, 0].copy()
270
+ # print(r_rot[0])
271
+ '''Root Linear Velocity'''
272
+ # (seq_len - 1, 3)
273
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
274
+ # print(r_rot.shape, velocity.shape)
275
+ velocity = qrot_np(r_rot[1:], velocity)
276
+ '''Root Angular Velocity'''
277
+ # (seq_len - 1, 4)
278
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
279
+ quat_params[1:, 0] = r_velocity
280
+ # (seq_len, joints_num, 4)
281
+ return quat_params, r_velocity, velocity, r_rot
282
+
283
+ def get_cont6d_params(positions):
284
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
285
+ # (seq_len, joints_num, 4)
286
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
287
+
288
+ '''Quaternion to continuous 6D'''
289
+ cont_6d_params = quaternion_to_cont6d_np(quat_params)
290
+ # (seq_len, 4)
291
+ r_rot = quat_params[:, 0].copy()
292
+ # print(r_rot[0])
293
+ '''Root Linear Velocity'''
294
+ # (seq_len - 1, 3)
295
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
296
+ # print(r_rot.shape, velocity.shape)
297
+ velocity = qrot_np(r_rot[1:], velocity)
298
+ '''Root Angular Velocity'''
299
+ # (seq_len - 1, 4)
300
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
301
+ # (seq_len, joints_num, 4)
302
+ return cont_6d_params, r_velocity, velocity, r_rot
303
+
304
+ cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
305
+ positions = get_rifke(positions)
306
+
307
+ # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0)
308
+ # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]])
309
+
310
+ # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
311
+ # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r')
312
+ # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g')
313
+ # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y')
314
+ # plt.xlabel('x')
315
+ # plt.ylabel('z')
316
+ # plt.axis('equal')
317
+ # plt.show()
318
+
319
+ '''Root height'''
320
+ root_y = positions[:, 0, 1:2]
321
+
322
+ '''Root rotation and linear velocity'''
323
+ # (seq_len-1, 1) rotation velocity along y-axis
324
+ # (seq_len-1, 2) linear velovity on xz plane
325
+ r_velocity = np.arcsin(r_velocity[:, 2:3])
326
+ l_velocity = velocity[:, [0, 2]]
327
+ # print(r_velocity.shape, l_velocity.shape, root_y.shape)
328
+ root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
329
+
330
+ '''Get Joint Rotation Representation'''
331
+ # (seq_len, (joints_num-1) *6) quaternion for skeleton joints
332
+ rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
333
+
334
+ '''Get Joint Rotation Invariant Position Represention'''
335
+ # (seq_len, (joints_num-1)*3) local joint position
336
+ ric_data = positions[:, 1:].reshape(len(positions), -1)
337
+
338
+ '''Get Joint Velocity Representation'''
339
+ # (seq_len-1, joints_num*3)
340
+ local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
341
+ global_positions[1:] - global_positions[:-1])
342
+ local_vel = local_vel.reshape(len(local_vel), -1)
343
+
344
+ data = root_data
345
+ data = np.concatenate([data, ric_data[:-1]], axis=-1)
346
+ data = np.concatenate([data, rot_data[:-1]], axis=-1)
347
+ # print(data.shape, local_vel.shape)
348
+ data = np.concatenate([data, local_vel], axis=-1)
349
+ data = np.concatenate([data, feet_l, feet_r], axis=-1)
350
+
351
+ return data, global_positions, positions, l_velocity
352
+
353
+
354
+ # Recover global angle and positions for rotation data
355
+ # root_rot_velocity (B, seq_len, 1)
356
+ # root_linear_velocity (B, seq_len, 2)
357
+ # root_y (B, seq_len, 1)
358
+ # ric_data (B, seq_len, (joint_num - 1)*3)
359
+ # rot_data (B, seq_len, (joint_num - 1)*6)
360
+ # local_velocity (B, seq_len, joint_num*3)
361
+ # foot contact (B, seq_len, 4)
362
+ def recover_root_rot_pos(data):
363
+ rot_vel = data[..., 0]
364
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
365
+ '''Get Y-axis rotation from rotation velocity'''
366
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
367
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
368
+
369
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
370
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
371
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
372
+
373
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
374
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
375
+ '''Add Y-axis rotation to root position'''
376
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
377
+
378
+ r_pos = torch.cumsum(r_pos, dim=-2)
379
+
380
+ r_pos[..., 1] = data[..., 3]
381
+ return r_rot_quat, r_pos
382
+
383
+
384
+ def recover_from_rot(data, joints_num, skeleton):
385
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
386
+
387
+ r_rot_cont6d = quaternion_to_cont6d(r_rot_quat)
388
+
389
+ start_indx = 1 + 2 + 1 + (joints_num - 1) * 3
390
+ end_indx = start_indx + (joints_num - 1) * 6
391
+ cont6d_params = data[..., start_indx:end_indx]
392
+ # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape)
393
+ cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1)
394
+ cont6d_params = cont6d_params.view(-1, joints_num, 6)
395
+
396
+ positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos)
397
+
398
+ return positions
399
+
400
+
401
+ def recover_from_ric(data, joints_num):
402
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
403
+ positions = data[..., 4:(joints_num - 1) * 3 + 4]
404
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
405
+
406
+ '''Add Y-axis rotation to local joints'''
407
+ positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
408
+
409
+ '''Add root XZ to joints'''
410
+ positions[..., 0] += r_pos[..., 0:1]
411
+ positions[..., 2] += r_pos[..., 2:3]
412
+
413
+ '''Concate root and joints'''
414
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
415
+
416
+ return positions
417
+ '''
418
+ For Text2Motion Dataset
419
+ '''
420
+ '''
421
+ if __name__ == "__main__":
422
+ example_id = "000021"
423
+ # Lower legs
424
+ l_idx1, l_idx2 = 5, 8
425
+ # Right/Left foot
426
+ fid_r, fid_l = [8, 11], [7, 10]
427
+ # Face direction, r_hip, l_hip, sdr_r, sdr_l
428
+ face_joint_indx = [2, 1, 17, 16]
429
+ # l_hip, r_hip
430
+ r_hip, l_hip = 2, 1
431
+ joints_num = 22
432
+ # ds_num = 8
433
+ data_dir = '../dataset/pose_data_raw/joints/'
434
+ save_dir1 = '../dataset/pose_data_raw/new_joints/'
435
+ save_dir2 = '../dataset/pose_data_raw/new_joint_vecs/'
436
+
437
+ n_raw_offsets = torch.from_numpy(t2m_raw_offsets)
438
+ kinematic_chain = t2m_kinematic_chain
439
+
440
+ # Get offsets of target skeleton
441
+ example_data = np.load(os.path.join(data_dir, example_id + '.npy'))
442
+ example_data = example_data.reshape(len(example_data), -1, 3)
443
+ example_data = torch.from_numpy(example_data)
444
+ tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
445
+ # (joints_num, 3)
446
+ tgt_offsets = tgt_skel.get_offsets_joints(example_data[0])
447
+ # print(tgt_offsets)
448
+
449
+ source_list = os.listdir(data_dir)
450
+ frame_num = 0
451
+ for source_file in tqdm(source_list):
452
+ source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num]
453
+ try:
454
+ data, ground_positions, positions, l_velocity = process_file(source_data, 0.002)
455
+ rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num)
456
+ np.save(pjoin(save_dir1, source_file), rec_ric_data.squeeze().numpy())
457
+ np.save(pjoin(save_dir2, source_file), data)
458
+ frame_num += data.shape[0]
459
+ except Exception as e:
460
+ print(source_file)
461
+ print(e)
462
+
463
+ print('Total clips: %d, Frames: %d, Duration: %fm' %
464
+ (len(source_list), frame_num, frame_num / 20 / 60))
465
+ '''
466
+
467
+ if __name__ == "__main__":
468
+ example_id = "03950_gt"
469
+ # Lower legs
470
+ l_idx1, l_idx2 = 17, 18
471
+ # Right/Left foot
472
+ fid_r, fid_l = [14, 15], [19, 20]
473
+ # Face direction, r_hip, l_hip, sdr_r, sdr_l
474
+ face_joint_indx = [11, 16, 5, 8]
475
+ # l_hip, r_hip
476
+ r_hip, l_hip = 11, 16
477
+ joints_num = 21
478
+ # ds_num = 8
479
+ data_dir = '../dataset/kit_mocap_dataset/joints/'
480
+ save_dir1 = '../dataset/kit_mocap_dataset/new_joints/'
481
+ save_dir2 = '../dataset/kit_mocap_dataset/new_joint_vecs/'
482
+
483
+ n_raw_offsets = torch.from_numpy(kit_raw_offsets)
484
+ kinematic_chain = kit_kinematic_chain
485
+
486
+ '''Get offsets of target skeleton'''
487
+ example_data = np.load(os.path.join(data_dir, example_id + '.npy'))
488
+ example_data = example_data.reshape(len(example_data), -1, 3)
489
+ example_data = torch.from_numpy(example_data)
490
+ tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
491
+ # (joints_num, 3)
492
+ tgt_offsets = tgt_skel.get_offsets_joints(example_data[0])
493
+ # print(tgt_offsets)
494
+
495
+ source_list = os.listdir(data_dir)
496
+ frame_num = 0
497
+ '''Read source data'''
498
+ for source_file in tqdm(source_list):
499
+ source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num]
500
+ try:
501
+ name = ''.join(source_file[:-7].split('_')) + '.npy'
502
+ data, ground_positions, positions, l_velocity = process_file(source_data, 0.05)
503
+ rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num)
504
+ if np.isnan(rec_ric_data.numpy()).any():
505
+ print(source_file)
506
+ continue
507
+ np.save(pjoin(save_dir1, name), rec_ric_data.squeeze().numpy())
508
+ np.save(pjoin(save_dir2, name), data)
509
+ frame_num += data.shape[0]
510
+ except Exception as e:
511
+ print(source_file)
512
+ print(e)
513
+
514
+ print('Total clips: %d, Frames: %d, Duration: %fm' %
515
+ (len(source_list), frame_num, frame_num / 12.5 / 60))
utils/paramUtil.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Define a kinematic tree for the skeletal struture
4
+ kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
5
+
6
+ kit_raw_offsets = np.array(
7
+ [
8
+ [0, 0, 0],
9
+ [0, 1, 0],
10
+ [0, 1, 0],
11
+ [0, 1, 0],
12
+ [0, 1, 0],
13
+ [1, 0, 0],
14
+ [0, -1, 0],
15
+ [0, -1, 0],
16
+ [-1, 0, 0],
17
+ [0, -1, 0],
18
+ [0, -1, 0],
19
+ [1, 0, 0],
20
+ [0, -1, 0],
21
+ [0, -1, 0],
22
+ [0, 0, 1],
23
+ [0, 0, 1],
24
+ [-1, 0, 0],
25
+ [0, -1, 0],
26
+ [0, -1, 0],
27
+ [0, 0, 1],
28
+ [0, 0, 1]
29
+ ]
30
+ )
31
+
32
+ t2m_raw_offsets = np.array([[0,0,0],
33
+ [1,0,0],
34
+ [-1,0,0],
35
+ [0,1,0],
36
+ [0,-1,0],
37
+ [0,-1,0],
38
+ [0,1,0],
39
+ [0,-1,0],
40
+ [0,-1,0],
41
+ [0,1,0],
42
+ [0,0,1],
43
+ [0,0,1],
44
+ [0,1,0],
45
+ [1,0,0],
46
+ [-1,0,0],
47
+ [0,0,1],
48
+ [0,-1,0],
49
+ [0,-1,0],
50
+ [0,-1,0],
51
+ [0,-1,0],
52
+ [0,-1,0],
53
+ [0,-1,0]])
54
+
55
+ t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
56
+ t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
57
+ t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
58
+
59
+
60
+ kit_tgt_skel_id = '03950'
61
+
62
+ t2m_tgt_skel_id = '000021'
63
+
utils/plot_script.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ from mpl_toolkits.mplot3d import Axes3D
6
+ from matplotlib.animation import FuncAnimation, FFMpegFileWriter
7
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
8
+ import mpl_toolkits.mplot3d.axes3d as p3
9
+ # import cv2
10
+
11
+
12
+ def list_cut_average(ll, intervals):
13
+ if intervals == 1:
14
+ return ll
15
+
16
+ bins = math.ceil(len(ll) * 1.0 / intervals)
17
+ ll_new = []
18
+ for i in range(bins):
19
+ l_low = intervals * i
20
+ l_high = l_low + intervals
21
+ l_high = l_high if l_high < len(ll) else len(ll)
22
+ ll_new.append(np.mean(ll[l_low:l_high]))
23
+ return ll_new
24
+
25
+
26
+ def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4):
27
+ matplotlib.use('Agg')
28
+
29
+ title_sp = title.split(' ')
30
+ if len(title_sp) > 20:
31
+ title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])])
32
+ elif len(title_sp) > 10:
33
+ title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
34
+
35
+ def init():
36
+ ax.set_xlim3d([-radius / 4, radius / 4])
37
+ ax.set_ylim3d([0, radius / 2])
38
+ ax.set_zlim3d([0, radius / 2])
39
+ # print(title)
40
+ fig.suptitle(title, fontsize=20)
41
+ ax.grid(b=False)
42
+
43
+ def plot_xzPlane(minx, maxx, miny, minz, maxz):
44
+ ## Plot a plane XZ
45
+ verts = [
46
+ [minx, miny, minz],
47
+ [minx, miny, maxz],
48
+ [maxx, miny, maxz],
49
+ [maxx, miny, minz]
50
+ ]
51
+ xz_plane = Poly3DCollection([verts])
52
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
53
+ ax.add_collection3d(xz_plane)
54
+
55
+ # return ax
56
+
57
+ # (seq_len, joints_num, 3)
58
+ data = joints.copy().reshape(len(joints), -1, 3)
59
+ fig = plt.figure(figsize=figsize)
60
+ ax = p3.Axes3D(fig)
61
+ init()
62
+ MINS = data.min(axis=0).min(axis=0)
63
+ MAXS = data.max(axis=0).max(axis=0)
64
+ colors = ['red', 'blue', 'black', 'red', 'blue',
65
+ 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
66
+ 'darkred', 'darkred', 'darkred', 'darkred', 'darkred']
67
+ frame_number = data.shape[0]
68
+ # print(data.shape)
69
+
70
+ height_offset = MINS[1]
71
+ data[:, :, 1] -= height_offset
72
+ trajec = data[:, 0, [0, 2]]
73
+
74
+ data[..., 0] -= data[:, 0:1, 0]
75
+ data[..., 2] -= data[:, 0:1, 2]
76
+
77
+ # print(trajec.shape)
78
+
79
+ def update(index):
80
+ # print(index)
81
+ ax.lines = []
82
+ ax.collections = []
83
+ ax.view_init(elev=120, azim=-90)
84
+ ax.dist = 7.5
85
+ # ax =
86
+ plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
87
+ MAXS[2] - trajec[index, 1])
88
+ # ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3)
89
+
90
+ if index > 1:
91
+ ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]),
92
+ trajec[:index, 1] - trajec[index, 1], linewidth=1.0,
93
+ color='blue')
94
+ # ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2])
95
+
96
+ for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
97
+ # print(color)
98
+ if i < 5:
99
+ linewidth = 4.0
100
+ else:
101
+ linewidth = 2.0
102
+ ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
103
+ color=color)
104
+ # print(trajec[:index, 0].shape)
105
+
106
+ plt.axis('off')
107
+ ax.set_xticklabels([])
108
+ ax.set_yticklabels([])
109
+ ax.set_zticklabels([])
110
+
111
+ ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
112
+
113
+ # writer = FFMpegFileWriter(fps=fps)
114
+ ani.save(save_path, fps=fps)
115
+ plt.close()
utils/quaternion.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import numpy as np
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(np.float).eps
14
+
15
+ # PyTorch-backed implementations
16
+ def qinv(q):
17
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18
+ mask = torch.ones_like(q)
19
+ mask[..., 1:] = -mask[..., 1:]
20
+ return q * mask
21
+
22
+
23
+ def qinv_np(q):
24
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25
+ return qinv(torch.from_numpy(q).float()).numpy()
26
+
27
+
28
+ def qnormalize(q):
29
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30
+ return q / torch.norm(q, dim=-1, keepdim=True)
31
+
32
+
33
+ def qmul(q, r):
34
+ """
35
+ Multiply quaternion(s) q with quaternion(s) r.
36
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37
+ Returns q*r as a tensor of shape (*, 4).
38
+ """
39
+ assert q.shape[-1] == 4
40
+ assert r.shape[-1] == 4
41
+
42
+ original_shape = q.shape
43
+
44
+ # Compute outer product
45
+ terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46
+
47
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
52
+
53
+
54
+ def qrot(q, v):
55
+ """
56
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
57
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58
+ where * denotes any number of dimensions.
59
+ Returns a tensor of shape (*, 3).
60
+ """
61
+ assert q.shape[-1] == 4
62
+ assert v.shape[-1] == 3
63
+ assert q.shape[:-1] == v.shape[:-1]
64
+
65
+ original_shape = list(v.shape)
66
+ # print(q.shape)
67
+ q = q.contiguous().view(-1, 4)
68
+ v = v.contiguous().view(-1, 3)
69
+
70
+ qvec = q[:, 1:]
71
+ uv = torch.cross(qvec, v, dim=1)
72
+ uuv = torch.cross(qvec, uv, dim=1)
73
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74
+
75
+
76
+ def qeuler(q, order, epsilon=0, deg=True):
77
+ """
78
+ Convert quaternion(s) q to Euler angles.
79
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80
+ Returns a tensor of shape (*, 3).
81
+ """
82
+ assert q.shape[-1] == 4
83
+
84
+ original_shape = list(q.shape)
85
+ original_shape[-1] = 3
86
+ q = q.view(-1, 4)
87
+
88
+ q0 = q[:, 0]
89
+ q1 = q[:, 1]
90
+ q2 = q[:, 2]
91
+ q3 = q[:, 3]
92
+
93
+ if order == 'xyz':
94
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97
+ elif order == 'yzx':
98
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101
+ elif order == 'zxy':
102
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105
+ elif order == 'xzy':
106
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109
+ elif order == 'yxz':
110
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113
+ elif order == 'zyx':
114
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117
+ else:
118
+ raise
119
+
120
+ if deg:
121
+ return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122
+ else:
123
+ return torch.stack((x, y, z), dim=1).view(original_shape)
124
+
125
+
126
+ # Numpy-backed implementations
127
+
128
+ def qmul_np(q, r):
129
+ q = torch.from_numpy(q).contiguous().float()
130
+ r = torch.from_numpy(r).contiguous().float()
131
+ return qmul(q, r).numpy()
132
+
133
+
134
+ def qrot_np(q, v):
135
+ q = torch.from_numpy(q).contiguous().float()
136
+ v = torch.from_numpy(v).contiguous().float()
137
+ return qrot(q, v).numpy()
138
+
139
+
140
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
141
+ if use_gpu:
142
+ q = torch.from_numpy(q).cuda().float()
143
+ return qeuler(q, order, epsilon).cpu().numpy()
144
+ else:
145
+ q = torch.from_numpy(q).contiguous().float()
146
+ return qeuler(q, order, epsilon).numpy()
147
+
148
+
149
+ def qfix(q):
150
+ """
151
+ Enforce quaternion continuity across the time dimension by selecting
152
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153
+ between two consecutive frames.
154
+
155
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156
+ Returns a tensor of the same shape.
157
+ """
158
+ assert len(q.shape) == 3
159
+ assert q.shape[-1] == 4
160
+
161
+ result = q.copy()
162
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
163
+ mask = dot_products < 0
164
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165
+ result[1:][mask] *= -1
166
+ return result
167
+
168
+
169
+ def euler2quat(e, order, deg=True):
170
+ """
171
+ Convert Euler angles to quaternions.
172
+ """
173
+ assert e.shape[-1] == 3
174
+
175
+ original_shape = list(e.shape)
176
+ original_shape[-1] = 4
177
+
178
+ e = e.view(-1, 3)
179
+
180
+ ## if euler angles in degrees
181
+ if deg:
182
+ e = e * np.pi / 180.
183
+
184
+ x = e[:, 0]
185
+ y = e[:, 1]
186
+ z = e[:, 2]
187
+
188
+ rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189
+ ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190
+ rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191
+
192
+ result = None
193
+ for coord in order:
194
+ if coord == 'x':
195
+ r = rx
196
+ elif coord == 'y':
197
+ r = ry
198
+ elif coord == 'z':
199
+ r = rz
200
+ else:
201
+ raise
202
+ if result is None:
203
+ result = r
204
+ else:
205
+ result = qmul(result, r)
206
+
207
+ # Reverse antipodal representation to have a non-negative "w"
208
+ if order in ['xyz', 'yzx', 'zxy']:
209
+ result *= -1
210
+
211
+ return result.view(original_shape)
212
+
213
+
214
+ def expmap_to_quaternion(e):
215
+ """
216
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
217
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219
+ Returns a tensor of shape (*, 4).
220
+ """
221
+ assert e.shape[-1] == 3
222
+
223
+ original_shape = list(e.shape)
224
+ original_shape[-1] = 4
225
+ e = e.reshape(-1, 3)
226
+
227
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228
+ w = np.cos(0.5 * theta).reshape(-1, 1)
229
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231
+
232
+
233
+ def euler_to_quaternion(e, order):
234
+ """
235
+ Convert Euler angles to quaternions.
236
+ """
237
+ assert e.shape[-1] == 3
238
+
239
+ original_shape = list(e.shape)
240
+ original_shape[-1] = 4
241
+
242
+ e = e.reshape(-1, 3)
243
+
244
+ x = e[:, 0]
245
+ y = e[:, 1]
246
+ z = e[:, 2]
247
+
248
+ rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249
+ ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250
+ rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251
+
252
+ result = None
253
+ for coord in order:
254
+ if coord == 'x':
255
+ r = rx
256
+ elif coord == 'y':
257
+ r = ry
258
+ elif coord == 'z':
259
+ r = rz
260
+ else:
261
+ raise
262
+ if result is None:
263
+ result = r
264
+ else:
265
+ result = qmul_np(result, r)
266
+
267
+ # Reverse antipodal representation to have a non-negative "w"
268
+ if order in ['xyz', 'yzx', 'zxy']:
269
+ result *= -1
270
+
271
+ return result.reshape(original_shape)
272
+
273
+
274
+ def quaternion_to_matrix(quaternions):
275
+ """
276
+ Convert rotations given as quaternions to rotation matrices.
277
+ Args:
278
+ quaternions: quaternions with real part first,
279
+ as tensor of shape (..., 4).
280
+ Returns:
281
+ Rotation matrices as tensor of shape (..., 3, 3).
282
+ """
283
+ r, i, j, k = torch.unbind(quaternions, -1)
284
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
285
+
286
+ o = torch.stack(
287
+ (
288
+ 1 - two_s * (j * j + k * k),
289
+ two_s * (i * j - k * r),
290
+ two_s * (i * k + j * r),
291
+ two_s * (i * j + k * r),
292
+ 1 - two_s * (i * i + k * k),
293
+ two_s * (j * k - i * r),
294
+ two_s * (i * k - j * r),
295
+ two_s * (j * k + i * r),
296
+ 1 - two_s * (i * i + j * j),
297
+ ),
298
+ -1,
299
+ )
300
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
301
+
302
+
303
+ def quaternion_to_matrix_np(quaternions):
304
+ q = torch.from_numpy(quaternions).contiguous().float()
305
+ return quaternion_to_matrix(q).numpy()
306
+
307
+
308
+ def quaternion_to_cont6d_np(quaternions):
309
+ rotation_mat = quaternion_to_matrix_np(quaternions)
310
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311
+ return cont_6d
312
+
313
+
314
+ def quaternion_to_cont6d(quaternions):
315
+ rotation_mat = quaternion_to_matrix(quaternions)
316
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317
+ return cont_6d
318
+
319
+
320
+ def cont6d_to_matrix(cont6d):
321
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322
+ x_raw = cont6d[..., 0:3]
323
+ y_raw = cont6d[..., 3:6]
324
+
325
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326
+ z = torch.cross(x, y_raw, dim=-1)
327
+ z = z / torch.norm(z, dim=-1, keepdim=True)
328
+
329
+ y = torch.cross(z, x, dim=-1)
330
+
331
+ x = x[..., None]
332
+ y = y[..., None]
333
+ z = z[..., None]
334
+
335
+ mat = torch.cat([x, y, z], dim=-1)
336
+ return mat
337
+
338
+
339
+ def cont6d_to_matrix_np(cont6d):
340
+ q = torch.from_numpy(cont6d).contiguous().float()
341
+ return cont6d_to_matrix(q).numpy()
342
+
343
+
344
+ def qpow(q0, t, dtype=torch.float):
345
+ ''' q0 : tensor of quaternions
346
+ t: tensor of powers
347
+ '''
348
+ q0 = qnormalize(q0)
349
+ theta0 = torch.acos(q0[..., 0])
350
+
351
+ ## if theta0 is close to zero, add epsilon to avoid NaNs
352
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
354
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355
+
356
+ if isinstance(t, torch.Tensor):
357
+ q = torch.zeros(t.shape + q0.shape)
358
+ theta = t.view(-1, 1) * theta0.view(1, -1)
359
+ else: ## if t is a number
360
+ q = torch.zeros(q0.shape)
361
+ theta = t * theta0
362
+
363
+ q[..., 0] = torch.cos(theta)
364
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365
+
366
+ return q.to(dtype)
367
+
368
+
369
+ def qslerp(q0, q1, t):
370
+ '''
371
+ q0: starting quaternion
372
+ q1: ending quaternion
373
+ t: array of points along the way
374
+
375
+ Returns:
376
+ Tensor of Slerps: t.shape + q0.shape
377
+ '''
378
+
379
+ q0 = qnormalize(q0)
380
+ q1 = qnormalize(q1)
381
+ q_ = qpow(qmul(q1, qinv(q0)), t)
382
+
383
+ return qmul(q_,
384
+ q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385
+
386
+
387
+ def qbetween(v0, v1):
388
+ '''
389
+ find the quaternion used to rotate v0 to v1
390
+ '''
391
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393
+
394
+ v = torch.cross(v0, v1)
395
+ w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396
+ keepdim=True)
397
+ return qnormalize(torch.cat([w, v], dim=-1))
398
+
399
+
400
+ def qbetween_np(v0, v1):
401
+ '''
402
+ find the quaternion used to rotate v0 to v1
403
+ '''
404
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406
+
407
+ v0 = torch.from_numpy(v0).float()
408
+ v1 = torch.from_numpy(v1).float()
409
+ return qbetween(v0, v1).numpy()
410
+
411
+
412
+ def lerp(p0, p1, t):
413
+ if not isinstance(t, torch.Tensor):
414
+ t = torch.Tensor([t])
415
+
416
+ new_shape = t.shape + p0.shape
417
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419
+ p0 = p0.view(new_view_p).expand(new_shape)
420
+ p1 = p1.view(new_view_p).expand(new_shape)
421
+ t = t.view(new_view_t).expand(new_shape)
422
+
423
+ return p0 + t * (p1 - p0)
utils/skeleton.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.quaternion import *
2
+ import scipy.ndimage.filters as filters
3
+
4
+ class Skeleton(object):
5
+ def __init__(self, offset, kinematic_tree, device):
6
+ self.device = device
7
+ self._raw_offset_np = offset.numpy()
8
+ self._raw_offset = offset.clone().detach().to(device).float()
9
+ self._kinematic_tree = kinematic_tree
10
+ self._offset = None
11
+ self._parents = [0] * len(self._raw_offset)
12
+ self._parents[0] = -1
13
+ for chain in self._kinematic_tree:
14
+ for j in range(1, len(chain)):
15
+ self._parents[chain[j]] = chain[j-1]
16
+
17
+ def njoints(self):
18
+ return len(self._raw_offset)
19
+
20
+ def offset(self):
21
+ return self._offset
22
+
23
+ def set_offset(self, offsets):
24
+ self._offset = offsets.clone().detach().to(self.device).float()
25
+
26
+ def kinematic_tree(self):
27
+ return self._kinematic_tree
28
+
29
+ def parents(self):
30
+ return self._parents
31
+
32
+ # joints (batch_size, joints_num, 3)
33
+ def get_offsets_joints_batch(self, joints):
34
+ assert len(joints.shape) == 3
35
+ _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
36
+ for i in range(1, self._raw_offset.shape[0]):
37
+ _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
38
+
39
+ self._offset = _offsets.detach()
40
+ return _offsets
41
+
42
+ # joints (joints_num, 3)
43
+ def get_offsets_joints(self, joints):
44
+ assert len(joints.shape) == 2
45
+ _offsets = self._raw_offset.clone()
46
+ for i in range(1, self._raw_offset.shape[0]):
47
+ # print(joints.shape)
48
+ _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
49
+
50
+ self._offset = _offsets.detach()
51
+ return _offsets
52
+
53
+ # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
54
+ # joints (batch_size, joints_num, 3)
55
+ def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
56
+ assert len(face_joint_idx) == 4
57
+ '''Get Forward Direction'''
58
+ l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
59
+ across1 = joints[:, r_hip] - joints[:, l_hip]
60
+ across2 = joints[:, sdr_r] - joints[:, sdr_l]
61
+ across = across1 + across2
62
+ across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
63
+ # print(across1.shape, across2.shape)
64
+
65
+ # forward (batch_size, 3)
66
+ forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
67
+ if smooth_forward:
68
+ forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
69
+ # forward (batch_size, 3)
70
+ forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
71
+
72
+ '''Get Root Rotation'''
73
+ target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
74
+ root_quat = qbetween_np(forward, target)
75
+
76
+ '''Inverse Kinematics'''
77
+ # quat_params (batch_size, joints_num, 4)
78
+ # print(joints.shape[:-1])
79
+ quat_params = np.zeros(joints.shape[:-1] + (4,))
80
+ # print(quat_params.shape)
81
+ root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
82
+ quat_params[:, 0] = root_quat
83
+ # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
84
+ for chain in self._kinematic_tree:
85
+ R = root_quat
86
+ for j in range(len(chain) - 1):
87
+ # (batch, 3)
88
+ u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
89
+ # print(u.shape)
90
+ # (batch, 3)
91
+ v = joints[:, chain[j+1]] - joints[:, chain[j]]
92
+ v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
93
+ # print(u.shape, v.shape)
94
+ rot_u_v = qbetween_np(u, v)
95
+
96
+ R_loc = qmul_np(qinv_np(R), rot_u_v)
97
+
98
+ quat_params[:,chain[j + 1], :] = R_loc
99
+ R = qmul_np(R, R_loc)
100
+
101
+ return quat_params
102
+
103
+ # Be sure root joint is at the beginning of kinematic chains
104
+ def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
105
+ # quat_params (batch_size, joints_num, 4)
106
+ # joints (batch_size, joints_num, 3)
107
+ # root_pos (batch_size, 3)
108
+ if skel_joints is not None:
109
+ offsets = self.get_offsets_joints_batch(skel_joints)
110
+ if len(self._offset.shape) == 2:
111
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
112
+ joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
113
+ joints[:, 0] = root_pos
114
+ for chain in self._kinematic_tree:
115
+ if do_root_R:
116
+ R = quat_params[:, 0]
117
+ else:
118
+ R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
119
+ for i in range(1, len(chain)):
120
+ R = qmul(R, quat_params[:, chain[i]])
121
+ offset_vec = offsets[:, chain[i]]
122
+ joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
123
+ return joints
124
+
125
+ # Be sure root joint is at the beginning of kinematic chains
126
+ def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
127
+ # quat_params (batch_size, joints_num, 4)
128
+ # joints (batch_size, joints_num, 3)
129
+ # root_pos (batch_size, 3)
130
+ if skel_joints is not None:
131
+ skel_joints = torch.from_numpy(skel_joints)
132
+ offsets = self.get_offsets_joints_batch(skel_joints)
133
+ if len(self._offset.shape) == 2:
134
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
135
+ offsets = offsets.numpy()
136
+ joints = np.zeros(quat_params.shape[:-1] + (3,))
137
+ joints[:, 0] = root_pos
138
+ for chain in self._kinematic_tree:
139
+ if do_root_R:
140
+ R = quat_params[:, 0]
141
+ else:
142
+ R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
143
+ for i in range(1, len(chain)):
144
+ R = qmul_np(R, quat_params[:, chain[i]])
145
+ offset_vec = offsets[:, chain[i]]
146
+ joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
147
+ return joints
148
+
149
+ def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
150
+ # cont6d_params (batch_size, joints_num, 6)
151
+ # joints (batch_size, joints_num, 3)
152
+ # root_pos (batch_size, 3)
153
+ if skel_joints is not None:
154
+ skel_joints = torch.from_numpy(skel_joints)
155
+ offsets = self.get_offsets_joints_batch(skel_joints)
156
+ if len(self._offset.shape) == 2:
157
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
158
+ offsets = offsets.numpy()
159
+ joints = np.zeros(cont6d_params.shape[:-1] + (3,))
160
+ joints[:, 0] = root_pos
161
+ for chain in self._kinematic_tree:
162
+ if do_root_R:
163
+ matR = cont6d_to_matrix_np(cont6d_params[:, 0])
164
+ else:
165
+ matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
166
+ for i in range(1, len(chain)):
167
+ matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
168
+ offset_vec = offsets[:, chain[i]][..., np.newaxis]
169
+ # print(matR.shape, offset_vec.shape)
170
+ joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
171
+ return joints
172
+
173
+ def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
174
+ # cont6d_params (batch_size, joints_num, 6)
175
+ # joints (batch_size, joints_num, 3)
176
+ # root_pos (batch_size, 3)
177
+ if skel_joints is not None:
178
+ # skel_joints = torch.from_numpy(skel_joints)
179
+ offsets = self.get_offsets_joints_batch(skel_joints)
180
+ if len(self._offset.shape) == 2:
181
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
182
+ joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
183
+ joints[..., 0, :] = root_pos
184
+ for chain in self._kinematic_tree:
185
+ if do_root_R:
186
+ matR = cont6d_to_matrix(cont6d_params[:, 0])
187
+ else:
188
+ matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
189
+ for i in range(1, len(chain)):
190
+ matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
191
+ offset_vec = offsets[:, chain[i]].unsqueeze(-1)
192
+ # print(matR.shape, offset_vec.shape)
193
+ joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
194
+ return joints
195
+
196
+
197
+
198
+
199
+
utils/utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ # import cv2
4
+ from PIL import Image
5
+ from utils import paramUtil
6
+ import math
7
+ import time
8
+ import matplotlib.pyplot as plt
9
+ from scipy.ndimage import gaussian_filter
10
+
11
+
12
+ def mkdir(path):
13
+ if not os.path.exists(path):
14
+ os.makedirs(path)
15
+
16
+ COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
17
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
18
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
19
+
20
+ MISSING_VALUE = -1
21
+
22
+ def save_image(image_numpy, image_path):
23
+ img_pil = Image.fromarray(image_numpy)
24
+ img_pil.save(image_path)
25
+
26
+
27
+ def save_logfile(log_loss, save_path):
28
+ with open(save_path, 'wt') as f:
29
+ for k, v in log_loss.items():
30
+ w_line = k
31
+ for digit in v:
32
+ w_line += ' %.3f' % digit
33
+ f.write(w_line + '\n')
34
+
35
+
36
+ def print_current_loss(start_time, niter_state, losses, epoch=None, inner_iter=None):
37
+
38
+ def as_minutes(s):
39
+ m = math.floor(s / 60)
40
+ s -= m * 60
41
+ return '%dm %ds' % (m, s)
42
+
43
+ def time_since(since, percent):
44
+ now = time.time()
45
+ s = now - since
46
+ es = s / percent
47
+ rs = es - s
48
+ return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
49
+
50
+ if epoch is not None:
51
+ print('epoch: %3d niter: %6d inner_iter: %4d' % (epoch, niter_state, inner_iter), end=" ")
52
+
53
+ now = time.time()
54
+ message = '%s'%(as_minutes(now - start_time))
55
+
56
+ for k, v in losses.items():
57
+ message += ' %s: %.4f ' % (k, v)
58
+ print(message)
59
+
60
+
61
+ def compose_gif_img_list(img_list, fp_out, duration):
62
+ img, *imgs = [Image.fromarray(np.array(image)) for image in img_list]
63
+ img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False,
64
+ save_all=True, loop=0, duration=duration)
65
+
66
+
67
+ def save_images(visuals, image_path):
68
+ if not os.path.exists(image_path):
69
+ os.makedirs(image_path)
70
+
71
+ for i, (label, img_numpy) in enumerate(visuals.items()):
72
+ img_name = '%d_%s.jpg' % (i, label)
73
+ save_path = os.path.join(image_path, img_name)
74
+ save_image(img_numpy, save_path)
75
+
76
+
77
+ def save_images_test(visuals, image_path, from_name, to_name):
78
+ if not os.path.exists(image_path):
79
+ os.makedirs(image_path)
80
+
81
+ for i, (label, img_numpy) in enumerate(visuals.items()):
82
+ img_name = "%s_%s_%s" % (from_name, to_name, label)
83
+ save_path = os.path.join(image_path, img_name)
84
+ save_image(img_numpy, save_path)
85
+
86
+
87
+ def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)):
88
+ # print(col, row)
89
+ compose_img = compose_image(img_list, col, row, img_size)
90
+ if not os.path.exists(save_dir):
91
+ os.makedirs(save_dir)
92
+ img_path = os.path.join(save_dir, img_name)
93
+ # print(img_path)
94
+ compose_img.save(img_path)
95
+
96
+
97
+ def compose_image(img_list, col, row, img_size):
98
+ to_image = Image.new('RGB', (col * img_size[0], row * img_size[1]))
99
+ for y in range(0, row):
100
+ for x in range(0, col):
101
+ from_img = Image.fromarray(img_list[y * col + x])
102
+ # print((x * img_size[0], y*img_size[1],
103
+ # (x + 1) * img_size[0], (y + 1) * img_size[1]))
104
+ paste_area = (x * img_size[0], y*img_size[1],
105
+ (x + 1) * img_size[0], (y + 1) * img_size[1])
106
+ to_image.paste(from_img, paste_area)
107
+ # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img
108
+ return to_image
109
+
110
+
111
+ def list_cut_average(ll, intervals):
112
+ if intervals == 1:
113
+ return ll
114
+
115
+ bins = math.ceil(len(ll) * 1.0 / intervals)
116
+ ll_new = []
117
+ for i in range(bins):
118
+ l_low = intervals * i
119
+ l_high = l_low + intervals
120
+ l_high = l_high if l_high < len(ll) else len(ll)
121
+ ll_new.append(np.mean(ll[l_low:l_high]))
122
+ return ll_new
123
+
124
+
125
+ def motion_temporal_filter(motion, sigma=1):
126
+ motion = motion.reshape(motion.shape[0], -1)
127
+ # print(motion.shape)
128
+ for i in range(motion.shape[1]):
129
+ motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest")
130
+ return motion.reshape(motion.shape[0], -1, 3)
131
+
utils/word_vectorizer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ from os.path import join as pjoin
4
+
5
+ POS_enumerator = {
6
+ 'VERB': 0,
7
+ 'NOUN': 1,
8
+ 'DET': 2,
9
+ 'ADP': 3,
10
+ 'NUM': 4,
11
+ 'AUX': 5,
12
+ 'PRON': 6,
13
+ 'ADJ': 7,
14
+ 'ADV': 8,
15
+ 'Loc_VIP': 9,
16
+ 'Body_VIP': 10,
17
+ 'Obj_VIP': 11,
18
+ 'Act_VIP': 12,
19
+ 'Desc_VIP': 13,
20
+ 'OTHER': 14,
21
+ }
22
+
23
+ Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
24
+ 'up', 'down', 'straight', 'curve')
25
+
26
+ Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
27
+
28
+ Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
29
+
30
+ Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
31
+ 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
32
+ 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
33
+
34
+ Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
35
+ 'angrily', 'sadly')
36
+
37
+ VIP_dict = {
38
+ 'Loc_VIP': Loc_list,
39
+ 'Body_VIP': Body_list,
40
+ 'Obj_VIP': Obj_List,
41
+ 'Act_VIP': Act_list,
42
+ 'Desc_VIP': Desc_list,
43
+ }
44
+
45
+
46
+ class WordVectorizer(object):
47
+ def __init__(self, meta_root, prefix):
48
+ vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix))
49
+ words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb'))
50
+ word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb'))
51
+ self.word2vec = {w: vectors[word2idx[w]] for w in words}
52
+
53
+ def _get_pos_ohot(self, pos):
54
+ pos_vec = np.zeros(len(POS_enumerator))
55
+ if pos in POS_enumerator:
56
+ pos_vec[POS_enumerator[pos]] = 1
57
+ else:
58
+ pos_vec[POS_enumerator['OTHER']] = 1
59
+ return pos_vec
60
+
61
+ def __len__(self):
62
+ return len(self.word2vec)
63
+
64
+ def __getitem__(self, item):
65
+ word, pos = item.split('/')
66
+ if word in self.word2vec:
67
+ word_vec = self.word2vec[word]
68
+ vip_pos = None
69
+ for key, values in VIP_dict.items():
70
+ if word in values:
71
+ vip_pos = key
72
+ break
73
+ if vip_pos is not None:
74
+ pos_vec = self._get_pos_ohot(vip_pos)
75
+ else:
76
+ pos_vec = self._get_pos_ohot(pos)
77
+ else:
78
+ word_vec = self.word2vec['unk']
79
+ pos_vec = self._get_pos_ohot('OTHER')
80
+ return word_vec, pos_vec