Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from model_utils import TimestepEmbedderMDM | |
from model_utils import PositionalEncoding | |
class TMED_denoiser(nn.Module): | |
def __init__(self, | |
nfeats: int = 207, | |
condition: str = "text", | |
latent_dim: list = 512, | |
ff_size: int = 1024, | |
num_layers: int = 8, | |
num_heads: int = 4, | |
dropout: float = 0.1, | |
activation: str = "gelu", | |
text_encoded_dim: int = 768, | |
pred_delta_motion: bool = False, | |
use_sep: bool = True, | |
motion_condition: str = 'source', | |
**kwargs) -> None: | |
super().__init__() | |
self.latent_dim = latent_dim | |
self.pred_delta_motion = pred_delta_motion | |
self.text_encoded_dim = text_encoded_dim | |
self.condition = condition | |
self.feat_comb_coeff = nn.Parameter(torch.tensor([1.0])) | |
self.pose_proj_in_source = nn.Linear(nfeats, self.latent_dim) | |
self.pose_proj_in_target = nn.Linear(nfeats, self.latent_dim) | |
self.pose_proj_out = nn.Linear(self.latent_dim, nfeats) | |
self.first_pose_proj = nn.Linear(self.latent_dim, nfeats) | |
self.motion_condition = motion_condition | |
# emb proj | |
if self.condition in ["text", "text_uncond"]: | |
# text condition | |
# project time from text_encoded_dim to latent_dim | |
self.embed_timestep = TimestepEmbedderMDM(self.latent_dim) | |
# FIXME me TODO this | |
# self.time_embedding = TimestepEmbedderMDM(self.latent_dim) | |
# project time+text to latent_dim | |
if text_encoded_dim != self.latent_dim: | |
# todo 10.24 debug why relu | |
self.emb_proj = nn.Linear(text_encoded_dim, self.latent_dim) | |
else: | |
raise TypeError(f"condition type {self.condition} not supported") | |
self.use_sep = use_sep | |
self.query_pos = PositionalEncoding(self.latent_dim, dropout) | |
self.mem_pos = PositionalEncoding(self.latent_dim, dropout) | |
if self.motion_condition == "source": | |
if self.use_sep: | |
self.sep_token = nn.Parameter(torch.randn(1, self.latent_dim)) | |
# use torch transformer | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=self.latent_dim, | |
nhead=num_heads, | |
dim_feedforward=ff_size, | |
dropout=dropout, | |
activation=activation) | |
self.encoder = nn.TransformerEncoder(encoder_layer, | |
num_layers=num_layers) | |
def forward(self, | |
noised_motion, | |
timestep, | |
in_motion_mask, | |
text_embeds, | |
condition_mask, | |
motion_embeds=None, | |
lengths=None, | |
**kwargs): | |
# 0. dimension matching | |
# noised_motion [latent_dim[0], batch_size, latent_dim] <= [batch_size, latent_dim[0], latent_dim[1]] | |
bs = noised_motion.shape[0] | |
noised_motion = noised_motion.permute(1, 0, 2) | |
# 0. check lengths for no vae (diffusion only) | |
# if lengths not in [None, []]: | |
motion_in_mask = in_motion_mask | |
# time_embedding | text_embedding | frames_source | frames_target | |
# 1 * lat_d | max_text * lat_d | max_frames * lat_d | max_frames * lat_d | |
# 1. time_embeddingno | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timestep.expand(noised_motion.shape[1]).clone() | |
time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype) | |
# make it S first | |
# time_emb = self.time_embedding(time_emb).unsqueeze(0) | |
if self.condition in ["text", "text_uncond"]: | |
# make it seq first | |
text_embeds = text_embeds.permute(1, 0, 2) | |
if self.text_encoded_dim != self.latent_dim: | |
# [1 or 2, bs, latent_dim] <= [1 or 2, bs, text_encoded_dim] | |
text_emb_latent = self.emb_proj(text_embeds) | |
else: | |
text_emb_latent = text_embeds | |
# source_motion_zeros = torch.zeros(*noised_motion.shape[:2], | |
# self.latent_dim, | |
# device=noised_motion.device) | |
# aux_fake_mask = torch.zeros(condition_mask.shape[0], | |
# noised_motion.shape[0], | |
# device=noised_motion.device) | |
# condition_mask = torch.cat((condition_mask, aux_fake_mask), | |
# 1).bool().to(noised_motion.device) | |
emb_latent = torch.cat((time_emb, text_emb_latent), 0) | |
if motion_embeds is not None: | |
zeroes_mask = (motion_embeds == 0).all(dim=-1) | |
if motion_embeds.shape[-1] != self.latent_dim: | |
motion_embeds_proj = self.pose_proj_in_source(motion_embeds) | |
motion_embeds_proj[zeroes_mask] = 0 | |
else: | |
motion_embeds_proj = motion_embeds | |
else: | |
raise TypeError(f"condition type {self.condition} not supported") | |
# 4. transformer | |
# if self.diffusion_only: | |
proj_noised_motion = self.pose_proj_in_target(noised_motion) | |
if motion_embeds is None: | |
xseq = torch.cat((emb_latent, proj_noised_motion), axis=0) | |
else: | |
if self.use_sep: | |
sep_token_batch = torch.tile(self.sep_token, (bs,)).reshape(bs, | |
-1) | |
xseq = torch.cat((emb_latent, motion_embeds_proj, | |
sep_token_batch[None], | |
proj_noised_motion), axis=0) | |
else: | |
xseq = torch.cat((emb_latent, motion_embeds_proj, | |
proj_noised_motion), axis=0) | |
# if self.ablation_skip_connection: | |
# xseq = self.query_pos(xseq) | |
# tokens = self.encoder(xseq) | |
# else: | |
# # adding the timestep embed | |
# # [seqlen+1, bs, d] | |
# # todo change to query_pos_decoder | |
xseq = self.query_pos(xseq) | |
# BUILD the mask now | |
if motion_embeds is None: | |
time_token_mask = torch.ones((bs, time_emb.shape[0]), | |
dtype=bool, device=xseq.device) | |
aug_mask = torch.cat((time_token_mask, | |
condition_mask[:, :text_emb_latent.shape[0]], | |
motion_in_mask), 1) | |
else: | |
time_token_mask = torch.ones((bs, time_emb.shape[0]), | |
dtype=bool, | |
device=xseq.device) | |
if self.use_sep: | |
sep_token_mask = torch.ones((bs, self.sep_token.shape[0]), | |
dtype=bool, | |
device=xseq.device) | |
if self.use_sep: | |
aug_mask = torch.cat((time_token_mask, | |
condition_mask[:, :text_emb_latent.shape[0]], | |
condition_mask[:, text_emb_latent.shape[0]:], | |
sep_token_mask, | |
motion_in_mask, | |
), 1) | |
else: | |
aug_mask = torch.cat((time_token_mask, | |
condition_mask[:, :text_emb_latent.shape[0]], | |
condition_mask[:, text_emb_latent.shape[0]:], | |
motion_in_mask, | |
), 1) | |
tokens = self.encoder(xseq, src_key_padding_mask=~aug_mask) | |
# if self.diffusion_only: | |
if motion_embeds is not None: | |
denoised_motion_proj = tokens[emb_latent.shape[0]:] | |
if self.use_sep: | |
useful_tokens = motion_embeds_proj.shape[0]+1 | |
else: | |
useful_tokens = motion_embeds_proj.shape[0] | |
denoised_motion_proj = denoised_motion_proj[useful_tokens:] | |
else: | |
denoised_motion_proj = tokens[emb_latent.shape[0]:] | |
denoised_motion = self.pose_proj_out(denoised_motion_proj) | |
if self.pred_delta_motion and motion_embeds is not None: | |
import torch.nn.functional as F | |
tgt_size = len(denoised_motion) | |
if len(denoised_motion) > len(motion_embeds): | |
pad_for_src = tgt_size - len(motion_embeds) | |
motion_embeds = F.pad(motion_embeds, | |
(0, 0, 0, 0, 0, pad_for_src)) | |
denoised_motion = denoised_motion + motion_embeds[:tgt_size] | |
denoised_motion[~motion_in_mask.T] = 0 | |
# zero for padded area | |
# else: | |
# sample = tokens[:sample.shape[0]] | |
# 5. [batch_size, latent_dim[0], latent_dim[1]] <= [latent_dim[0], batch_size, latent_dim[1]] | |
denoised_motion = denoised_motion.permute(1, 0, 2) | |
return denoised_motion | |
def forward_with_guidance(self, | |
noised_motion, | |
timestep, | |
in_motion_mask, | |
text_embeds, | |
condition_mask, | |
guidance_motion, | |
guidance_text_n_motion, | |
motion_embeds=None, | |
lengths=None, | |
inpaint_dict=None, | |
max_steps=None, | |
prob_way='3way', | |
**kwargs): | |
# if motion embeds is None | |
# TODO put here that you have tow | |
# implement 2 cases for that case | |
# text unconditional more or less 2 replicas | |
# timestep | |
if max_steps is not None: | |
curr_ts = timestep[0].item() | |
g_m = max(1, guidance_motion*2*curr_ts/max_steps) | |
guidance_motion = g_m | |
g_t_tm = max(1, guidance_text_n_motion*2*curr_ts/max_steps) | |
guidance_text_n_motion = g_t_tm | |
if motion_embeds is None: | |
half = noised_motion[: len(noised_motion) // 2] | |
combined = torch.cat([half, half], dim=0) | |
model_out = self.forward(combined, timestep, | |
in_motion_mask=in_motion_mask, | |
text_embeds=text_embeds, | |
condition_mask=condition_mask, | |
motion_embeds=motion_embeds, | |
lengths=lengths) | |
uncond_eps, cond_eps_text = torch.split(model_out, len(model_out) // 2, | |
dim=0) | |
# make it BxSxfeatures | |
if inpaint_dict is not None: | |
import torch.nn.functional as F | |
source_mot = inpaint_dict['start_motion'].permute(1, 0, 2) | |
if source_mot.shape[1] >= uncond_eps.shape[1]: | |
source_mot = source_mot[:, :uncond_eps.shape[1]] | |
else: | |
pad = uncond_eps.shape[1] - source_mot.shape[1] | |
# Pad the tensor on the second dimension (time) | |
source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0) | |
mot_len = source_mot.shape[1] | |
# concat mask for all the frames | |
mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1, | |
mot_len, | |
1) | |
uncond_eps = uncond_eps*(mask_src_parts) + source_mot*(~mask_src_parts) | |
cond_eps_text = cond_eps_text*(mask_src_parts) + source_mot*(~mask_src_parts) | |
half_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text - uncond_eps) | |
eps = torch.cat([half_eps, half_eps], dim=0) | |
else: | |
third = noised_motion[: len(noised_motion) // 3] | |
combined = torch.cat([third, third, third], dim=0) | |
model_out = self.forward(combined, timestep, | |
in_motion_mask=in_motion_mask, | |
text_embeds=text_embeds, | |
condition_mask=condition_mask, | |
motion_embeds=motion_embeds, | |
lengths=lengths) | |
# For exact reproducibility reasons, we apply classifier-free guidance on only | |
# three channels by default. The standard approach to cfg applies it to all channels. | |
# This can be done by uncommenting the following line and commenting-out the line following that. | |
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] | |
# eps, rest = model_out[:, :3], model_out[:, 3:] | |
uncond_eps, cond_eps_motion, cond_eps_text_n_motion = torch.split(model_out, | |
len(model_out) // 3, | |
dim=0) | |
if inpaint_dict is not None: | |
import torch.nn.functional as F | |
source_mot = inpaint_dict['start_motion'].permute(1, 0, 2) | |
if source_mot.shape[1] >= uncond_eps.shape[1]: | |
source_mot = source_mot[:, :uncond_eps.shape[1]] | |
else: | |
pad = uncond_eps.shape[1] - source_mot.shape[1] | |
# Pad the tensor on the second dimension (time) | |
source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0) | |
mot_len = source_mot.shape[1] | |
# concat mask for all the frames | |
mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1, | |
mot_len, | |
1) | |
uncond_eps = uncond_eps*(~mask_src_parts) + source_mot*mask_src_parts | |
cond_eps_text = cond_eps_text*(~mask_src_parts) + source_mot*mask_src_parts | |
cond_eps_text_n_motion = cond_eps_text_n_motion*(~mask_src_parts) + source_mot*mask_src_parts | |
if prob_way=='3way': | |
third_eps = uncond_eps + guidance_motion * (cond_eps_motion - uncond_eps) + \ | |
guidance_text_n_motion * (cond_eps_text_n_motion - cond_eps_motion) | |
if prob_way=='2way': | |
third_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text_n_motion - uncond_eps) | |
eps = torch.cat([third_eps, third_eps, third_eps], dim=0) | |
return eps | |
def _diffusion_reverse(self, text_embeds, text_masks_from_enc, | |
motion_embeds, cond_motion_masks, | |
inp_motion_mask, diff_process, | |
init_vec=None, | |
init_from='noise', | |
gd_text=None, gd_motion=None, | |
mode='full_cond', | |
return_init_noise=False, | |
steps_num=None, | |
inpaint_dict=None, | |
use_linear=False, | |
prob_way='3way'): | |
# guidance_scale_text: 7.5 # | |
# guidance_scale_motion: 1.5 | |
# init latents | |
bsz = inp_motion_mask.shape[0] | |
assert mode in ['full_cond', 'text_cond', 'mot_cond'] | |
assert inp_motion_mask is not None | |
# len_to_gen = max(lengths) if not self.input_deltas else max(lengths) + 1 | |
if init_vec is None: | |
initial_latents = torch.randn( | |
(bsz, inp_motion_mask.shape[1], 207), | |
device=inp_motion_mask.device, | |
dtype=torch.float, | |
) | |
else: | |
initial_latents = init_vec | |
gd_scale_text = 2.0 | |
gd_scale_motion = 4.0 | |
if text_embeds is not None: | |
max_text_len = text_embeds.shape[1] | |
else: | |
max_text_len = 0 | |
max_motion_len = cond_motion_masks.shape[1] | |
text_masks = text_masks_from_enc.clone() | |
nomotion_mask = torch.zeros(bsz, max_motion_len, | |
dtype=torch.bool).to('cuda') | |
motion_masks = torch.cat([nomotion_mask, | |
cond_motion_masks, | |
cond_motion_masks], | |
dim=0) | |
aug_mask = torch.cat([text_masks, | |
motion_masks], | |
dim=1) | |
# Setup classifier-free guidance: | |
if motion_embeds is not None: | |
z = torch.cat([initial_latents, initial_latents, initial_latents], 0) | |
else: | |
z = torch.cat([initial_latents, initial_latents], 0) | |
# y_null = torch.tensor([1000] * n, device=device) | |
# y = torch.cat([y, y_null], 0) | |
if use_linear: | |
max_steps_diff = diff_process.num_timesteps | |
else: | |
max_steps_diff = None | |
if motion_embeds is not None: | |
model_kwargs = dict(# noised_motion=latent_model_input, | |
# timestep=t, | |
in_motion_mask=torch.cat([inp_motion_mask, | |
inp_motion_mask, | |
inp_motion_mask], 0), | |
text_embeds=text_embeds, | |
condition_mask=aug_mask, | |
motion_embeds=torch.cat([torch.zeros_like(motion_embeds), | |
motion_embeds, | |
motion_embeds], 1), | |
guidance_motion=gd_motion, | |
guidance_text_n_motion=gd_text, | |
inpaint_dict=inpaint_dict, | |
max_steps=max_steps_diff, | |
prob_way=prob_way) | |
else: | |
model_kwargs = dict(# noised_motion=latent_model_input, | |
# timestep=t, | |
in_motion_mask=torch.cat([inp_motion_mask, | |
inp_motion_mask], 0), | |
text_embeds=text_embeds, | |
condition_mask=aug_mask, | |
motion_embeds=None, | |
guidance_motion=gd_motion, | |
guidance_text_n_motion=gd_text, | |
inpaint_dict=inpaint_dict, | |
max_steps=max_steps_diff) | |
# model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) | |
# Sample images: | |
samples = diff_process.p_sample_loop(self.forward_with_guidance, | |
z.shape, z, | |
clip_denoised=False, | |
model_kwargs=model_kwargs, | |
progress=True, | |
device=initial_latents.device,) | |
_, _, samples = samples.chunk(3, dim=0) # Remove null class samples | |
final_diffout = samples.permute(1, 0, 2) | |
if return_init_noise: | |
return initial_latents, final_diffout | |
else: | |
return final_diffout |