File size: 19,906 Bytes
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10ff2d6
d8530c7
 
 
 
 
 
 
 
 
 
 
10ff2d6
 
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10ff2d6
 
 
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10ff2d6
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10ff2d6
 
d8530c7
10ff2d6
 
 
 
 
 
 
 
 
 
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10ff2d6
 
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
import torch
import torch.nn as nn
from model_utils import TimestepEmbedderMDM
from model_utils  import PositionalEncoding

class TMED_denoiser(nn.Module):

    def __init__(self,
                 nfeats: int = 207,
                 condition: str = "text",
                 latent_dim: list = 512,
                 ff_size: int = 1024,
                 num_layers: int = 8,
                 num_heads: int = 4,
                 dropout: float = 0.1,
                 activation: str = "gelu",
                 text_encoded_dim: int = 768,
                 pred_delta_motion: bool = False,
                 use_sep: bool = True,
                 motion_condition: str = 'source',
                 **kwargs) -> None:

        super().__init__()
        self.latent_dim = latent_dim
        self.pred_delta_motion = pred_delta_motion
        self.text_encoded_dim = text_encoded_dim
        self.condition = condition
        self.feat_comb_coeff = nn.Parameter(torch.tensor([1.0]))
        self.pose_proj_in_source = nn.Linear(nfeats, self.latent_dim)
        self.pose_proj_in_target = nn.Linear(nfeats, self.latent_dim)
        self.pose_proj_out = nn.Linear(self.latent_dim, nfeats)
        self.first_pose_proj = nn.Linear(self.latent_dim, nfeats)
        self.motion_condition = motion_condition

        # emb proj
        if self.condition in ["text", "text_uncond"]:
            # text condition
            # project time from text_encoded_dim to latent_dim
            self.embed_timestep = TimestepEmbedderMDM(self.latent_dim)

            # FIXME me TODO this            
            # self.time_embedding = TimestepEmbedderMDM(self.latent_dim)
            
            # project time+text to latent_dim
            if text_encoded_dim != self.latent_dim:
                # todo 10.24 debug why relu
                self.emb_proj = nn.Linear(text_encoded_dim, self.latent_dim)
        else:
            raise TypeError(f"condition type {self.condition} not supported")
        self.use_sep = use_sep
        self.query_pos = PositionalEncoding(self.latent_dim, dropout)
        self.mem_pos = PositionalEncoding(self.latent_dim, dropout)
        if self.motion_condition == "source":
            if self.use_sep:
                self.sep_token = nn.Parameter(torch.randn(1, self.latent_dim))

        # use torch transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.latent_dim,
            nhead=num_heads,
            dim_feedforward=ff_size,
            dropout=dropout,
            activation=activation)
        self.encoder = nn.TransformerEncoder(encoder_layer,
                                                num_layers=num_layers)

    def forward(self,
                noised_motion,
                timestep,
                in_motion_mask,
                text_embeds,
                condition_mask, 
                motion_embeds=None,
                lengths=None,
                **kwargs):
        # 0.  dimension matching
        # noised_motion [latent_dim[0], batch_size, latent_dim] <= [batch_size, latent_dim[0], latent_dim[1]]
        bs = noised_motion.shape[0]
        noised_motion = noised_motion.permute(1, 0, 2)
        # 0. check lengths for no vae (diffusion only)
        # if lengths not in [None, []]:
        motion_in_mask = in_motion_mask

        # time_embedding | text_embedding | frames_source | frames_target
        # 1 * lat_d | max_text * lat_d | max_frames * lat_d | max_frames * lat_d
        
        
        # 1. time_embeddingno
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timestep.expand(noised_motion.shape[1]).clone()
        time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype)
        # make it S first
        # time_emb = self.time_embedding(time_emb).unsqueeze(0)
        if self.condition in ["text", "text_uncond"]:
            # make it seq first
            text_embeds = text_embeds.permute(1, 0, 2)
            if self.text_encoded_dim != self.latent_dim:
                # [1 or 2, bs, latent_dim] <= [1 or 2, bs, text_encoded_dim]
                text_emb_latent = self.emb_proj(text_embeds)
            else:
                text_emb_latent = text_embeds
                # source_motion_zeros = torch.zeros(*noised_motion.shape[:2], 
                #                             self.latent_dim, 
                #                             device=noised_motion.device)
                # aux_fake_mask = torch.zeros(condition_mask.shape[0], 
                #                             noised_motion.shape[0], 
                #                             device=noised_motion.device)
                # condition_mask = torch.cat((condition_mask, aux_fake_mask), 
                #                            1).bool().to(noised_motion.device)
            emb_latent = torch.cat((time_emb, text_emb_latent), 0)

            if motion_embeds is not None:
                zeroes_mask = (motion_embeds == 0).all(dim=-1)
                if motion_embeds.shape[-1] != self.latent_dim:
                    motion_embeds_proj = self.pose_proj_in_source(motion_embeds)
                    motion_embeds_proj[zeroes_mask] = 0
                else:
                    motion_embeds_proj = motion_embeds
 
        else:
            raise TypeError(f"condition type {self.condition} not supported")
        # 4. transformer
        # if self.diffusion_only:
        proj_noised_motion = self.pose_proj_in_target(noised_motion)

        if motion_embeds is None:
            xseq = torch.cat((emb_latent, proj_noised_motion), axis=0)
        else:
            if self.use_sep:

                sep_token_batch = torch.tile(self.sep_token, (bs,)).reshape(bs,
                                                                         -1)
                xseq = torch.cat((emb_latent, motion_embeds_proj,
                                sep_token_batch[None],
                                proj_noised_motion), axis=0)
            else:
                xseq = torch.cat((emb_latent, motion_embeds_proj,
                                  proj_noised_motion), axis=0)
        # if self.ablation_skip_connection:
        #     xseq = self.query_pos(xseq)
        #     tokens = self.encoder(xseq)
        # else:
        #     # adding the timestep embed
        #     # [seqlen+1, bs, d]
        #     # todo change to query_pos_decoder
        xseq = self.query_pos(xseq)
        # BUILD the mask now
        if motion_embeds is None:
            time_token_mask = torch.ones((bs, time_emb.shape[0]),
                                        dtype=bool, device=xseq.device)
            aug_mask = torch.cat((time_token_mask,
                                  condition_mask[:, :text_emb_latent.shape[0]],
                                  motion_in_mask), 1)
        else:
            time_token_mask = torch.ones((bs, time_emb.shape[0]),
                                        dtype=bool,
                                        device=xseq.device)
            if self.use_sep:
                sep_token_mask = torch.ones((bs, self.sep_token.shape[0]),
                                        dtype=bool,
                                        device=xseq.device)
            if self.use_sep:
                aug_mask = torch.cat((time_token_mask,
                                condition_mask[:, :text_emb_latent.shape[0]],
                                condition_mask[:, text_emb_latent.shape[0]:],
                                sep_token_mask,
                                motion_in_mask,
                                ), 1)
            else:
                aug_mask = torch.cat((time_token_mask,
                                condition_mask[:, :text_emb_latent.shape[0]],
                                condition_mask[:, text_emb_latent.shape[0]:],
                                motion_in_mask,
                                ), 1)
        tokens = self.encoder(xseq, src_key_padding_mask=~aug_mask)

        # if self.diffusion_only:
        if motion_embeds is not None:
            denoised_motion_proj = tokens[emb_latent.shape[0]:]
            if self.use_sep:
                useful_tokens = motion_embeds_proj.shape[0]+1
            else:
                useful_tokens = motion_embeds_proj.shape[0]
            denoised_motion_proj = denoised_motion_proj[useful_tokens:]
        else:
            denoised_motion_proj = tokens[emb_latent.shape[0]:]

        denoised_motion = self.pose_proj_out(denoised_motion_proj)
        if self.pred_delta_motion and motion_embeds is not None:
            import torch.nn.functional as F
            tgt_size = len(denoised_motion)
            if len(denoised_motion) > len(motion_embeds):
                pad_for_src = tgt_size - len(motion_embeds)
                motion_embeds = F.pad(motion_embeds, 
                                      (0, 0, 0, 0, 0, pad_for_src))
            denoised_motion = denoised_motion + motion_embeds[:tgt_size]

        denoised_motion[~motion_in_mask.T] = 0
        # zero for padded area
        # else:
        #     sample = tokens[:sample.shape[0]]
        # 5. [batch_size, latent_dim[0], latent_dim[1]] <= [latent_dim[0], batch_size, latent_dim[1]]
        denoised_motion = denoised_motion.permute(1, 0, 2)
        return denoised_motion

    def forward_with_guidance(self,
                              noised_motion,
                              timestep,
                              in_motion_mask,
                              text_embeds,
                              condition_mask,
                              guidance_motion,
                              guidance_text_n_motion, 
                              motion_embeds=None,
                              lengths=None,
                              inpaint_dict=None,
                              max_steps=None,
                              prob_way='3way',
                              **kwargs):
        # if motion embeds is None
        # TODO put here that you have tow
        # implement 2 cases for that case
        # text unconditional more or less 2 replicas
        # timestep
        if max_steps is not None:
            curr_ts = timestep[0].item()
            g_m = max(1, guidance_motion*2*curr_ts/max_steps)
            guidance_motion = g_m
            g_t_tm = max(1, guidance_text_n_motion*2*curr_ts/max_steps)
            guidance_text_n_motion = g_t_tm

        if motion_embeds is None:
            half = noised_motion[: len(noised_motion) // 2]
            combined = torch.cat([half, half], dim=0)
            model_out = self.forward(combined, timestep,
                                    in_motion_mask=in_motion_mask,
                                    text_embeds=text_embeds,
                                    condition_mask=condition_mask, 
                                    motion_embeds=motion_embeds,
                                    lengths=lengths)
            uncond_eps, cond_eps_text = torch.split(model_out, len(model_out) // 2,
                                                     dim=0)
            # make it BxSxfeatures
            if inpaint_dict is not None:
                import torch.nn.functional as F
                source_mot = inpaint_dict['start_motion'].permute(1, 0, 2)
                if source_mot.shape[1] >= uncond_eps.shape[1]:
                    source_mot = source_mot[:, :uncond_eps.shape[1]]
                else:
                    pad = uncond_eps.shape[1] - source_mot.shape[1]
                    # Pad the tensor on the second dimension (time)
                    source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0)

                mot_len = source_mot.shape[1]
                # concat mask for all the frames
                mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
                                                                      mot_len,
                                                                      1)
                uncond_eps = uncond_eps*(mask_src_parts) + source_mot*(~mask_src_parts)
                cond_eps_text = cond_eps_text*(mask_src_parts) + source_mot*(~mask_src_parts)
            half_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text - uncond_eps) 
            eps = torch.cat([half_eps, half_eps], dim=0)
        else:
            third = noised_motion[: len(noised_motion) // 3]
            combined = torch.cat([third, third, third], dim=0)
            model_out = self.forward(combined, timestep,
                                     in_motion_mask=in_motion_mask,
                                     text_embeds=text_embeds,
                                     condition_mask=condition_mask, 
                                     motion_embeds=motion_embeds,
                                     lengths=lengths)
            # For exact reproducibility reasons, we apply classifier-free guidance on only
            # three channels by default. The standard approach to cfg applies it to all channels.
            # This can be done by uncommenting the following line and commenting-out the line following that.
            # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
            # eps, rest = model_out[:, :3], model_out[:, 3:]
            uncond_eps, cond_eps_motion, cond_eps_text_n_motion = torch.split(model_out,
                                                                            len(model_out) // 3,
                                                                            dim=0)
            if inpaint_dict is not None:
                import torch.nn.functional as F
                source_mot = inpaint_dict['start_motion'].permute(1, 0, 2)
                if source_mot.shape[1] >= uncond_eps.shape[1]:
                    source_mot = source_mot[:, :uncond_eps.shape[1]]
                else:
                    pad = uncond_eps.shape[1] - source_mot.shape[1]
                    # Pad the tensor on the second dimension (time)
                    source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0)

                mot_len = source_mot.shape[1]
                # concat mask for all the frames
                mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
                                                                      mot_len,
                                                                      1)
                uncond_eps = uncond_eps*(~mask_src_parts) + source_mot*mask_src_parts
                cond_eps_text = cond_eps_text*(~mask_src_parts) + source_mot*mask_src_parts
                cond_eps_text_n_motion = cond_eps_text_n_motion*(~mask_src_parts) + source_mot*mask_src_parts
            if prob_way=='3way':
                third_eps = uncond_eps + guidance_motion * (cond_eps_motion - uncond_eps) + \
                            guidance_text_n_motion * (cond_eps_text_n_motion - cond_eps_motion)
            if prob_way=='2way':
                third_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text_n_motion - uncond_eps)

            eps = torch.cat([third_eps, third_eps, third_eps], dim=0)
        return eps

    def _diffusion_reverse(self, text_embeds, text_masks_from_enc, 
                            motion_embeds, cond_motion_masks,
                            inp_motion_mask, diff_process,
                            init_vec=None,
                            init_from='noise',
                            gd_text=None, gd_motion=None, 
                            mode='full_cond',
                            return_init_noise=False,
                            steps_num=None,
                            inpaint_dict=None,
                            use_linear=False,
                            prob_way='3way'):
        # guidance_scale_text: 7.5 #
        #  guidance_scale_motion: 1.5
        # init latents

        bsz = inp_motion_mask.shape[0]
        assert mode in ['full_cond', 'text_cond', 'mot_cond']
        assert inp_motion_mask is not None
        # len_to_gen = max(lengths) if not self.input_deltas else max(lengths) + 1
        if init_vec is None:
            initial_latents = torch.randn(
                (bsz, inp_motion_mask.shape[1], 207),
                device=inp_motion_mask.device,
                dtype=torch.float,
            )
        else:
            initial_latents = init_vec

        gd_scale_text = 2.0
        gd_scale_motion = 4.0

        if text_embeds is not None:
            max_text_len = text_embeds.shape[1]
        else:
            max_text_len = 0
        max_motion_len = cond_motion_masks.shape[1]
        text_masks = text_masks_from_enc.clone()
        nomotion_mask = torch.zeros(bsz, max_motion_len, 
                    dtype=torch.bool).to('cuda')
        motion_masks = torch.cat([nomotion_mask, 
                                cond_motion_masks, 
                                cond_motion_masks],
                                dim=0)
        aug_mask = torch.cat([text_masks,
                                motion_masks],
                                dim=1)


        # Setup classifier-free guidance:
        if motion_embeds is not None:
            z = torch.cat([initial_latents, initial_latents, initial_latents], 0)
        else:
            z = torch.cat([initial_latents, initial_latents], 0)

        # y_null = torch.tensor([1000] * n, device=device)
        # y = torch.cat([y, y_null], 0)
        if use_linear:
            max_steps_diff = diff_process.num_timesteps
        else:
            max_steps_diff = None
        if motion_embeds is not None:
            model_kwargs = dict(# noised_motion=latent_model_input,
                                # timestep=t,
                                in_motion_mask=torch.cat([inp_motion_mask,
                                                        inp_motion_mask,
                                                        inp_motion_mask], 0),
                                text_embeds=text_embeds,
                                condition_mask=aug_mask,
                                motion_embeds=torch.cat([torch.zeros_like(motion_embeds),
                                                        motion_embeds,
                                                        motion_embeds], 1),
                                guidance_motion=gd_motion,
                                guidance_text_n_motion=gd_text,
                                inpaint_dict=inpaint_dict,
                                max_steps=max_steps_diff,
                                prob_way=prob_way)
        else:
            model_kwargs = dict(# noised_motion=latent_model_input,
                    # timestep=t,
                    in_motion_mask=torch.cat([inp_motion_mask,
                                            inp_motion_mask], 0),
                    text_embeds=text_embeds,
                    condition_mask=aug_mask,
                    motion_embeds=None,
                    guidance_motion=gd_motion,
                    guidance_text_n_motion=gd_text,
                    inpaint_dict=inpaint_dict,
                    max_steps=max_steps_diff)

        # model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
        # Sample images:
        samples = diff_process.p_sample_loop(self.forward_with_guidance,
                                            z.shape, z, 
                                            clip_denoised=False, 
                                            model_kwargs=model_kwargs,
                                            progress=True,
                                            device=initial_latents.device,)
        _, _, samples = samples.chunk(3, dim=0)  # Remove null class samples

        final_diffout = samples.permute(1, 0, 2)
        if return_init_noise:
            return initial_latents, final_diffout
        else:
            return final_diffout