Mathis Petrovich commited on
Commit
83f52e6
1 Parent(s): efba2f3

First commit

Browse files
amass-annotations/amass_to_babel.json ADDED
The diff for this file is too large to render. See raw diff
 
amass-annotations/humanml3d.json ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import os
3
+
4
+ import torch
5
+ import numpy as np
6
+ import gradio as gr
7
+ import gdown
8
+
9
+ from load import load_model, load_json
10
+ from load import load_unit_motion_embs_splits, load_keyids_splits
11
+
12
+
13
+ EXAMPLES = [
14
+ "A person is walking in a circle",
15
+ "A person is jumping rope",
16
+ "Someone is doing a backflip",
17
+ "A person is doing a moonwalk",
18
+ "A person walks forward and then turns back",
19
+ "Picking up an object",
20
+ "A person is swimming in the sea",
21
+ "A human is squatting",
22
+ "Someone is jumping with one foot",
23
+ "A person is chopping vegetables",
24
+ "Someone walks backward",
25
+ "Somebody is ascending a staircase",
26
+ "A person is sitting down",
27
+ "A person is taking the stairs",
28
+ "Someone is doing jumping jacks",
29
+ "The person walked forward and is picking up his toolbox",
30
+ "The person angrily punching the air."
31
+ ]
32
+
33
+ # Show closest text in the training
34
+
35
+
36
+ # css to make videos look nice
37
+ CSS = """
38
+ video {
39
+ position: relative;
40
+ margin: 0;
41
+ box-shadow: var(--block-shadow);
42
+ border-width: var(--block-border-width);
43
+ border-color: var(--block-border-color);
44
+ border-radius: var(--block-radius);
45
+ background: var(--block-background-fill);
46
+ width: 100%;
47
+ line-height: var(--line-sm);
48
+ }
49
+ """
50
+
51
+
52
+ def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
53
+ # Don't show the mirrored version of HumanMl3D
54
+ if "M" in keyid:
55
+ return None
56
+
57
+ dico = h3d_index[keyid]
58
+ path = dico["path"]
59
+
60
+ # HumanAct12 motions are not rendered online
61
+ # so we skip them for now
62
+ if "humanact12" in path:
63
+ return None
64
+
65
+ # This motion is not rendered in BABEL
66
+ # so we skip them for now
67
+ if path not in amass_to_babel:
68
+ return None
69
+
70
+ babel_id = amass_to_babel[path].zfill(6)
71
+ url = f"https://babel-renders.s3.eu-central-1.amazonaws.com/{babel_id}.mp4"
72
+
73
+ # For the demo, we retrieve from the first annotation only
74
+ ann = dico["annotations"][0]
75
+ start = ann["start"]
76
+ end = ann["end"]
77
+ text = ann["text"]
78
+
79
+ data = {
80
+ "url": url,
81
+ "start": start,
82
+ "end": end,
83
+ "text": text,
84
+ "keyid": keyid,
85
+ "babel_id": babel_id
86
+ }
87
+
88
+ return data
89
+
90
+
91
+ def retrieve(model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits=["test"], nmax=8):
92
+ unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
93
+ keyids = np.concatenate([all_keyids[s] for s in splits])
94
+
95
+ scores = model.compute_scores(text, unit_embs=unit_motion_embs)
96
+
97
+ sorted_idxs = np.argsort(-scores)
98
+ best_keyids = keyids[sorted_idxs]
99
+ best_scores = scores[sorted_idxs]
100
+
101
+ datas = []
102
+ for keyid, score in zip(best_keyids, best_scores):
103
+ if len(datas) == nmax:
104
+ break
105
+
106
+ data = keyid_to_url(keyid)
107
+ if data is None:
108
+ continue
109
+ data["score"] = round(float(score), 2)
110
+ datas.append(data)
111
+ return datas
112
+
113
+
114
+ # HTML component
115
+ def get_video_html(url, video_id, start=None, end=None, score=None, width=350, height=350):
116
+ trim = ""
117
+ if start is not None:
118
+ if end is not None:
119
+ trim = f"#t={start},{end}"
120
+ else:
121
+ trim = f"#t={start}"
122
+
123
+ score_t = ""
124
+ if score is not None:
125
+ score_t = f'title="Score = {score}"'
126
+
127
+ video_html = f'''
128
+ <video preload="auto" muted playsinline onpause="this.load()"
129
+ autoplay loop disablepictureinpicture id="{video_id}" width="{width}" height="{height}" {score_t}>
130
+ <source src="{url}{trim}" type="video/mp4">
131
+ Your browser does not support the video tag.
132
+ </video>
133
+ '''
134
+ return video_html
135
+
136
+
137
+ def retrive_component(retrieve_function, text, splits, nvids, n_component=16):
138
+ # cannot produce more than n_compoenent
139
+ nvids = min(nvids, n_component)
140
+ if not splits:
141
+ return [None for _ in range(n_component)]
142
+
143
+ splits_l = [x.lower() for x in splits]
144
+ datas = retrieve_function(text, splits=splits_l, nmax=nvids)
145
+ htmls = [
146
+ get_video_html(
147
+ url["url"], idx, start=url["start"],
148
+ end=url["end"], score=url["score"]
149
+ )
150
+ for idx, url in enumerate(datas)
151
+ ]
152
+ # get n_component exactly if asked less
153
+ # pad with dummy blocks
154
+ htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
155
+ return htmls
156
+
157
+
158
+ def main():
159
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
160
+
161
+ # LOADING
162
+ model = load_model(device)
163
+ splits = ["train", "val", "test"]
164
+ all_unit_motion_embs = load_unit_motion_embs_splits(splits, device)
165
+ all_keyids = load_keyids_splits(splits)
166
+
167
+ h3d_index = load_json("amass-annotations/humanml3d.json")
168
+ amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
169
+
170
+ keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
171
+ retrieve_function = partial(retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids)
172
+
173
+ # DEMO
174
+ theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
175
+ retrive_and_show = partial(retrive_component, retrieve_function)
176
+
177
+ default_text = "A person is "
178
+
179
+ with gr.Blocks(css=CSS, theme=theme) as demo:
180
+ title = "<h1 style='text-align: center'>TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis </h1>"
181
+ gr.Markdown(title)
182
+
183
+ authors = """
184
+ <h2 style='text-align: center'>
185
+ <a href="https://mathis.petrovich.fr" target="_blank"><nobr>Mathis Petrovich</nobr></a> &emsp;
186
+ <a href="https://ps.is.mpg.de/~black" target="_blank"><nobr>Michael J. Black</nobr></a> &emsp;
187
+ <a href="https://imagine.enpc.fr/~varolg" target="_blank"><nobr>G&uumll Varol</nobr></a>
188
+ </h2>
189
+ """
190
+ gr.Markdown(authors)
191
+
192
+ conf = """
193
+ <h2 style='text-align: center'>
194
+ <nobr>arXiv 2023</nobr>
195
+ </h2>
196
+ """
197
+ gr.Markdown(conf)
198
+
199
+ videos = []
200
+
201
+ with gr.Row():
202
+ with gr.Column(scale=3):
203
+ with gr.Column(scale=2):
204
+ text = gr.Textbox(placeholder="Type in natural language, the motion to retrieve",
205
+ show_label=True, label="Text prompt", value=default_text)
206
+ with gr.Column(scale=1):
207
+ btn = gr.Button("Retrieve", variant='primary')
208
+ clear = gr.Button("Clear", variant='secondary')
209
+
210
+ with gr.Row():
211
+ with gr.Column(scale=1):
212
+ splits = gr.Dropdown(["Train", "Val", "Test"],
213
+ value=["Test"], multiselect=True, label="Splits",
214
+ info="HumanML3D data used for the motion database")
215
+ with gr.Column(scale=1):
216
+ nvideo_slider = gr.Slider(minimum=4, maximum=16, step=4, value=8, label="Number of videos")
217
+ with gr.Column(scale=2):
218
+ examples = gr.Examples(examples=EXAMPLES, inputs=text, examples_per_page=15)
219
+
220
+ i = -1
221
+ # should indent
222
+ for _ in range(4):
223
+ with gr.Row():
224
+ for _ in range(4):
225
+ i += 1
226
+ with gr.Column():
227
+ video = gr.HTML()
228
+ videos.append(video)
229
+
230
+ def check_error(splits):
231
+ if not splits:
232
+ raise gr.Error("At least one split should be selected!")
233
+ return splits
234
+
235
+ btn.click(fn=retrive_and_show, inputs=[text, splits, nvideo_slider], outputs=videos).then(
236
+ fn=check_error, inputs=splits
237
+ )
238
+
239
+ text.submit(fn=retrive_and_show, inputs=[text, splits, nvideo_slider], outputs=videos).then(
240
+ fn=check_error, inputs=splits
241
+ )
242
+
243
+ def keep_test(splits):
244
+ if len(splits) == 0:
245
+ return ["Test"]
246
+ return splits
247
+
248
+ def clear_videos():
249
+ return [None for x in range(16)] + [default_text]
250
+
251
+ clear.click(fn=clear_videos, outputs=videos + [text])
252
+ demo.launch()
253
+
254
+
255
+ def prepare():
256
+ if not os.path.exists("data"):
257
+ gdown.download_folder("https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08", use_cookies=False)
258
+
259
+
260
+ if __name__ == "__main__":
261
+ prepare()
262
+ main()
263
+
264
+ # new
265
+ # A person is walking slowly
load.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import orjson
3
+ import torch
4
+ import numpy as np
5
+ from model import TMR_textencoder
6
+
7
+ EMBS = "data/unit_motion_embs"
8
+
9
+
10
+ def load_json(path):
11
+ with open(path, "rb") as ff:
12
+ return orjson.loads(ff.read())
13
+
14
+
15
+ def load_keyids(split):
16
+ path = os.path.join(EMBS, f"{split}.keyids")
17
+ with open(path) as ff:
18
+ keyids = np.array([x.strip() for x in ff.readlines()])
19
+ return keyids
20
+
21
+
22
+ def load_keyids_splits(splits):
23
+ return {
24
+ split: load_keyids(split)
25
+ for split in splits
26
+ }
27
+
28
+
29
+ def load_unit_motion_embs(split, device):
30
+ path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy")
31
+ tensor = torch.from_numpy(np.load(path)).to(device)
32
+ return tensor
33
+
34
+
35
+ def load_unit_motion_embs_splits(splits, device):
36
+ return {
37
+ split: load_unit_motion_embs(split, device)
38
+ for split in splits
39
+ }
40
+
41
+
42
+ def load_model(device):
43
+ text_params = {
44
+ 'latent_dim': 256, 'ff_size': 1024, 'num_layers': 6, 'num_heads': 4,
45
+ 'activation': 'gelu', 'modelpath': 'distilbert-base-uncased'
46
+ }
47
+ "unit_motion_embs"
48
+ model = TMR_textencoder(**text_params)
49
+ state_dict = torch.load("data/textencoder.pt", map_location=device)
50
+ # load values for the transformer only
51
+ model.load_state_dict(state_dict, strict=False)
52
+ model = model.eval()
53
+ return model
model.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ import torch.nn as nn
3
+ import os
4
+
5
+ import torch
6
+ import numpy as np
7
+ from torch import Tensor
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from transformers import logging
10
+ from torch.nn.functional import normalize
11
+
12
+
13
+ class PositionalEncoding(nn.Module):
14
+ def __init__(self, d_model, max_len=5000):
15
+ super().__init__()
16
+
17
+ pe = torch.zeros(max_len, d_model)
18
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
19
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
20
+ pe[:, 0::2] = torch.sin(position * div_term)
21
+ pe[:, 1::2] = torch.cos(position * div_term)
22
+ pe = pe.unsqueeze(0).transpose(0, 1)
23
+
24
+ self.register_buffer('pe', pe, persistent=False)
25
+
26
+ def forward(self, x):
27
+ return x + self.pe[:x.shape[0], :]
28
+
29
+
30
+ class TMR_textencoder(nn.Module):
31
+ def __init__(self, modelpath: str, latent_dim: int, ff_size: int,
32
+ num_layers: int, num_heads: int, activation: str, **kwargs) -> None:
33
+ super().__init__()
34
+
35
+ logging.set_verbosity_error()
36
+
37
+ # Tokenizer
38
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
+ self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
40
+
41
+ # Text model
42
+ self.text_model = AutoModel.from_pretrained(modelpath)
43
+ # Then configure the model
44
+ self.text_encoded_dim = self.text_model.config.hidden_size
45
+
46
+ # Projection of the text-outputs into the latent space
47
+ self.projection = nn.Sequential(
48
+ nn.ReLU(),
49
+ nn.Linear(self.text_encoded_dim, latent_dim)
50
+ )
51
+
52
+ self.mu_token = nn.Parameter(torch.randn(latent_dim))
53
+ self.logvar_token = nn.Parameter(torch.randn(latent_dim))
54
+ self.sequence_pos_encoding = PositionalEncoding(latent_dim)
55
+
56
+ seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
57
+ nhead=num_heads,
58
+ dim_feedforward=ff_size,
59
+ dropout=0.0,
60
+ activation=activation)
61
+ self.seqTransEncoder = nn.TransformerEncoder(
62
+ seq_trans_encoder_layer,
63
+ num_layers=num_layers
64
+ )
65
+
66
+ def get_last_hidden_state(self, texts: List[str],
67
+ return_mask: bool = False
68
+ ) -> Union[Tensor, tuple[Tensor, Tensor]]:
69
+ encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
70
+ output = self.text_model(**encoded_inputs.to(self.text_model.device))
71
+ if not return_mask:
72
+ return output.last_hidden_state
73
+ return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
74
+
75
+ def forward(self, texts: List[str]) -> Tensor:
76
+ text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
77
+
78
+ x = self.projection(text_encoded)
79
+ bs, nframes, _ = x.shape
80
+ # bs, nframes, totjoints, nfeats = x.shape
81
+ # Switch sequence and batch_size because the input of
82
+ # Pytorch Transformer is [Sequence, Batch size, ...]
83
+ x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
84
+
85
+ mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
86
+ logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)
87
+
88
+ # adding the distribution tokens for all sequences
89
+ xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)
90
+
91
+ # create a bigger mask, to allow attend to mu and logvar
92
+ token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
93
+ aug_mask = torch.cat((token_mask, mask), 1)
94
+
95
+ # add positional encoding
96
+ xseq = self.sequence_pos_encoding(xseq)
97
+ final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
98
+
99
+ # only mu for inference
100
+ mu = final[0]
101
+ return mu
102
+
103
+ # compute score for retrieval
104
+ def compute_scores(self, texts, unit_embs=None, embs=None):
105
+ # not both empty
106
+ assert not (unit_embs is None and embs is None)
107
+ # not both filled
108
+ assert not (unit_embs is not None and embs is not None)
109
+
110
+ output_str = False
111
+ # if one input, squeeze the output
112
+ if isinstance(texts, str):
113
+ texts = [texts]
114
+ output_str = True
115
+
116
+ # compute unit_embs from embs if not given
117
+ if embs is not None:
118
+ unit_embs = normalize(embs)
119
+
120
+ with torch.no_grad():
121
+ latent_unit_texts = normalize(self(texts))
122
+ # compute cosine similarity between 0 and 1
123
+ scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5
124
+ scores = scores.cpu().numpy()
125
+
126
+ if output_str:
127
+ scores = scores[0]
128
+
129
+ return scores