import clip import math import torch import torch.nn.functional as F from torch import nn import numpy as np from einops.layers.torch import Rearrange from einops import rearrange import matplotlib.pyplot as plt import os import torch.nn as nn # Custom LayerNorm class to handle fp16 class CustomLayerNorm(nn.LayerNorm): def forward(self, x: torch.Tensor): if self.weight.dtype == torch.float32: orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) else: return super().forward(x) # Function to replace LayerNorm in CLIP model with CustomLayerNorm def replace_layer_norm(model): for name, module in model.named_children(): if isinstance(module, nn.LayerNorm): setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine)) else: replace_layer_norm(module) # Recursively apply to all submodules MONITOR_ATTN = [] SELF_ATTN = [] def vis_attn(att, out_path, step, layer, shape, type_="self", lines=True): if lines: plt.figure(figsize=(10, 3)) for token_index in range(att.shape[1]): plt.plot(att[:, token_index], label=f"Token {token_index}") plt.title("Attention Values for Each Token") plt.xlabel("time") plt.ylabel("Attention Value") plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1)) # save image savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_lines_{shape}.png") os.makedirs(os.path.dirname(savepath), exist_ok=True) plt.savefig(savepath, bbox_inches="tight") np.save(savepath.replace(".png", ".npy"), att) else: plt.figure(figsize=(10, 10)) plt.imshow(att.transpose(), cmap="viridis", aspect="auto") plt.colorbar() plt.title("Attention Matrix Heatmap") plt.ylabel("time") plt.xlabel("time") # save image savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_heatmap_{shape}.png") os.makedirs(os.path.dirname(savepath), exist_ok=True) plt.savefig(savepath, bbox_inches="tight") np.save(savepath.replace(".png", ".npy"), att) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module class FFN(nn.Module): def __init__(self, latent_dim, ffn_dim, dropout): super().__init__() self.linear1 = nn.Linear(latent_dim, ffn_dim) self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) self.activation = nn.GELU() self.dropout = nn.Dropout(dropout) def forward(self, x): y = self.linear2(self.dropout(self.activation(self.linear1(x)))) y = x + y return y class Conv1dAdaGNBlock(nn.Module): """ Conv1d --> GroupNorm --> scale,shift --> Mish """ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4): super().__init__() self.out_channels = out_channels self.block = nn.Conv1d( inp_channels, out_channels, kernel_size, padding=kernel_size // 2 ) self.group_norm = nn.GroupNorm(n_groups, out_channels) self.avtication = nn.Mish() def forward(self, x, scale, shift): """ Args: x: [bs, nfeat, nframes] scale: [bs, out_feat, 1] shift: [bs, out_feat, 1] """ x = self.block(x) batch_size, channels, horizon = x.size() x = rearrange( x, "batch channels horizon -> (batch horizon) channels" ) # [bs*seq, nfeats] x = self.group_norm(x) x = rearrange( x.reshape(batch_size, horizon, channels), "batch horizon channels -> batch channels horizon", ) x = ada_shift_scale(x, shift, scale) return self.avtication(x) class SelfAttention(nn.Module): def __init__( self, latent_dim, text_latent_dim, num_heads: int = 8, dropout: float = 0.0, log_attn=False, edit_config=None, ): super().__init__() self.num_head = num_heads self.norm = nn.LayerNorm(latent_dim) self.query = nn.Linear(latent_dim, latent_dim) self.key = nn.Linear(latent_dim, latent_dim) self.value = nn.Linear(latent_dim, latent_dim) self.dropout = nn.Dropout(dropout) self.edit_config = edit_config self.log_attn = log_attn def forward(self, x): """ x: B, T, D xf: B, N, L """ B, T, D = x.shape N = x.shape[1] assert N == T H = self.num_head # B, T, 1, D query = self.query(self.norm(x)).unsqueeze(2) # B, 1, N, D key = self.key(self.norm(x)).unsqueeze(1) query = query.view(B, T, H, -1) key = key.view(B, N, H, -1) # style transfer motion editing style_tranfer = self.edit_config.style_tranfer.use if style_tranfer: if ( len(SELF_ATTN) <= self.edit_config.style_tranfer.style_transfer_steps_end ): query[1] = query[0] # example based motion generation example_based = self.edit_config.example_based.use if example_based: if len(SELF_ATTN) == self.edit_config.example_based.example_based_steps_end: temp_seed = self.edit_config.example_based.temp_seed for id_ in range(query.shape[0] - 1): with torch.random.fork_rng(): torch.manual_seed(temp_seed) tensor = query[0] chunks = torch.split( tensor, self.edit_config.example_based.chunk_size, dim=0 ) shuffled_indices = torch.randperm(len(chunks)) shuffled_chunks = [chunks[i] for i in shuffled_indices] shuffled_tensor = torch.cat(shuffled_chunks, dim=0) query[1 + id_] = shuffled_tensor temp_seed += self.edit_config.example_based.temp_seed_bar # time shift motion editing (q, k) time_shift = self.edit_config.time_shift.use if time_shift: if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end: part1 = int( key.shape[1] * self.edit_config.time_shift.time_shift_ratio // 1 ) part2 = int( key.shape[1] * (1 - self.edit_config.time_shift.time_shift_ratio) // 1 ) q_front_part = query[0, :part1, :, :] q_back_part = query[0, -part2:, :, :] new_q = torch.cat((q_back_part, q_front_part), dim=0) query[1] = new_q k_front_part = key[0, :part1, :, :] k_back_part = key[0, -part2:, :, :] new_k = torch.cat((k_back_part, k_front_part), dim=0) key[1] = new_k # B, T, N, H attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H) weight = self.dropout(F.softmax(attention, dim=2)) # for counting the step and logging attention maps try: attention_matrix = ( weight[0, :, :].mean(dim=-1).detach().cpu().numpy().astype(float) ) SELF_ATTN[-1].append(attention_matrix) except: pass # attention manipulation for replacement attention_manipulation = self.edit_config.manipulation.use if attention_manipulation: if len(SELF_ATTN) <= self.edit_config.manipulation.manipulation_steps_end: weight[1, :, :, :] = weight[0, :, :, :] value = self.value(self.norm(x)).view(B, N, H, -1) # time shift motion editing (v) if time_shift: if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end: v_front_part = value[0, :part1, :, :] v_back_part = value[0, -part2:, :, :] new_v = torch.cat((v_back_part, v_front_part), dim=0) value[1] = new_v y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D) return y class TimestepEmbedder(nn.Module): def __init__(self, d_model, max_len=5000): super(TimestepEmbedder, self).__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) def forward(self, x): self.pe = self.pe.cuda() return self.pe[x] class Downsample1d(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.Conv1d(dim, dim, 3, 2, 1) def forward(self, x): return self.conv(x) class Upsample1d(nn.Module): def __init__(self, dim_in, dim_out=None): super().__init__() dim_out = dim_out or dim_in self.conv = nn.ConvTranspose1d(dim_in, dim_out, 4, 2, 1) def forward(self, x): return self.conv(x) class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish """ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4, zero=False): super().__init__() self.out_channels = out_channels self.block = nn.Conv1d( inp_channels, out_channels, kernel_size, padding=kernel_size // 2 ) self.norm = nn.GroupNorm(n_groups, out_channels) self.activation = nn.Mish() if zero: # zero init the convolution nn.init.zeros_(self.block.weight) nn.init.zeros_(self.block.bias) def forward(self, x): """ Args: x: [bs, nfeat, nframes] """ x = self.block(x) batch_size, channels, horizon = x.size() x = rearrange( x, "batch channels horizon -> (batch horizon) channels" ) # [bs*seq, nfeats] x = self.norm(x) x = rearrange( x.reshape(batch_size, horizon, channels), "batch horizon channels -> batch channels horizon", ) return self.activation(x) def ada_shift_scale(x, shift, scale): return x * (1 + scale) + shift class ResidualTemporalBlock(nn.Module): def __init__( self, inp_channels, out_channels, embed_dim, kernel_size=5, zero=True, n_groups=8, dropout: float = 0.1, adagn=True, ): super().__init__() self.adagn = adagn self.blocks = nn.ModuleList( [ # adagn only the first conv (following guided-diffusion) ( Conv1dAdaGNBlock(inp_channels, out_channels, kernel_size, n_groups) if adagn else Conv1dBlock(inp_channels, out_channels, kernel_size) ), Conv1dBlock( out_channels, out_channels, kernel_size, n_groups, zero=zero ), ] ) self.time_mlp = nn.Sequential( nn.Mish(), # adagn = scale and shift nn.Linear(embed_dim, out_channels * 2 if adagn else out_channels), Rearrange("batch t -> batch t 1"), ) self.dropout = nn.Dropout(dropout) if zero: nn.init.zeros_(self.time_mlp[1].weight) nn.init.zeros_(self.time_mlp[1].bias) self.residual_conv = ( nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() ) def forward(self, x, time_embeds=None): """ x : [ batch_size x inp_channels x nframes ] t : [ batch_size x embed_dim ] returns: [ batch_size x out_channels x nframes ] """ if self.adagn: scale, shift = self.time_mlp(time_embeds).chunk(2, dim=1) out = self.blocks[0](x, scale, shift) else: out = self.blocks[0](x) + self.time_mlp(time_embeds) out = self.blocks[1](out) out = self.dropout(out) return out + self.residual_conv(x) class CrossAttention(nn.Module): def __init__( self, latent_dim, text_latent_dim, num_heads: int = 8, dropout: float = 0.0, log_attn=False, edit_config=None, ): super().__init__() self.num_head = num_heads self.norm = nn.LayerNorm(latent_dim) self.text_norm = nn.LayerNorm(text_latent_dim) self.query = nn.Linear(latent_dim, latent_dim) self.key = nn.Linear(text_latent_dim, latent_dim) self.value = nn.Linear(text_latent_dim, latent_dim) self.dropout = nn.Dropout(dropout) self.edit_config = edit_config self.log_attn = log_attn def forward(self, x, xf): """ x: B, T, D xf: B, N, L """ B, T, D = x.shape N = xf.shape[1] H = self.num_head # B, T, 1, D query = self.query(self.norm(x)).unsqueeze(2) # B, 1, N, D key = self.key(self.text_norm(xf)).unsqueeze(1) query = query.view(B, T, H, -1) key = key.view(B, N, H, -1) # B, T, N, H attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H) weight = self.dropout(F.softmax(attention, dim=2)) # attention reweighting for (de)-emphasizing motion if self.edit_config.reweighting_attn.use: reweighting_attn = self.edit_config.reweighting_attn.reweighting_attn_weight if self.edit_config.reweighting_attn.idx == -1: # read idxs from txt file with open("./assets/reweighting_idx.txt", "r") as f: idxs = f.readlines() else: # gradio demo mode idxs = [0, self.edit_config.reweighting_attn.idx] idxs = [int(idx) for idx in idxs] for i in range(len(idxs)): weight[i, :, 1 + idxs[i]] = weight[i, :, 1 + idxs[i]] + reweighting_attn weight[i, :, 1 + idxs[i] + 1] = ( weight[i, :, 1 + idxs[i] + 1] + reweighting_attn ) # for counting the step and logging attention maps try: attention_matrix = ( weight[0, :, 1 : 1 + 3] .mean(dim=-1) .detach() .cpu() .numpy() .astype(float) ) MONITOR_ATTN[-1].append(attention_matrix) except: pass # erasing motion (autually is the deemphasizing motion) erasing_motion = self.edit_config.erasing_motion.use if erasing_motion: reweighting_attn = self.edit_config.erasing_motion.erasing_motion_weight begin = self.edit_config.erasing_motion.time_start end = self.edit_config.erasing_motion.time_end idx = self.edit_config.erasing_motion.idx if reweighting_attn > 0.01 or reweighting_attn < -0.01: weight[1, int(T * begin) : int(T * end), idx] = ( weight[1, int(T * begin) : int(T * end) :, idx] * reweighting_attn ) weight[1, int(T * begin) : int(T * end), idx + 1] = ( weight[1, int(T * begin) : int(T * end), idx + 1] * reweighting_attn ) # attention manipulation for motion replacement manipulation = self.edit_config.manipulation.use if manipulation: if ( len(MONITOR_ATTN) <= self.edit_config.manipulation.manipulation_steps_end_crossattn ): word_idx = self.edit_config.manipulation.word_idx weight[1, :, : 1 + word_idx, :] = weight[0, :, : 1 + word_idx, :] weight[1, :, 1 + word_idx + 1 :, :] = weight[ 0, :, 1 + word_idx + 1 :, : ] value = self.value(self.text_norm(xf)).view(B, N, H, -1) y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D) return y class ResidualCLRAttentionLayer(nn.Module): def __init__( self, dim1, dim2, num_heads: int = 8, dropout: float = 0.1, no_eff: bool = False, self_attention: bool = False, log_attn=False, edit_config=None, ): super(ResidualCLRAttentionLayer, self).__init__() self.dim1 = dim1 self.dim2 = dim2 self.num_heads = num_heads # Multi-Head Attention Layer if no_eff: self.cross_attention = CrossAttention( latent_dim=dim1, text_latent_dim=dim2, num_heads=num_heads, dropout=dropout, log_attn=log_attn, edit_config=edit_config, ) else: self.cross_attention = LinearCrossAttention( latent_dim=dim1, text_latent_dim=dim2, num_heads=num_heads, dropout=dropout, log_attn=log_attn, ) if self_attention: self.self_attn_use = True self.self_attention = SelfAttention( latent_dim=dim1, text_latent_dim=dim2, num_heads=num_heads, dropout=dropout, log_attn=log_attn, edit_config=edit_config, ) else: self.self_attn_use = False def forward(self, input_tensor, condition_tensor, cond_indices): """ input_tensor :B, D, L condition_tensor: B, L, D """ if cond_indices.numel() == 0: return input_tensor # self attention if self.self_attn_use: x = input_tensor x = x.permute(0, 2, 1) # (batch_size, seq_length, feat_dim) x = self.self_attention(x) x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length) input_tensor = input_tensor + x x = input_tensor # cross attention x = x[cond_indices].permute(0, 2, 1) # (batch_size, seq_length, feat_dim) x = self.cross_attention(x, condition_tensor[cond_indices]) x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length) input_tensor[cond_indices] = input_tensor[cond_indices] + x return input_tensor class CLRBlock(nn.Module): def __init__( self, dim_in, dim_out, cond_dim, time_dim, adagn=True, zero=True, no_eff=False, self_attention=False, dropout: float = 0.1, log_attn=False, edit_config=None, ) -> None: super().__init__() self.conv1d = ResidualTemporalBlock( dim_in, dim_out, embed_dim=time_dim, adagn=adagn, zero=zero, dropout=dropout ) self.clr_attn = ResidualCLRAttentionLayer( dim1=dim_out, dim2=cond_dim, no_eff=no_eff, dropout=dropout, self_attention=self_attention, log_attn=log_attn, edit_config=edit_config, ) # import pdb; pdb.set_trace() self.ffn = FFN(dim_out, dim_out * 4, dropout=dropout) def forward(self, x, t, cond, cond_indices=None): x = self.conv1d(x, t) x = self.clr_attn(x, cond, cond_indices) x = self.ffn(x.permute(0, 2, 1)).permute(0, 2, 1) return x class CondUnet1D(nn.Module): """ Diffusion's style UNET with 1D convolution and adaptive group normalization for motion suquence denoising, cross-attention to introduce conditional prompts (like text). """ def __init__( self, input_dim, cond_dim, dim=128, dim_mults=(1, 2, 4, 8), dims=None, time_dim=512, adagn=True, zero=True, dropout=0.1, no_eff=False, self_attention=False, log_attn=False, edit_config=None, ): super().__init__() if not dims: dims = [input_dim, *map(lambda m: int(dim * m), dim_mults)] ##[d, d,2d,4d] print("dims: ", dims, "mults: ", dim_mults) in_out = list(zip(dims[:-1], dims[1:])) self.time_mlp = nn.Sequential( TimestepEmbedder(time_dim), nn.Linear(time_dim, time_dim * 4), nn.Mish(), nn.Linear(time_dim * 4, time_dim), ) self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): self.downs.append( nn.ModuleList( [ CLRBlock( dim_in, dim_out, cond_dim, time_dim, adagn=adagn, zero=zero, no_eff=no_eff, dropout=dropout, self_attention=self_attention, log_attn=log_attn, edit_config=edit_config, ), CLRBlock( dim_out, dim_out, cond_dim, time_dim, adagn=adagn, zero=zero, no_eff=no_eff, dropout=dropout, self_attention=self_attention, log_attn=log_attn, edit_config=edit_config, ), Downsample1d(dim_out), ] ) ) mid_dim = dims[-1] self.mid_block1 = CLRBlock( dim_in=mid_dim, dim_out=mid_dim, cond_dim=cond_dim, time_dim=time_dim, adagn=adagn, zero=zero, no_eff=no_eff, dropout=dropout, self_attention=self_attention, log_attn=log_attn, edit_config=edit_config, ) self.mid_block2 = CLRBlock( dim_in=mid_dim, dim_out=mid_dim, cond_dim=cond_dim, time_dim=time_dim, adagn=adagn, zero=zero, no_eff=no_eff, dropout=dropout, self_attention=self_attention, log_attn=log_attn, edit_config=edit_config, ) last_dim = mid_dim for ind, dim_out in enumerate(reversed(dims[1:])): self.ups.append( nn.ModuleList( [ Upsample1d(last_dim, dim_out), CLRBlock( dim_out * 2, dim_out, cond_dim, time_dim, adagn=adagn, zero=zero, no_eff=no_eff, dropout=dropout, self_attention=self_attention, log_attn=log_attn, edit_config=edit_config, ), CLRBlock( dim_out, dim_out, cond_dim, time_dim, adagn=adagn, zero=zero, no_eff=no_eff, dropout=dropout, self_attention=self_attention, log_attn=log_attn, edit_config=edit_config, ), ] ) ) last_dim = dim_out self.final_conv = nn.Conv1d(dim_out, input_dim, 1) if zero: nn.init.zeros_(self.final_conv.weight) nn.init.zeros_(self.final_conv.bias) def forward( self, x, t, cond, cond_indices, ): self.time_mlp = self.time_mlp.cuda() temb = self.time_mlp(t) h = [] for block1, block2, downsample in self.downs: block1 = block1.cuda() block2 = block2.cuda() x = block1(x, temb, cond, cond_indices) x = block2(x, temb, cond, cond_indices) h.append(x) x = downsample(x) self.mid_block1 = self.mid_block1.cuda() self.mid_block2 = self.mid_block2.cuda() x = self.mid_block1(x, temb, cond, cond_indices) x = self.mid_block2(x, temb, cond, cond_indices) for upsample, block1, block2 in self.ups: x = upsample(x) x = torch.cat((x, h.pop()), dim=1) block1 = block1.cuda() block2 = block2.cuda() x = block1(x, temb, cond, cond_indices) x = block2(x, temb, cond, cond_indices) self.final_conv = self.final_conv.cuda() x = self.final_conv(x) return x class MotionCLR(nn.Module): """ Diffuser's style UNET for text-to-motion task. """ def __init__( self, input_feats, base_dim=128, dim_mults=(1, 2, 2, 2), dims=None, adagn=True, zero=True, dropout=0.1, no_eff=False, time_dim=512, latent_dim=256, cond_mask_prob=0.1, clip_dim=512, clip_version="ViT-B/32", text_latent_dim=256, text_ff_size=2048, text_num_heads=4, activation="gelu", num_text_layers=4, self_attention=False, vis_attn=False, edit_config=None, out_path=None, ): super().__init__() self.input_feats = input_feats self.dim_mults = dim_mults self.base_dim = base_dim self.latent_dim = latent_dim self.cond_mask_prob = cond_mask_prob self.vis_attn = vis_attn self.counting_map = [] self.out_path = out_path print( f"The T2M Unet mask the text prompt by {self.cond_mask_prob} prob. in training" ) # text encoder self.embed_text = nn.Linear(clip_dim, text_latent_dim) self.clip_version = clip_version self.clip_model = self.load_and_freeze_clip(clip_version) replace_layer_norm(self.clip_model) textTransEncoderLayer = nn.TransformerEncoderLayer( d_model=text_latent_dim, nhead=text_num_heads, dim_feedforward=text_ff_size, dropout=dropout, activation=activation, ) self.textTransEncoder = nn.TransformerEncoder( textTransEncoderLayer, num_layers=num_text_layers ) self.text_ln = nn.LayerNorm(text_latent_dim) self.unet = CondUnet1D( input_dim=self.input_feats, cond_dim=text_latent_dim, dim=self.base_dim, dim_mults=self.dim_mults, adagn=adagn, zero=zero, dropout=dropout, no_eff=no_eff, dims=dims, time_dim=time_dim, self_attention=self_attention, log_attn=self.vis_attn, edit_config=edit_config, ) self.clip_model = self.clip_model.cuda() self.embed_text = self.embed_text.cuda() self.textTransEncoder = self.textTransEncoder.cuda() self.text_ln = self.text_ln.cuda() self.unet = self.unet.cuda() def encode_text(self, raw_text, device): self.clip_model.token_embedding = self.clip_model.token_embedding.to(device) self.clip_model.transformer = self.clip_model.transformer.to(device) self.clip_model.ln_final = self.clip_model.ln_final.to(device) with torch.no_grad(): texts = clip.tokenize(raw_text, truncate=True).to( device ) # [bs, context_length] # if n_tokens > 77 -> will truncate x = self.clip_model.token_embedding(texts).type(self.clip_model.dtype).to(device) # [batch_size, n_ctx, d_model] x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype).to(device) x = x.permute(1, 0, 2) # NLD -> LND x = self.clip_model.transformer(x) x = self.clip_model.ln_final(x).type( self.clip_model.dtype ) # [len, batch_size, 512] self.embed_text = self.embed_text.to(device) x = self.embed_text(x) # [len, batch_size, 256] self.textTransEncoder = self.textTransEncoder.to(device) x = self.textTransEncoder(x) self.text_ln = self.text_ln.to(device) x = self.text_ln(x) # T, B, D -> B, T, D xf_out = x.permute(1, 0, 2) ablation_text = False if ablation_text: xf_out[:, 1:, :] = xf_out[:, 0, :].unsqueeze(1) return xf_out def load_and_freeze_clip(self, clip_version): clip_model, _ = clip.load( # clip_model.dtype=float32 clip_version, device="cpu", jit=False ) # Must set jit=False for training # Freeze CLIP weights clip_model.eval() for p in clip_model.parameters(): p.requires_grad = False return clip_model def mask_cond(self, bs, force_mask=False): """ mask motion condition , return contitional motion index in the batch """ if force_mask: cond_indices = torch.empty(0) elif self.training and self.cond_mask_prob > 0.0: mask = torch.bernoulli( torch.ones( bs, ) * self.cond_mask_prob ) # 1-> use null_cond, 0-> use real cond mask = 1.0 - mask cond_indices = torch.nonzero(mask).squeeze(-1) else: cond_indices = torch.arange(bs) return cond_indices def forward( self, x, timesteps, text=None, uncond=False, enc_text=None, ): """ Args: x: [batch_size, nframes, nfeats], timesteps: [batch_size] (int) text: list (batch_size length) of strings with input text prompts uncond: whethere using text condition Returns: [batch_size, seq_length, nfeats] """ B, T, _ = x.shape x = x.transpose(1, 2) # [bs, nfeats, nframes] if enc_text is None: enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim] cond_indices = self.mask_cond(x.shape[0], force_mask=uncond) # NOTE: need to pad to be the multiplier of 8 for the unet PADDING_NEEEDED = (16 - (T % 16)) % 16 padding = (0, PADDING_NEEEDED) x = F.pad(x, padding, value=0) x = self.unet( x, t=timesteps, cond=enc_text, cond_indices=cond_indices, ) # [bs, nfeats,, nframes] x = x[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,] return x def forward_with_cfg(self, x, timesteps, text=None, enc_text=None, cfg_scale=2.5): """ Args: x: [batch_size, nframes, nfeats], timesteps: [batch_size] (int) text: list (batch_size length) of strings with input text prompts Returns: [batch_size, max_frames, nfeats] """ global SELF_ATTN global MONITOR_ATTN MONITOR_ATTN.append([]) SELF_ATTN.append([]) B, T, _ = x.shape x = x.transpose(1, 2) # [bs, nfeats, nframes] if enc_text is None: enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim] cond_indices = self.mask_cond(B) # NOTE: need to pad to be the multiplier of 8 for the unet PADDING_NEEEDED = (16 - (T % 16)) % 16 padding = (0, PADDING_NEEEDED) x = F.pad(x, padding, value=0) combined_x = torch.cat([x, x], dim=0) combined_t = torch.cat([timesteps, timesteps], dim=0) out = self.unet( x=combined_x, t=combined_t, cond=enc_text, cond_indices=cond_indices, ) # [bs, nfeats, nframes] out = out[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,] out_cond, out_uncond = torch.split(out, len(out) // 2, dim=0) if self.vis_attn == True: i = len(MONITOR_ATTN) attnlist = MONITOR_ATTN[-1] print(i, "cross", len(attnlist)) for j, att in enumerate(attnlist): vis_attn( att, out_path=self.out_path, step=i, layer=j, shape="_".join(map(str, att.shape)), type_="cross", ) attnlist = SELF_ATTN[-1] print(i, "self", len(attnlist)) for j, att in enumerate(attnlist): vis_attn( att, out_path=self.out_path, step=i, layer=j, shape="_".join(map(str, att.shape)), type_="self", lines=False, ) if len(SELF_ATTN) % 10 == 0: SELF_ATTN = [] MONITOR_ATTN = [] return out_uncond + (cfg_scale * (out_cond - out_uncond)) if __name__ == "__main__": device = "cuda:0" n_feats = 263 num_frames = 196 text_latent_dim = 256 dim_mults = [2, 2, 2, 2] base_dim = 512 model = MotionCLR( input_feats=n_feats, text_latent_dim=text_latent_dim, base_dim=base_dim, dim_mults=dim_mults, adagn=True, zero=True, dropout=0.1, no_eff=True, cond_mask_prob=0.1, self_attention=True, ) model = model.to(device) from utils.model_load import load_model_weights checkpoint_path = "/comp_robot/chenlinghao/StableMoFusion/checkpoints/t2m/self_attn—fulllayer-ffn-drop0_1-lr1e4/model/latest.tar" new_state_dict = {} checkpoint = torch.load(checkpoint_path) ckpt2 = checkpoint.copy() ckpt2["model_ema"] = {} ckpt2["encoder"] = {} for key, value in list(checkpoint["model_ema"].items()): new_key = key.replace( "cross_attn", "clr_attn" ) # Replace 'cross_attn' with 'clr_attn' ckpt2["model_ema"][new_key] = value for key, value in list(checkpoint["encoder"].items()): new_key = key.replace( "cross_attn", "clr_attn" ) # Replace 'cross_attn' with 'clr_attn' ckpt2["encoder"][new_key] = value torch.save( ckpt2, "/comp_robot/chenlinghao/CLRpreview/checkpoints/t2m/release/model/latest.tar", ) dtype = torch.float32 bs = 1 x = torch.rand((bs, 196, 263), dtype=dtype).to(device) timesteps = torch.randint(low=0, high=1000, size=(bs,)).to(device) y = ["A man jumps to his left." for i in range(bs)] length = torch.randint(low=20, high=196, size=(bs,)).to(device) out = model(x, timesteps, text=y) print(out.shape) model.eval() out = model.forward_with_cfg(x, timesteps, text=y) print(out.shape)