motionfix-demo / tmed_denoiser.py
atnikos's picture
attempts to fix
10ff2d6
import torch
import torch.nn as nn
from model_utils import TimestepEmbedderMDM
from model_utils import PositionalEncoding
class TMED_denoiser(nn.Module):
def __init__(self,
nfeats: int = 207,
condition: str = "text",
latent_dim: list = 512,
ff_size: int = 1024,
num_layers: int = 8,
num_heads: int = 4,
dropout: float = 0.1,
activation: str = "gelu",
text_encoded_dim: int = 768,
pred_delta_motion: bool = False,
use_sep: bool = True,
motion_condition: str = 'source',
**kwargs) -> None:
super().__init__()
self.latent_dim = latent_dim
self.pred_delta_motion = pred_delta_motion
self.text_encoded_dim = text_encoded_dim
self.condition = condition
self.feat_comb_coeff = nn.Parameter(torch.tensor([1.0]))
self.pose_proj_in_source = nn.Linear(nfeats, self.latent_dim)
self.pose_proj_in_target = nn.Linear(nfeats, self.latent_dim)
self.pose_proj_out = nn.Linear(self.latent_dim, nfeats)
self.first_pose_proj = nn.Linear(self.latent_dim, nfeats)
self.motion_condition = motion_condition
# emb proj
if self.condition in ["text", "text_uncond"]:
# text condition
# project time from text_encoded_dim to latent_dim
self.embed_timestep = TimestepEmbedderMDM(self.latent_dim)
# FIXME me TODO this
# self.time_embedding = TimestepEmbedderMDM(self.latent_dim)
# project time+text to latent_dim
if text_encoded_dim != self.latent_dim:
# todo 10.24 debug why relu
self.emb_proj = nn.Linear(text_encoded_dim, self.latent_dim)
else:
raise TypeError(f"condition type {self.condition} not supported")
self.use_sep = use_sep
self.query_pos = PositionalEncoding(self.latent_dim, dropout)
self.mem_pos = PositionalEncoding(self.latent_dim, dropout)
if self.motion_condition == "source":
if self.use_sep:
self.sep_token = nn.Parameter(torch.randn(1, self.latent_dim))
# use torch transformer
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation)
self.encoder = nn.TransformerEncoder(encoder_layer,
num_layers=num_layers)
def forward(self,
noised_motion,
timestep,
in_motion_mask,
text_embeds,
condition_mask,
motion_embeds=None,
lengths=None,
**kwargs):
# 0. dimension matching
# noised_motion [latent_dim[0], batch_size, latent_dim] <= [batch_size, latent_dim[0], latent_dim[1]]
bs = noised_motion.shape[0]
noised_motion = noised_motion.permute(1, 0, 2)
# 0. check lengths for no vae (diffusion only)
# if lengths not in [None, []]:
motion_in_mask = in_motion_mask
# time_embedding | text_embedding | frames_source | frames_target
# 1 * lat_d | max_text * lat_d | max_frames * lat_d | max_frames * lat_d
# 1. time_embeddingno
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timestep.expand(noised_motion.shape[1]).clone()
time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype)
# make it S first
# time_emb = self.time_embedding(time_emb).unsqueeze(0)
if self.condition in ["text", "text_uncond"]:
# make it seq first
text_embeds = text_embeds.permute(1, 0, 2)
if self.text_encoded_dim != self.latent_dim:
# [1 or 2, bs, latent_dim] <= [1 or 2, bs, text_encoded_dim]
text_emb_latent = self.emb_proj(text_embeds)
else:
text_emb_latent = text_embeds
# source_motion_zeros = torch.zeros(*noised_motion.shape[:2],
# self.latent_dim,
# device=noised_motion.device)
# aux_fake_mask = torch.zeros(condition_mask.shape[0],
# noised_motion.shape[0],
# device=noised_motion.device)
# condition_mask = torch.cat((condition_mask, aux_fake_mask),
# 1).bool().to(noised_motion.device)
emb_latent = torch.cat((time_emb, text_emb_latent), 0)
if motion_embeds is not None:
zeroes_mask = (motion_embeds == 0).all(dim=-1)
if motion_embeds.shape[-1] != self.latent_dim:
motion_embeds_proj = self.pose_proj_in_source(motion_embeds)
motion_embeds_proj[zeroes_mask] = 0
else:
motion_embeds_proj = motion_embeds
else:
raise TypeError(f"condition type {self.condition} not supported")
# 4. transformer
# if self.diffusion_only:
proj_noised_motion = self.pose_proj_in_target(noised_motion)
if motion_embeds is None:
xseq = torch.cat((emb_latent, proj_noised_motion), axis=0)
else:
if self.use_sep:
sep_token_batch = torch.tile(self.sep_token, (bs,)).reshape(bs,
-1)
xseq = torch.cat((emb_latent, motion_embeds_proj,
sep_token_batch[None],
proj_noised_motion), axis=0)
else:
xseq = torch.cat((emb_latent, motion_embeds_proj,
proj_noised_motion), axis=0)
# if self.ablation_skip_connection:
# xseq = self.query_pos(xseq)
# tokens = self.encoder(xseq)
# else:
# # adding the timestep embed
# # [seqlen+1, bs, d]
# # todo change to query_pos_decoder
xseq = self.query_pos(xseq)
# BUILD the mask now
if motion_embeds is None:
time_token_mask = torch.ones((bs, time_emb.shape[0]),
dtype=bool, device=xseq.device)
aug_mask = torch.cat((time_token_mask,
condition_mask[:, :text_emb_latent.shape[0]],
motion_in_mask), 1)
else:
time_token_mask = torch.ones((bs, time_emb.shape[0]),
dtype=bool,
device=xseq.device)
if self.use_sep:
sep_token_mask = torch.ones((bs, self.sep_token.shape[0]),
dtype=bool,
device=xseq.device)
if self.use_sep:
aug_mask = torch.cat((time_token_mask,
condition_mask[:, :text_emb_latent.shape[0]],
condition_mask[:, text_emb_latent.shape[0]:],
sep_token_mask,
motion_in_mask,
), 1)
else:
aug_mask = torch.cat((time_token_mask,
condition_mask[:, :text_emb_latent.shape[0]],
condition_mask[:, text_emb_latent.shape[0]:],
motion_in_mask,
), 1)
tokens = self.encoder(xseq, src_key_padding_mask=~aug_mask)
# if self.diffusion_only:
if motion_embeds is not None:
denoised_motion_proj = tokens[emb_latent.shape[0]:]
if self.use_sep:
useful_tokens = motion_embeds_proj.shape[0]+1
else:
useful_tokens = motion_embeds_proj.shape[0]
denoised_motion_proj = denoised_motion_proj[useful_tokens:]
else:
denoised_motion_proj = tokens[emb_latent.shape[0]:]
denoised_motion = self.pose_proj_out(denoised_motion_proj)
if self.pred_delta_motion and motion_embeds is not None:
import torch.nn.functional as F
tgt_size = len(denoised_motion)
if len(denoised_motion) > len(motion_embeds):
pad_for_src = tgt_size - len(motion_embeds)
motion_embeds = F.pad(motion_embeds,
(0, 0, 0, 0, 0, pad_for_src))
denoised_motion = denoised_motion + motion_embeds[:tgt_size]
denoised_motion[~motion_in_mask.T] = 0
# zero for padded area
# else:
# sample = tokens[:sample.shape[0]]
# 5. [batch_size, latent_dim[0], latent_dim[1]] <= [latent_dim[0], batch_size, latent_dim[1]]
denoised_motion = denoised_motion.permute(1, 0, 2)
return denoised_motion
def forward_with_guidance(self,
noised_motion,
timestep,
in_motion_mask,
text_embeds,
condition_mask,
guidance_motion,
guidance_text_n_motion,
motion_embeds=None,
lengths=None,
inpaint_dict=None,
max_steps=None,
prob_way='3way',
**kwargs):
# if motion embeds is None
# TODO put here that you have tow
# implement 2 cases for that case
# text unconditional more or less 2 replicas
# timestep
if max_steps is not None:
curr_ts = timestep[0].item()
g_m = max(1, guidance_motion*2*curr_ts/max_steps)
guidance_motion = g_m
g_t_tm = max(1, guidance_text_n_motion*2*curr_ts/max_steps)
guidance_text_n_motion = g_t_tm
if motion_embeds is None:
half = noised_motion[: len(noised_motion) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.forward(combined, timestep,
in_motion_mask=in_motion_mask,
text_embeds=text_embeds,
condition_mask=condition_mask,
motion_embeds=motion_embeds,
lengths=lengths)
uncond_eps, cond_eps_text = torch.split(model_out, len(model_out) // 2,
dim=0)
# make it BxSxfeatures
if inpaint_dict is not None:
import torch.nn.functional as F
source_mot = inpaint_dict['start_motion'].permute(1, 0, 2)
if source_mot.shape[1] >= uncond_eps.shape[1]:
source_mot = source_mot[:, :uncond_eps.shape[1]]
else:
pad = uncond_eps.shape[1] - source_mot.shape[1]
# Pad the tensor on the second dimension (time)
source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0)
mot_len = source_mot.shape[1]
# concat mask for all the frames
mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
mot_len,
1)
uncond_eps = uncond_eps*(mask_src_parts) + source_mot*(~mask_src_parts)
cond_eps_text = cond_eps_text*(mask_src_parts) + source_mot*(~mask_src_parts)
half_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
else:
third = noised_motion[: len(noised_motion) // 3]
combined = torch.cat([third, third, third], dim=0)
model_out = self.forward(combined, timestep,
in_motion_mask=in_motion_mask,
text_embeds=text_embeds,
condition_mask=condition_mask,
motion_embeds=motion_embeds,
lengths=lengths)
# For exact reproducibility reasons, we apply classifier-free guidance on only
# three channels by default. The standard approach to cfg applies it to all channels.
# This can be done by uncommenting the following line and commenting-out the line following that.
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
# eps, rest = model_out[:, :3], model_out[:, 3:]
uncond_eps, cond_eps_motion, cond_eps_text_n_motion = torch.split(model_out,
len(model_out) // 3,
dim=0)
if inpaint_dict is not None:
import torch.nn.functional as F
source_mot = inpaint_dict['start_motion'].permute(1, 0, 2)
if source_mot.shape[1] >= uncond_eps.shape[1]:
source_mot = source_mot[:, :uncond_eps.shape[1]]
else:
pad = uncond_eps.shape[1] - source_mot.shape[1]
# Pad the tensor on the second dimension (time)
source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0)
mot_len = source_mot.shape[1]
# concat mask for all the frames
mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
mot_len,
1)
uncond_eps = uncond_eps*(~mask_src_parts) + source_mot*mask_src_parts
cond_eps_text = cond_eps_text*(~mask_src_parts) + source_mot*mask_src_parts
cond_eps_text_n_motion = cond_eps_text_n_motion*(~mask_src_parts) + source_mot*mask_src_parts
if prob_way=='3way':
third_eps = uncond_eps + guidance_motion * (cond_eps_motion - uncond_eps) + \
guidance_text_n_motion * (cond_eps_text_n_motion - cond_eps_motion)
if prob_way=='2way':
third_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text_n_motion - uncond_eps)
eps = torch.cat([third_eps, third_eps, third_eps], dim=0)
return eps
def _diffusion_reverse(self, text_embeds, text_masks_from_enc,
motion_embeds, cond_motion_masks,
inp_motion_mask, diff_process,
init_vec=None,
init_from='noise',
gd_text=None, gd_motion=None,
mode='full_cond',
return_init_noise=False,
steps_num=None,
inpaint_dict=None,
use_linear=False,
prob_way='3way'):
# guidance_scale_text: 7.5 #
# guidance_scale_motion: 1.5
# init latents
bsz = inp_motion_mask.shape[0]
assert mode in ['full_cond', 'text_cond', 'mot_cond']
assert inp_motion_mask is not None
# len_to_gen = max(lengths) if not self.input_deltas else max(lengths) + 1
if init_vec is None:
initial_latents = torch.randn(
(bsz, inp_motion_mask.shape[1], 207),
device=inp_motion_mask.device,
dtype=torch.float,
)
else:
initial_latents = init_vec
gd_scale_text = 2.0
gd_scale_motion = 4.0
if text_embeds is not None:
max_text_len = text_embeds.shape[1]
else:
max_text_len = 0
max_motion_len = cond_motion_masks.shape[1]
text_masks = text_masks_from_enc.clone()
nomotion_mask = torch.zeros(bsz, max_motion_len,
dtype=torch.bool).to('cuda')
motion_masks = torch.cat([nomotion_mask,
cond_motion_masks,
cond_motion_masks],
dim=0)
aug_mask = torch.cat([text_masks,
motion_masks],
dim=1)
# Setup classifier-free guidance:
if motion_embeds is not None:
z = torch.cat([initial_latents, initial_latents, initial_latents], 0)
else:
z = torch.cat([initial_latents, initial_latents], 0)
# y_null = torch.tensor([1000] * n, device=device)
# y = torch.cat([y, y_null], 0)
if use_linear:
max_steps_diff = diff_process.num_timesteps
else:
max_steps_diff = None
if motion_embeds is not None:
model_kwargs = dict(# noised_motion=latent_model_input,
# timestep=t,
in_motion_mask=torch.cat([inp_motion_mask,
inp_motion_mask,
inp_motion_mask], 0),
text_embeds=text_embeds,
condition_mask=aug_mask,
motion_embeds=torch.cat([torch.zeros_like(motion_embeds),
motion_embeds,
motion_embeds], 1),
guidance_motion=gd_motion,
guidance_text_n_motion=gd_text,
inpaint_dict=inpaint_dict,
max_steps=max_steps_diff,
prob_way=prob_way)
else:
model_kwargs = dict(# noised_motion=latent_model_input,
# timestep=t,
in_motion_mask=torch.cat([inp_motion_mask,
inp_motion_mask], 0),
text_embeds=text_embeds,
condition_mask=aug_mask,
motion_embeds=None,
guidance_motion=gd_motion,
guidance_text_n_motion=gd_text,
inpaint_dict=inpaint_dict,
max_steps=max_steps_diff)
# model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
# Sample images:
samples = diff_process.p_sample_loop(self.forward_with_guidance,
z.shape, z,
clip_denoised=False,
model_kwargs=model_kwargs,
progress=True,
device=initial_latents.device,)
_, _, samples = samples.chunk(3, dim=0) # Remove null class samples
final_diffout = samples.permute(1, 0, 2)
if return_init_noise:
return initial_latents, final_diffout
else:
return final_diffout