MeYourHint commited on
Commit
023631a
1 Parent(s): 3d89bc0

Update gen func in app

Browse files
Files changed (3) hide show
  1. app.py +163 -21
  2. options/base_option.py +1 -1
  3. options/hgdemo_option.py +38 -0
app.py CHANGED
@@ -10,6 +10,30 @@ import shutil
10
  print(f"Is CUDA available: {torch.cuda.is_available()}")
11
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  WEBSITE = """
14
  <div class="embed_hidden">
15
  <h1 style='text-align: center'> MoMask: Generative Masked Modeling of 3D Human Motions </h1>
@@ -89,19 +113,120 @@ CSS = """
89
 
90
  DEFAULT_TEXT = "A person is "
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def generate(
93
- text, uid, motion_length=0, seed=10107, repeat_times=1,
94
  ):
95
- os.system(f'python gen_t2m.py --gpu_id 0 --seed {seed} --ext {uid} --repeat_times {repeat_times} --motion_length {motion_length} --text_prompt "{text}"')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  datas = []
97
- file_name = [name for name in os.listdir(f"./generation/{uid}/animations/0/") if name.endswith('_ik.mp4')][0]
98
- motion_length = int(file_name.split('len')[-1].replace('_ik.mp4', ''))
99
- for n in range(repeat_times):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  data_unit = {
101
- "url": f"generation/{uid}/animations/0/sample0_repeat{n}_len{motion_length}_ik.mp4"
102
  }
103
  datas.append(data_unit)
104
- print(datas)
105
  return datas
106
 
107
 
@@ -121,11 +246,16 @@ autoplay loop disablepictureinpicture id="{video_id}">
121
  return video_html
122
 
123
 
124
- def generate_component(generate_function, text):
125
  if text == DEFAULT_TEXT or text == "" or text is None:
126
  return [None for _ in range(1)]
127
  uid = random.randrange(99999)
128
- datas = generate_function(text, uid)
 
 
 
 
 
129
  htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
130
  return htmls
131
 
@@ -148,15 +278,27 @@ with gr.Blocks(css=CSS, theme=theme) as demo:
148
 
149
  with gr.Row():
150
  with gr.Column(scale=3):
151
- with gr.Column(scale=2):
152
- text = gr.Textbox(
153
- show_label=True,
154
- label="Text prompt",
155
- value=DEFAULT_TEXT,
156
- )
157
- with gr.Column(scale=1):
158
- gen_btn = gr.Button("Generate", variant="primary")
159
- clear = gr.Button("Clear", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  with gr.Column(scale=2):
162
 
@@ -166,7 +308,7 @@ with gr.Blocks(css=CSS, theme=theme) as demo:
166
  examples = gr.Examples(
167
  examples=[[x, None, None] for x in EXAMPLES],
168
  inputs=[text],
169
- examples_per_page=20,
170
  run_on_click=False,
171
  cache_examples=False,
172
  fn=generate_example,
@@ -201,12 +343,12 @@ with gr.Blocks(css=CSS, theme=theme) as demo:
201
 
202
  gen_btn.click(
203
  fn=generate_and_show,
204
- inputs=[text],
205
  outputs=videos,
206
  )
207
  text.submit(
208
  fn=generate_and_show,
209
- inputs=[text],
210
  outputs=videos,
211
  )
212
 
 
10
  print(f"Is CUDA available: {torch.cuda.is_available()}")
11
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
12
 
13
+ import os
14
+ from os.path import join as pjoin
15
+
16
+ import torch.nn.functional as F
17
+
18
+ from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
19
+ from models.vq.model import RVQVAE, LengthEstimator
20
+
21
+ from options.hgdemo_option import EvalT2MOptions
22
+ from utils.get_opt import get_opt
23
+
24
+ from utils.fixseed import fixseed
25
+ from visualization.joints2bvh import Joint2BVHConvertor
26
+ from torch.distributions.categorical import Categorical
27
+
28
+ from utils.motion_process import recover_from_ric
29
+ from utils.plot_script import plot_3d_motion
30
+
31
+ from utils.paramUtil import t2m_kinematic_chain
32
+
33
+ from gen_t2m import load_vq_model, load_res_model, load_trans_model, load_len_estimator
34
+
35
+ clip_version = 'ViT-B/32'
36
+
37
  WEBSITE = """
38
  <div class="embed_hidden">
39
  <h1 style='text-align: center'> MoMask: Generative Masked Modeling of 3D Human Motions </h1>
 
113
 
114
  DEFAULT_TEXT = "A person is "
115
 
116
+
117
+ ##########################
118
+ ######Preparing demo######
119
+ ##########################
120
+ parser = EvalT2MOptions()
121
+ opt = parser.parse()
122
+ fixseed(opt.seed)
123
+ opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
124
+ dim_pose = 263
125
+ root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
126
+ model_dir = pjoin(root_dir, 'model')
127
+ model_opt_path = pjoin(root_dir, 'opt.txt')
128
+ model_opt = get_opt(model_opt_path, device=opt.device)
129
+
130
+ ######Loading RVQ######
131
+ vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
132
+ vq_opt = get_opt(vq_opt_path, device=opt.device)
133
+ vq_opt.dim_pose = dim_pose
134
+ vq_model, vq_opt = load_vq_model(vq_opt)
135
+
136
+ model_opt.num_tokens = vq_opt.nb_code
137
+ model_opt.num_quantizers = vq_opt.num_quantizers
138
+ model_opt.code_dim = vq_opt.code_dim
139
+
140
+ ######Loading R-Transformer######
141
+ res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
142
+ res_opt = get_opt(res_opt_path, device=opt.device)
143
+ res_model = load_res_model(res_opt, vq_opt, opt)
144
+
145
+ assert res_opt.vq_name == model_opt.vq_name
146
+
147
+ ######Loading M-Transformer######
148
+ t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')
149
+
150
+ #####Loading Length Predictor#####
151
+ length_estimator = load_len_estimator(model_opt)
152
+
153
+ t2m_transformer.eval()
154
+ vq_model.eval()
155
+ res_model.eval()
156
+ length_estimator.eval()
157
+
158
+ res_model.to(opt.device)
159
+ t2m_transformer.to(opt.device)
160
+ vq_model.to(opt.device)
161
+ length_estimator.to(opt.device)
162
+
163
+ opt.nb_joints = 22
164
+ mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))
165
+ std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))
166
+ def inv_transform(data):
167
+ return data * std + mean
168
+
169
+ kinematic_chain = t2m_kinematic_chain
170
+ converter = Joint2BVHConvertor()
171
+ cached_dir = './cached'
172
+ os.makedirs(cached_dir, exist_ok=True)
173
+
174
+ @torch.no_grad()
175
  def generate(
176
+ text, uid, motion_length=0, use_ik=True, seed=10107, repeat_times=1,
177
  ):
178
+ fixseed(seed)
179
+ prompt_list = []
180
+ length_list = []
181
+ est_length = False
182
+ prompt_list.append(text)
183
+ if motion_length == 0:
184
+ est_length = True
185
+ else:
186
+ length_list.append(motion_length)
187
+
188
+ if est_length:
189
+ print("Since no motion length are specified, we will use estimated motion lengthes!!")
190
+ text_embedding = t2m_transformer.encode_text(prompt_list)
191
+ pred_dis = length_estimator(text_embedding)
192
+ probs = F.softmax(pred_dis, dim=-1) # (b, ntoken)
193
+ token_lens = Categorical(probs).sample() # (b, seqlen)
194
+ else:
195
+ token_lens = torch.LongTensor(length_list) // 4
196
+ token_lens = token_lens.to(opt.device).long()
197
+
198
+ m_length = token_lens * 4
199
+ captions = prompt_list
200
  datas = []
201
+ for r in range(repeat_times):
202
+ mids = t2m_transformer.generate(captions, token_lens,
203
+ timesteps=opt.time_steps,
204
+ cond_scale=opt.cond_scale,
205
+ temperature=opt.temperature,
206
+ topk_filter_thres=opt.topkr,
207
+ gsample=opt.gumbel_sample)
208
+ mids = res_model.generate(mids, captions, token_lens, temperature=1, cond_scale=5)
209
+ pred_motions = vq_model.forward_decoder(mids)
210
+ pred_motions = pred_motions.detach().cpu().numpy()
211
+ data = inv_transform(pred_motions)
212
+ for k, (caption, joint_data) in enumerate(zip(captions, data)):
213
+ animation_path = pjoin(cached_dir, uid, str(k))
214
+ os.makedirs(animation_path, exist_ok=True)
215
+ joint_data = joint_data[:m_length[k]]
216
+ joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy()
217
+ bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k]))
218
+ save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k]))
219
+ if use_ik:
220
+ _, joint = converter.convert(joint, filename=bvh_path, iterations=100)
221
+ else:
222
+ _, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False)
223
+ plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)
224
+ np.save(pjoin(animation_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint)
225
  data_unit = {
226
+ "url": f"generation/{uid}/animations/0/sample0_repeat{r}_len{motion_length}.mp4"
227
  }
228
  datas.append(data_unit)
229
+
230
  return datas
231
 
232
 
 
246
  return video_html
247
 
248
 
249
+ def generate_component(generate_function, text, motion_len='0', postprocess='IK'):
250
  if text == DEFAULT_TEXT or text == "" or text is None:
251
  return [None for _ in range(1)]
252
  uid = random.randrange(99999)
253
+ try:
254
+ motion_len = max(0, min(int(float(motion_len) * 20), 196))
255
+ except:
256
+ motion_len = 0
257
+ use_ik = postprocess == 'IK'
258
+ datas = generate_function(text, uid, motion_len, use_ik)
259
  htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
260
  return htmls
261
 
 
278
 
279
  with gr.Row():
280
  with gr.Column(scale=3):
281
+ text = gr.Textbox(
282
+ show_label=True,
283
+ label="Text prompt",
284
+ value=DEFAULT_TEXT,
285
+ )
286
+ with gr.Row():
287
+ with gr.Column(scale=1):
288
+ motion_len = gr.Textbox(
289
+ show_label=True,
290
+ label="Motion length (<10s)",
291
+ value=0,
292
+ )
293
+ with gr.Column(scale=1):
294
+ use_ik = gr.Radio(
295
+ ["Raw", "IK"],
296
+ label="Post-processing",
297
+ value="IK",
298
+ info="Use basic inverse kinematic (IK) for foot contact locking",
299
+ )
300
+ gen_btn = gr.Button("Generate", variant="primary")
301
+ clear = gr.Button("Clear", variant="secondary")
302
 
303
  with gr.Column(scale=2):
304
 
 
308
  examples = gr.Examples(
309
  examples=[[x, None, None] for x in EXAMPLES],
310
  inputs=[text],
311
+ examples_per_page=10,
312
  run_on_click=False,
313
  cache_examples=False,
314
  fn=generate_example,
 
343
 
344
  gen_btn.click(
345
  fn=generate_and_show,
346
+ inputs=[text, motion_len, use_ik],
347
  outputs=videos,
348
  )
349
  text.submit(
350
  fn=generate_and_show,
351
+ inputs=[text, motion_len, use_ik],
352
  outputs=videos,
353
  )
354
 
options/base_option.py CHANGED
@@ -12,7 +12,7 @@ class BaseOptions():
12
 
13
  self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
14
 
15
- self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
16
  self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
17
  self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
18
 
 
12
 
13
  self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
14
 
15
+ self.parser.add_argument("--gpu_id", type=int, default=0, help='GPU id')
16
  self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
17
  self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
18
 
options/hgdemo_option.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from options.base_option import BaseOptions
2
+
3
+ class EvalT2MOptions(BaseOptions):
4
+ def initialize(self):
5
+ BaseOptions.initialize(self)
6
+ self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint you want to use, {latest, net_best_fid, etc}')
7
+ self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
8
+
9
+ self.parser.add_argument('--ext', type=str, default='text2motion', help='Extension of the result file or folder')
10
+ self.parser.add_argument("--num_batch", default=2, type=int,
11
+ help="Number of batch for generation")
12
+ self.parser.add_argument("--repeat_times", default=1, type=int,
13
+ help="Number of repetitions, per sample text prompt")
14
+ self.parser.add_argument("--cond_scale", default=4, type=float,
15
+ help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
16
+ self.parser.add_argument("--temperature", default=1., type=float,
17
+ help="Sampling Temperature.")
18
+ self.parser.add_argument("--topkr", default=0.9, type=float,
19
+ help="Filter out percentil low prop entries.")
20
+ self.parser.add_argument("--time_steps", default=18, type=int,
21
+ help="Mask Generate steps.")
22
+ self.parser.add_argument("--seed", default=10107, type=int)
23
+
24
+ self.parser.add_argument('--gumbel_sample', action="store_true", help='True: gumbel sampling, False: categorical sampling.')
25
+ self.parser.add_argument('--use_res_model', action="store_true", help='Whether to use residual transformer.')
26
+ # self.parser.add_argument('--est_length', action="store_true", help='Training iterations')
27
+
28
+ self.parser.add_argument('--res_name', type=str, default='tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw', help='Model name of residual transformer')
29
+ self.parser.add_argument('--text_path', type=str, default="", help='Text prompt file')
30
+
31
+
32
+ self.parser.add_argument('-msec', '--mask_edit_section', nargs='*', type=str, help='Indicate sections for editing, use comma to separate the start and end of a section'
33
+ 'type int will specify the token frame, type float will specify the ratio of seq_len')
34
+ self.parser.add_argument('--text_prompt', default='', type=str, help="A text prompt to be generated. If empty, will take text prompts from dataset.")
35
+ self.parser.add_argument('--source_motion', default='example_data/000612.npy', type=str, help="Source motion path for editing. (new_joint_vecs format .npy file)")
36
+ self.parser.add_argument("--motion_length", default=0, type=int,
37
+ help="Motion length for generation, only applicable with single text prompt.")
38
+ self.is_train = False