Mathis Petrovich commited on
Commit
5e4fa5e
1 Parent(s): bdb661d
Files changed (2) hide show
  1. app.py +71 -37
  2. load.py +9 -11
app.py CHANGED
@@ -56,7 +56,7 @@ EXAMPLES = [
56
  "A person is taking the stairs",
57
  "Someone is doing jumping jacks",
58
  "The person walked forward and is picking up his toolbox",
59
- "The person angrily punching the air"
60
  ]
61
 
62
  # Show closest text in the training
@@ -94,6 +94,7 @@ CSS = """
94
 
95
  DEFAULT_TEXT = "A person is "
96
 
 
97
  def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
98
  # Don't show the mirrored version of HumanMl3D
99
  if "M" in keyid:
@@ -128,13 +129,15 @@ def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
128
  "text": text,
129
  "keyid": keyid,
130
  "babel_id": babel_id,
131
- "path": path
132
  }
133
 
134
  return data
135
 
136
 
137
- def retrieve(model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits=["test"], nmax=8):
 
 
138
  unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
139
  keyids = np.concatenate([all_keyids[s] for s in splits])
140
 
@@ -169,7 +172,7 @@ def get_video_html(data, video_id, width=700, height=700):
169
  path = data["path"]
170
 
171
  trim = f"#t={start},{end}"
172
- title = f'''Score = {score}
173
 
174
  Corresponding text: {text}
175
 
@@ -177,18 +180,18 @@ HumanML3D keyid: {keyid}
177
 
178
  BABEL keyid: {babel_id}
179
 
180
- AMASS path: {path}'''
181
 
182
  # class="wrap default svelte-gjihhp hide"
183
  # <div class="contour_video" style="position: absolute; padding: 10px;">
184
  # width="{width}" height="{height}"
185
- video_html = f'''
186
  <video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
187
  autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
188
  <source src="{url}{trim}" type="video/mp4">
189
  Your browser does not support the video tag.
190
  </video>
191
- '''
192
  return video_html
193
 
194
 
@@ -208,16 +211,18 @@ def retrieve_component(retrieve_function, text, splits_choice, nvids, n_componen
208
  htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
209
  # get n_component exactly if asked less
210
  # pad with dummy blocks
211
- htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
212
  return htmls
213
 
214
 
215
  if not os.path.exists("data"):
216
- gdown.download_folder("https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08",
217
- use_cookies=False)
 
 
218
 
219
 
220
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
221
 
222
  # LOADING
223
  model = load_model(device)
@@ -229,7 +234,9 @@ h3d_index = load_json("amass-annotations/humanml3d.json")
229
  amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
230
 
231
  keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
232
- retrieve_function = partial(retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids)
 
 
233
 
234
  # DEMO
235
  theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
@@ -242,33 +249,48 @@ with gr.Blocks(css=CSS, theme=theme) as demo:
242
  with gr.Row():
243
  with gr.Column(scale=3):
244
  with gr.Column(scale=2):
245
- text = gr.Textbox(placeholder="Type the motion you want to search with a sentence",
246
- show_label=True, label="Text prompt", value=DEFAULT_TEXT)
 
 
 
 
247
  with gr.Column(scale=1):
248
- btn = gr.Button("Retrieve", variant='primary')
249
- clear = gr.Button("Clear", variant='secondary')
250
 
251
  with gr.Row():
252
  with gr.Column(scale=1):
253
- splits_choice = gr.Radio(["All motions", "Unseen motions"], label="Gallery of motion",
254
- value="All motions",
255
- info="The motion gallery is coming from HumanML3D")
 
 
 
256
 
257
  with gr.Column(scale=1):
258
  # nvideo_slider = gr.Slider(minimum=4, maximum=24, step=4, value=8, label="Number of videos")
259
- nvideo_slider = gr.Radio([4, 8, 12, 16, 24], label="Videos",
260
- value=8,
261
- info="Number of videos to display")
 
 
 
262
 
263
  with gr.Column(scale=2):
 
264
  def retrieve_example(text, splits_choice, nvideo_slider):
265
  return retrieve_and_show(text, splits_choice, nvideo_slider)
266
 
267
- examples = gr.Examples(examples=[[x, None, None] for x in EXAMPLES],
268
- inputs=[text, splits_choice, nvideo_slider],
269
- examples_per_page=20,
270
- run_on_click=False, cache_examples=False,
271
- fn=retrieve_example, outputs=[])
 
 
 
 
272
 
273
  i = -1
274
  # should indent
@@ -294,16 +316,28 @@ with gr.Blocks(css=CSS, theme=theme) as demo:
294
  show_progress=False,
295
  postprocess=False,
296
  queue=False,
297
- ).then(
298
- fn=retrieve_example,
299
- inputs=examples.inputs,
300
- outputs=videos
301
- )
302
-
303
- btn.click(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
304
- text.submit(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
305
- splits_choice.change(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
306
- nvideo_slider.change(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  def clear_videos():
309
  return [None for x in range(24)] + [DEFAULT_TEXT]
 
56
  "A person is taking the stairs",
57
  "Someone is doing jumping jacks",
58
  "The person walked forward and is picking up his toolbox",
59
+ "The person angrily punching the air",
60
  ]
61
 
62
  # Show closest text in the training
 
94
 
95
  DEFAULT_TEXT = "A person is "
96
 
97
+
98
  def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
99
  # Don't show the mirrored version of HumanMl3D
100
  if "M" in keyid:
 
129
  "text": text,
130
  "keyid": keyid,
131
  "babel_id": babel_id,
132
+ "path": path,
133
  }
134
 
135
  return data
136
 
137
 
138
+ def retrieve(
139
+ model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits=["test"], nmax=8
140
+ ):
141
  unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
142
  keyids = np.concatenate([all_keyids[s] for s in splits])
143
 
 
172
  path = data["path"]
173
 
174
  trim = f"#t={start},{end}"
175
+ title = f"""Score = {score}
176
 
177
  Corresponding text: {text}
178
 
 
180
 
181
  BABEL keyid: {babel_id}
182
 
183
+ AMASS path: {path}"""
184
 
185
  # class="wrap default svelte-gjihhp hide"
186
  # <div class="contour_video" style="position: absolute; padding: 10px;">
187
  # width="{width}" height="{height}"
188
+ video_html = f"""
189
  <video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
190
  autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
191
  <source src="{url}{trim}" type="video/mp4">
192
  Your browser does not support the video tag.
193
  </video>
194
+ """
195
  return video_html
196
 
197
 
 
211
  htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
212
  # get n_component exactly if asked less
213
  # pad with dummy blocks
214
+ htmls = htmls + [None for _ in range(max(0, n_component - nvids))]
215
  return htmls
216
 
217
 
218
  if not os.path.exists("data"):
219
+ gdown.download_folder(
220
+ "https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08",
221
+ use_cookies=False,
222
+ )
223
 
224
 
225
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
 
227
  # LOADING
228
  model = load_model(device)
 
234
  amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
235
 
236
  keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
237
+ retrieve_function = partial(
238
+ retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids
239
+ )
240
 
241
  # DEMO
242
  theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
 
249
  with gr.Row():
250
  with gr.Column(scale=3):
251
  with gr.Column(scale=2):
252
+ text = gr.Textbox(
253
+ placeholder="Type the motion you want to search with a sentence",
254
+ show_label=True,
255
+ label="Text prompt",
256
+ value=DEFAULT_TEXT,
257
+ )
258
  with gr.Column(scale=1):
259
+ btn = gr.Button("Retrieve", variant="primary")
260
+ clear = gr.Button("Clear", variant="secondary")
261
 
262
  with gr.Row():
263
  with gr.Column(scale=1):
264
+ splits_choice = gr.Radio(
265
+ ["All motions", "Unseen motions"],
266
+ label="Gallery of motion",
267
+ value="All motions",
268
+ info="The motion gallery is coming from HumanML3D",
269
+ )
270
 
271
  with gr.Column(scale=1):
272
  # nvideo_slider = gr.Slider(minimum=4, maximum=24, step=4, value=8, label="Number of videos")
273
+ nvideo_slider = gr.Radio(
274
+ [4, 8, 12, 16, 24],
275
+ label="Videos",
276
+ value=8,
277
+ info="Number of videos to display",
278
+ )
279
 
280
  with gr.Column(scale=2):
281
+
282
  def retrieve_example(text, splits_choice, nvideo_slider):
283
  return retrieve_and_show(text, splits_choice, nvideo_slider)
284
 
285
+ examples = gr.Examples(
286
+ examples=[[x, None, None] for x in EXAMPLES],
287
+ inputs=[text, splits_choice, nvideo_slider],
288
+ examples_per_page=20,
289
+ run_on_click=False,
290
+ cache_examples=False,
291
+ fn=retrieve_example,
292
+ outputs=[],
293
+ )
294
 
295
  i = -1
296
  # should indent
 
316
  show_progress=False,
317
  postprocess=False,
318
  queue=False,
319
+ ).then(fn=retrieve_example, inputs=examples.inputs, outputs=videos)
320
+
321
+ btn.click(
322
+ fn=retrieve_and_show,
323
+ inputs=[text, splits_choice, nvideo_slider],
324
+ outputs=videos,
325
+ )
326
+ text.submit(
327
+ fn=retrieve_and_show,
328
+ inputs=[text, splits_choice, nvideo_slider],
329
+ outputs=videos,
330
+ )
331
+ splits_choice.change(
332
+ fn=retrieve_and_show,
333
+ inputs=[text, splits_choice, nvideo_slider],
334
+ outputs=videos,
335
+ )
336
+ nvideo_slider.change(
337
+ fn=retrieve_and_show,
338
+ inputs=[text, splits_choice, nvideo_slider],
339
+ outputs=videos,
340
+ )
341
 
342
  def clear_videos():
343
  return [None for x in range(24)] + [DEFAULT_TEXT]
load.py CHANGED
@@ -20,10 +20,7 @@ def load_keyids(split):
20
 
21
 
22
  def load_keyids_splits(splits):
23
- return {
24
- split: load_keyids(split)
25
- for split in splits
26
- }
27
 
28
 
29
  def load_unit_motion_embs(split, device):
@@ -33,16 +30,17 @@ def load_unit_motion_embs(split, device):
33
 
34
 
35
  def load_unit_motion_embs_splits(splits, device):
36
- return {
37
- split: load_unit_motion_embs(split, device)
38
- for split in splits
39
- }
40
 
41
 
42
  def load_model(device):
43
  text_params = {
44
- 'latent_dim': 256, 'ff_size': 1024, 'num_layers': 6, 'num_heads': 4,
45
- 'activation': 'gelu', 'modelpath': 'distilbert-base-uncased'
 
 
 
 
46
  }
47
  "unit_motion_embs"
48
  model = TMR_textencoder(**text_params)
@@ -50,4 +48,4 @@ def load_model(device):
50
  # load values for the transformer only
51
  model.load_state_dict(state_dict, strict=False)
52
  model = model.eval()
53
- return model
 
20
 
21
 
22
  def load_keyids_splits(splits):
23
+ return {split: load_keyids(split) for split in splits}
 
 
 
24
 
25
 
26
  def load_unit_motion_embs(split, device):
 
30
 
31
 
32
  def load_unit_motion_embs_splits(splits, device):
33
+ return {split: load_unit_motion_embs(split, device) for split in splits}
 
 
 
34
 
35
 
36
  def load_model(device):
37
  text_params = {
38
+ "latent_dim": 256,
39
+ "ff_size": 1024,
40
+ "num_layers": 6,
41
+ "num_heads": 4,
42
+ "activation": "gelu",
43
+ "modelpath": "distilbert-base-uncased",
44
  }
45
  "unit_motion_embs"
46
  model = TMR_textencoder(**text_params)
 
48
  # load values for the transformer only
49
  model.load_state_dict(state_dict, strict=False)
50
  model = model.eval()
51
+ return model.to(device)