MotionCLR / models /unet.py
EvanTHU's picture
Update models/unet.py
df42902 verified
raw
history blame
36.4 kB
import clip
import math
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from einops.layers.torch import Rearrange
from einops import rearrange
import matplotlib.pyplot as plt
import os
import torch.nn as nn
# Custom LayerNorm class to handle fp16
class CustomLayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor):
if self.weight.dtype == torch.float32:
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
else:
return super().forward(x)
# Function to replace LayerNorm in CLIP model with CustomLayerNorm
def replace_layer_norm(model):
for name, module in model.named_children():
if isinstance(module, nn.LayerNorm):
setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine).cuda())
else:
replace_layer_norm(module) # Recursively apply to all submodules
MONITOR_ATTN = []
SELF_ATTN = []
def vis_attn(att, out_path, step, layer, shape, type_="self", lines=True):
if lines:
plt.figure(figsize=(10, 3))
for token_index in range(att.shape[1]):
plt.plot(att[:, token_index], label=f"Token {token_index}")
plt.title("Attention Values for Each Token")
plt.xlabel("time")
plt.ylabel("Attention Value")
plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1))
# save image
savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_lines_{shape}.png")
os.makedirs(os.path.dirname(savepath), exist_ok=True)
plt.savefig(savepath, bbox_inches="tight")
np.save(savepath.replace(".png", ".npy"), att)
else:
plt.figure(figsize=(10, 10))
plt.imshow(att.transpose(), cmap="viridis", aspect="auto")
plt.colorbar()
plt.title("Attention Matrix Heatmap")
plt.ylabel("time")
plt.xlabel("time")
# save image
savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_heatmap_{shape}.png")
os.makedirs(os.path.dirname(savepath), exist_ok=True)
plt.savefig(savepath, bbox_inches="tight")
np.save(savepath.replace(".png", ".npy"), att)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class FFN(nn.Module):
def __init__(self, latent_dim, ffn_dim, dropout):
super().__init__()
self.linear1 = nn.Linear(latent_dim, ffn_dim)
self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
self.activation = nn.GELU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
y = self.linear2(self.dropout(self.activation(self.linear1(x))))
y = x + y
return y
class Conv1dAdaGNBlock(nn.Module):
"""
Conv1d --> GroupNorm --> scale,shift --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4):
super().__init__()
self.out_channels = out_channels
self.block = nn.Conv1d(
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.avtication = nn.Mish()
def forward(self, x, scale, shift):
"""
Args:
x: [bs, nfeat, nframes]
scale: [bs, out_feat, 1]
shift: [bs, out_feat, 1]
"""
x = self.block(x)
batch_size, channels, horizon = x.size()
x = rearrange(
x, "batch channels horizon -> (batch horizon) channels"
) # [bs*seq, nfeats]
x = self.group_norm(x)
x = rearrange(
x.reshape(batch_size, horizon, channels),
"batch horizon channels -> batch channels horizon",
)
x = ada_shift_scale(x, shift, scale)
return self.avtication(x)
class SelfAttention(nn.Module):
def __init__(
self,
latent_dim,
text_latent_dim,
num_heads: int = 8,
dropout: float = 0.0,
log_attn=False,
edit_config=None,
):
super().__init__()
self.num_head = num_heads
self.norm = nn.LayerNorm(latent_dim)
self.query = nn.Linear(latent_dim, latent_dim)
self.key = nn.Linear(latent_dim, latent_dim)
self.value = nn.Linear(latent_dim, latent_dim)
self.dropout = nn.Dropout(dropout)
self.edit_config = edit_config
self.log_attn = log_attn
def forward(self, x):
"""
x: B, T, D
xf: B, N, L
"""
B, T, D = x.shape
N = x.shape[1]
assert N == T
H = self.num_head
# B, T, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, N, D
key = self.key(self.norm(x)).unsqueeze(1)
query = query.view(B, T, H, -1)
key = key.view(B, N, H, -1)
# style transfer motion editing
style_tranfer = self.edit_config.style_tranfer.use
if style_tranfer:
if (
len(SELF_ATTN)
<= self.edit_config.style_tranfer.style_transfer_steps_end
):
query[1] = query[0]
# example based motion generation
example_based = self.edit_config.example_based.use
if example_based:
if len(SELF_ATTN) == self.edit_config.example_based.example_based_steps_end:
temp_seed = self.edit_config.example_based.temp_seed
for id_ in range(query.shape[0] - 1):
with torch.random.fork_rng():
torch.manual_seed(temp_seed)
tensor = query[0]
chunks = torch.split(
tensor, self.edit_config.example_based.chunk_size, dim=0
)
shuffled_indices = torch.randperm(len(chunks))
shuffled_chunks = [chunks[i] for i in shuffled_indices]
shuffled_tensor = torch.cat(shuffled_chunks, dim=0)
query[1 + id_] = shuffled_tensor
temp_seed += self.edit_config.example_based.temp_seed_bar
# time shift motion editing (q, k)
time_shift = self.edit_config.time_shift.use
if time_shift:
if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end:
part1 = int(
key.shape[1] * self.edit_config.time_shift.time_shift_ratio // 1
)
part2 = int(
key.shape[1]
* (1 - self.edit_config.time_shift.time_shift_ratio)
// 1
)
q_front_part = query[0, :part1, :, :]
q_back_part = query[0, -part2:, :, :]
new_q = torch.cat((q_back_part, q_front_part), dim=0)
query[1] = new_q
k_front_part = key[0, :part1, :, :]
k_back_part = key[0, -part2:, :, :]
new_k = torch.cat((k_back_part, k_front_part), dim=0)
key[1] = new_k
# B, T, N, H
attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
# for counting the step and logging attention maps
try:
attention_matrix = (
weight[0, :, :].mean(dim=-1).detach().cpu().numpy().astype(float)
)
SELF_ATTN[-1].append(attention_matrix)
except:
pass
# attention manipulation for replacement
attention_manipulation = self.edit_config.manipulation.use
if attention_manipulation:
if len(SELF_ATTN) <= self.edit_config.manipulation.manipulation_steps_end:
weight[1, :, :, :] = weight[0, :, :, :]
value = self.value(self.norm(x)).view(B, N, H, -1)
# time shift motion editing (v)
if time_shift:
if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end:
v_front_part = value[0, :part1, :, :]
v_back_part = value[0, -part2:, :, :]
new_v = torch.cat((v_back_part, v_front_part), dim=0)
value[1] = new_v
y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D)
return y
class TimestepEmbedder(nn.Module):
def __init__(self, d_model, max_len=5000):
super(TimestepEmbedder, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
self.pe = self.pe.cuda()
return self.pe[x]
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
self.conv = self.conv.cuda()
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim_in, dim_out=None):
super().__init__()
dim_out = dim_out or dim_in
self.conv = nn.ConvTranspose1d(dim_in, dim_out, 4, 2, 1)
def forward(self, x):
self.conv = self.conv.cuda()
return self.conv(x)
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4, zero=False):
super().__init__()
self.out_channels = out_channels
self.block = nn.Conv1d(
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
)
self.norm = nn.GroupNorm(n_groups, out_channels)
self.activation = nn.Mish()
if zero:
# zero init the convolution
nn.init.zeros_(self.block.weight)
nn.init.zeros_(self.block.bias)
def forward(self, x):
"""
Args:
x: [bs, nfeat, nframes]
"""
x = self.block(x)
batch_size, channels, horizon = x.size()
x = rearrange(
x, "batch channels horizon -> (batch horizon) channels"
) # [bs*seq, nfeats]
x = self.norm(x)
x = rearrange(
x.reshape(batch_size, horizon, channels),
"batch horizon channels -> batch channels horizon",
)
return self.activation(x)
def ada_shift_scale(x, shift, scale):
return x * (1 + scale) + shift
class ResidualTemporalBlock(nn.Module):
def __init__(
self,
inp_channels,
out_channels,
embed_dim,
kernel_size=5,
zero=True,
n_groups=8,
dropout: float = 0.1,
adagn=True,
):
super().__init__()
self.adagn = adagn
self.blocks = nn.ModuleList(
[
# adagn only the first conv (following guided-diffusion)
(
Conv1dAdaGNBlock(inp_channels, out_channels, kernel_size, n_groups)
if adagn
else Conv1dBlock(inp_channels, out_channels, kernel_size)
),
Conv1dBlock(
out_channels, out_channels, kernel_size, n_groups, zero=zero
),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
# adagn = scale and shift
nn.Linear(embed_dim, out_channels * 2 if adagn else out_channels),
Rearrange("batch t -> batch t 1"),
)
self.dropout = nn.Dropout(dropout)
if zero:
nn.init.zeros_(self.time_mlp[1].weight)
nn.init.zeros_(self.time_mlp[1].bias)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1)
if inp_channels != out_channels
else nn.Identity()
)
def forward(self, x, time_embeds=None):
"""
x : [ batch_size x inp_channels x nframes ]
t : [ batch_size x embed_dim ]
returns: [ batch_size x out_channels x nframes ]
"""
if self.adagn:
scale, shift = self.time_mlp(time_embeds).chunk(2, dim=1)
out = self.blocks[0](x, scale, shift)
else:
out = self.blocks[0](x) + self.time_mlp(time_embeds)
out = self.blocks[1](out)
out = self.dropout(out)
return out + self.residual_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
latent_dim,
text_latent_dim,
num_heads: int = 8,
dropout: float = 0.0,
log_attn=False,
edit_config=None,
):
super().__init__()
self.num_head = num_heads
self.norm = nn.LayerNorm(latent_dim)
self.text_norm = nn.LayerNorm(text_latent_dim)
self.query = nn.Linear(latent_dim, latent_dim)
self.key = nn.Linear(text_latent_dim, latent_dim)
self.value = nn.Linear(text_latent_dim, latent_dim)
self.dropout = nn.Dropout(dropout)
self.edit_config = edit_config
self.log_attn = log_attn
def forward(self, x, xf):
"""
x: B, T, D
xf: B, N, L
"""
B, T, D = x.shape
N = xf.shape[1]
H = self.num_head
# B, T, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, N, D
key = self.key(self.text_norm(xf)).unsqueeze(1)
query = query.view(B, T, H, -1)
key = key.view(B, N, H, -1)
# B, T, N, H
attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
# attention reweighting for (de)-emphasizing motion
if self.edit_config.reweighting_attn.use:
reweighting_attn = self.edit_config.reweighting_attn.reweighting_attn_weight
if self.edit_config.reweighting_attn.idx == -1:
# read idxs from txt file
with open("./assets/reweighting_idx.txt", "r") as f:
idxs = f.readlines()
else:
# gradio demo mode
idxs = [0, self.edit_config.reweighting_attn.idx]
idxs = [int(idx) for idx in idxs]
for i in range(len(idxs)):
weight[i, :, 1 + idxs[i]] = weight[i, :, 1 + idxs[i]] + reweighting_attn
weight[i, :, 1 + idxs[i] + 1] = (
weight[i, :, 1 + idxs[i] + 1] + reweighting_attn
)
# for counting the step and logging attention maps
try:
attention_matrix = (
weight[0, :, 1 : 1 + 3]
.mean(dim=-1)
.detach()
.cpu()
.numpy()
.astype(float)
)
MONITOR_ATTN[-1].append(attention_matrix)
except:
pass
# erasing motion (autually is the deemphasizing motion)
erasing_motion = self.edit_config.erasing_motion.use
if erasing_motion:
reweighting_attn = self.edit_config.erasing_motion.erasing_motion_weight
begin = self.edit_config.erasing_motion.time_start
end = self.edit_config.erasing_motion.time_end
idx = self.edit_config.erasing_motion.idx
if reweighting_attn > 0.01 or reweighting_attn < -0.01:
weight[1, int(T * begin) : int(T * end), idx] = (
weight[1, int(T * begin) : int(T * end) :, idx] * reweighting_attn
)
weight[1, int(T * begin) : int(T * end), idx + 1] = (
weight[1, int(T * begin) : int(T * end), idx + 1] * reweighting_attn
)
# attention manipulation for motion replacement
manipulation = self.edit_config.manipulation.use
if manipulation:
if (
len(MONITOR_ATTN)
<= self.edit_config.manipulation.manipulation_steps_end_crossattn
):
word_idx = self.edit_config.manipulation.word_idx
weight[1, :, : 1 + word_idx, :] = weight[0, :, : 1 + word_idx, :]
weight[1, :, 1 + word_idx + 1 :, :] = weight[
0, :, 1 + word_idx + 1 :, :
]
value = self.value(self.text_norm(xf)).view(B, N, H, -1)
y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D)
return y
class ResidualCLRAttentionLayer(nn.Module):
def __init__(
self,
dim1,
dim2,
num_heads: int = 8,
dropout: float = 0.1,
no_eff: bool = False,
self_attention: bool = False,
log_attn=False,
edit_config=None,
):
super(ResidualCLRAttentionLayer, self).__init__()
self.dim1 = dim1
self.dim2 = dim2
self.num_heads = num_heads
# Multi-Head Attention Layer
if no_eff:
self.cross_attention = CrossAttention(
latent_dim=dim1,
text_latent_dim=dim2,
num_heads=num_heads,
dropout=dropout,
log_attn=log_attn,
edit_config=edit_config,
)
else:
self.cross_attention = LinearCrossAttention(
latent_dim=dim1,
text_latent_dim=dim2,
num_heads=num_heads,
dropout=dropout,
log_attn=log_attn,
)
if self_attention:
self.self_attn_use = True
self.self_attention = SelfAttention(
latent_dim=dim1,
text_latent_dim=dim2,
num_heads=num_heads,
dropout=dropout,
log_attn=log_attn,
edit_config=edit_config,
)
else:
self.self_attn_use = False
def forward(self, input_tensor, condition_tensor, cond_indices):
"""
input_tensor :B, D, L
condition_tensor: B, L, D
"""
if cond_indices.numel() == 0:
return input_tensor
# self attention
if self.self_attn_use:
x = input_tensor
x = x.permute(0, 2, 1) # (batch_size, seq_length, feat_dim)
x = self.self_attention(x)
x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length)
input_tensor = input_tensor + x
x = input_tensor
# cross attention
x = x[cond_indices].permute(0, 2, 1) # (batch_size, seq_length, feat_dim)
x = self.cross_attention(x, condition_tensor[cond_indices])
x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length)
input_tensor[cond_indices] = input_tensor[cond_indices] + x
return input_tensor
class CLRBlock(nn.Module):
def __init__(
self,
dim_in,
dim_out,
cond_dim,
time_dim,
adagn=True,
zero=True,
no_eff=False,
self_attention=False,
dropout: float = 0.1,
log_attn=False,
edit_config=None,
) -> None:
super().__init__()
self.conv1d = ResidualTemporalBlock(
dim_in, dim_out, embed_dim=time_dim, adagn=adagn, zero=zero, dropout=dropout
)
self.clr_attn = ResidualCLRAttentionLayer(
dim1=dim_out,
dim2=cond_dim,
no_eff=no_eff,
dropout=dropout,
self_attention=self_attention,
log_attn=log_attn,
edit_config=edit_config,
)
# import pdb; pdb.set_trace()
self.ffn = FFN(dim_out, dim_out * 4, dropout=dropout)
def forward(self, x, t, cond, cond_indices=None):
x = self.conv1d(x, t)
x = self.clr_attn(x, cond, cond_indices)
x = self.ffn(x.permute(0, 2, 1)).permute(0, 2, 1)
return x
class CondUnet1D(nn.Module):
"""
Diffusion's style UNET with 1D convolution and adaptive group normalization for motion suquence denoising,
cross-attention to introduce conditional prompts (like text).
"""
def __init__(
self,
input_dim,
cond_dim,
dim=128,
dim_mults=(1, 2, 4, 8),
dims=None,
time_dim=512,
adagn=True,
zero=True,
dropout=0.1,
no_eff=False,
self_attention=False,
log_attn=False,
edit_config=None,
):
super().__init__()
if not dims:
dims = [input_dim, *map(lambda m: int(dim * m), dim_mults)] ##[d, d,2d,4d]
print("dims: ", dims, "mults: ", dim_mults)
in_out = list(zip(dims[:-1], dims[1:]))
self.time_mlp = nn.Sequential(
TimestepEmbedder(time_dim),
nn.Linear(time_dim, time_dim * 4),
nn.Mish(),
nn.Linear(time_dim * 4, time_dim),
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
self.downs.append(
nn.ModuleList(
[
CLRBlock(
dim_in,
dim_out,
cond_dim,
time_dim,
adagn=adagn,
zero=zero,
no_eff=no_eff,
dropout=dropout,
self_attention=self_attention,
log_attn=log_attn,
edit_config=edit_config,
),
CLRBlock(
dim_out,
dim_out,
cond_dim,
time_dim,
adagn=adagn,
zero=zero,
no_eff=no_eff,
dropout=dropout,
self_attention=self_attention,
log_attn=log_attn,
edit_config=edit_config,
),
Downsample1d(dim_out),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = CLRBlock(
dim_in=mid_dim,
dim_out=mid_dim,
cond_dim=cond_dim,
time_dim=time_dim,
adagn=adagn,
zero=zero,
no_eff=no_eff,
dropout=dropout,
self_attention=self_attention,
log_attn=log_attn,
edit_config=edit_config,
)
self.mid_block2 = CLRBlock(
dim_in=mid_dim,
dim_out=mid_dim,
cond_dim=cond_dim,
time_dim=time_dim,
adagn=adagn,
zero=zero,
no_eff=no_eff,
dropout=dropout,
self_attention=self_attention,
log_attn=log_attn,
edit_config=edit_config,
)
last_dim = mid_dim
for ind, dim_out in enumerate(reversed(dims[1:])):
self.ups.append(
nn.ModuleList(
[
Upsample1d(last_dim, dim_out),
CLRBlock(
dim_out * 2,
dim_out,
cond_dim,
time_dim,
adagn=adagn,
zero=zero,
no_eff=no_eff,
dropout=dropout,
self_attention=self_attention,
log_attn=log_attn,
edit_config=edit_config,
),
CLRBlock(
dim_out,
dim_out,
cond_dim,
time_dim,
adagn=adagn,
zero=zero,
no_eff=no_eff,
dropout=dropout,
self_attention=self_attention,
log_attn=log_attn,
edit_config=edit_config,
),
]
)
)
last_dim = dim_out
self.final_conv = nn.Conv1d(dim_out, input_dim, 1)
if zero:
nn.init.zeros_(self.final_conv.weight)
nn.init.zeros_(self.final_conv.bias)
def forward(
self,
x,
t,
cond,
cond_indices,
):
self.time_mlp = self.time_mlp.cuda()
temb = self.time_mlp(t)
h = []
for block1, block2, downsample in self.downs:
block1 = block1.cuda()
block2 = block2.cuda()
x = block1(x, temb, cond, cond_indices)
x = block2(x, temb, cond, cond_indices)
h.append(x)
x = downsample(x)
self.mid_block1 = self.mid_block1.cuda()
self.mid_block2 = self.mid_block2.cuda()
x = self.mid_block1(x, temb, cond, cond_indices)
x = self.mid_block2(x, temb, cond, cond_indices)
for upsample, block1, block2 in self.ups:
x = upsample(x)
x = torch.cat((x, h.pop()), dim=1)
block1 = block1.cuda()
block2 = block2.cuda()
x = block1(x, temb, cond, cond_indices)
x = block2(x, temb, cond, cond_indices)
self.final_conv = self.final_conv.cuda()
x = self.final_conv(x)
return x
class MotionCLR(nn.Module):
"""
Diffuser's style UNET for text-to-motion task.
"""
def __init__(
self,
input_feats,
base_dim=128,
dim_mults=(1, 2, 2, 2),
dims=None,
adagn=True,
zero=True,
dropout=0.1,
no_eff=False,
time_dim=512,
latent_dim=256,
cond_mask_prob=0.1,
clip_dim=512,
clip_version="ViT-B/32",
text_latent_dim=256,
text_ff_size=2048,
text_num_heads=4,
activation="gelu",
num_text_layers=4,
self_attention=False,
vis_attn=False,
edit_config=None,
out_path=None,
):
super().__init__()
self.input_feats = input_feats
self.dim_mults = dim_mults
self.base_dim = base_dim
self.latent_dim = latent_dim
self.cond_mask_prob = cond_mask_prob
self.vis_attn = vis_attn
self.counting_map = []
self.out_path = out_path
print(
f"The T2M Unet mask the text prompt by {self.cond_mask_prob} prob. in training"
)
# text encoder
self.embed_text = nn.Linear(clip_dim, text_latent_dim)
self.clip_version = clip_version
self.clip_model = self.load_and_freeze_clip(clip_version)
replace_layer_norm(self.clip_model)
textTransEncoderLayer = nn.TransformerEncoderLayer(
d_model=text_latent_dim,
nhead=text_num_heads,
dim_feedforward=text_ff_size,
dropout=dropout,
activation=activation,
)
self.textTransEncoder = nn.TransformerEncoder(
textTransEncoderLayer, num_layers=num_text_layers
)
self.text_ln = nn.LayerNorm(text_latent_dim)
self.unet = CondUnet1D(
input_dim=self.input_feats,
cond_dim=text_latent_dim,
dim=self.base_dim,
dim_mults=self.dim_mults,
adagn=adagn,
zero=zero,
dropout=dropout,
no_eff=no_eff,
dims=dims,
time_dim=time_dim,
self_attention=self_attention,
log_attn=self.vis_attn,
edit_config=edit_config,
)
self.clip_model = self.clip_model.cuda()
self.embed_text = self.embed_text.cuda()
self.textTransEncoder = self.textTransEncoder.cuda()
self.text_ln = self.text_ln.cuda()
self.unet = self.unet.cuda()
def encode_text(self, raw_text, device):
self.clip_model.token_embedding = self.clip_model.token_embedding.to(device)
self.clip_model.transformer = self.clip_model.transformer.to(device)
self.clip_model.ln_final = self.clip_model.ln_final.to(device)
with torch.no_grad():
texts = clip.tokenize(raw_text, truncate=True).to(
device
) # [bs, context_length] # if n_tokens > 77 -> will truncate
x = self.clip_model.token_embedding(texts).type(self.clip_model.dtype).to(device) # [batch_size, n_ctx, d_model]
x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype).to(device)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip_model.transformer(x)
x = self.clip_model.ln_final(x).type(
self.clip_model.dtype
) # [len, batch_size, 512]
self.embed_text = self.embed_text.to(device)
x = self.embed_text(x) # [len, batch_size, 256]
self.textTransEncoder = self.textTransEncoder.to(device)
x = self.textTransEncoder(x)
self.text_ln = self.text_ln.to(device)
x = self.text_ln(x)
# T, B, D -> B, T, D
xf_out = x.permute(1, 0, 2)
ablation_text = False
if ablation_text:
xf_out[:, 1:, :] = xf_out[:, 0, :].unsqueeze(1)
return xf_out
def load_and_freeze_clip(self, clip_version):
clip_model, _ = clip.load( # clip_model.dtype=float32
clip_version, device="cpu", jit=False
) # Must set jit=False for training
# Freeze CLIP weights
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
return clip_model
def mask_cond(self, bs, force_mask=False):
"""
mask motion condition , return contitional motion index in the batch
"""
if force_mask:
cond_indices = torch.empty(0)
elif self.training and self.cond_mask_prob > 0.0:
mask = torch.bernoulli(
torch.ones(
bs,
)
* self.cond_mask_prob
) # 1-> use null_cond, 0-> use real cond
mask = 1.0 - mask
cond_indices = torch.nonzero(mask).squeeze(-1)
else:
cond_indices = torch.arange(bs)
return cond_indices
def forward(
self,
x,
timesteps,
text=None,
uncond=False,
enc_text=None,
):
"""
Args:
x: [batch_size, nframes, nfeats],
timesteps: [batch_size] (int)
text: list (batch_size length) of strings with input text prompts
uncond: whethere using text condition
Returns: [batch_size, seq_length, nfeats]
"""
B, T, _ = x.shape
x = x.transpose(1, 2) # [bs, nfeats, nframes]
if enc_text is None:
enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim]
cond_indices = self.mask_cond(x.shape[0], force_mask=uncond)
# NOTE: need to pad to be the multiplier of 8 for the unet
PADDING_NEEEDED = (16 - (T % 16)) % 16
padding = (0, PADDING_NEEEDED)
x = F.pad(x, padding, value=0)
x = self.unet(
x,
t=timesteps,
cond=enc_text,
cond_indices=cond_indices,
) # [bs, nfeats,, nframes]
x = x[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,]
return x
def forward_with_cfg(self, x, timesteps, text=None, enc_text=None, cfg_scale=2.5):
"""
Args:
x: [batch_size, nframes, nfeats],
timesteps: [batch_size] (int)
text: list (batch_size length) of strings with input text prompts
Returns: [batch_size, max_frames, nfeats]
"""
global SELF_ATTN
global MONITOR_ATTN
MONITOR_ATTN.append([])
SELF_ATTN.append([])
B, T, _ = x.shape
x = x.transpose(1, 2) # [bs, nfeats, nframes]
if enc_text is None:
enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim]
cond_indices = self.mask_cond(B)
# NOTE: need to pad to be the multiplier of 8 for the unet
PADDING_NEEEDED = (16 - (T % 16)) % 16
padding = (0, PADDING_NEEEDED)
x = F.pad(x, padding, value=0)
combined_x = torch.cat([x, x], dim=0)
combined_t = torch.cat([timesteps, timesteps], dim=0)
out = self.unet(
x=combined_x,
t=combined_t,
cond=enc_text,
cond_indices=cond_indices,
) # [bs, nfeats, nframes]
out = out[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,]
out_cond, out_uncond = torch.split(out, len(out) // 2, dim=0)
if self.vis_attn == True:
i = len(MONITOR_ATTN)
attnlist = MONITOR_ATTN[-1]
print(i, "cross", len(attnlist))
for j, att in enumerate(attnlist):
vis_attn(
att,
out_path=self.out_path,
step=i,
layer=j,
shape="_".join(map(str, att.shape)),
type_="cross",
)
attnlist = SELF_ATTN[-1]
print(i, "self", len(attnlist))
for j, att in enumerate(attnlist):
vis_attn(
att,
out_path=self.out_path,
step=i,
layer=j,
shape="_".join(map(str, att.shape)),
type_="self",
lines=False,
)
if len(SELF_ATTN) % 10 == 0:
SELF_ATTN = []
MONITOR_ATTN = []
return out_uncond + (cfg_scale * (out_cond - out_uncond))
if __name__ == "__main__":
device = "cuda:0"
n_feats = 263
num_frames = 196
text_latent_dim = 256
dim_mults = [2, 2, 2, 2]
base_dim = 512
model = MotionCLR(
input_feats=n_feats,
text_latent_dim=text_latent_dim,
base_dim=base_dim,
dim_mults=dim_mults,
adagn=True,
zero=True,
dropout=0.1,
no_eff=True,
cond_mask_prob=0.1,
self_attention=True,
)
model = model.to(device)
from utils.model_load import load_model_weights
checkpoint_path = "/comp_robot/chenlinghao/StableMoFusion/checkpoints/t2m/self_attn—fulllayer-ffn-drop0_1-lr1e4/model/latest.tar"
new_state_dict = {}
checkpoint = torch.load(checkpoint_path)
ckpt2 = checkpoint.copy()
ckpt2["model_ema"] = {}
ckpt2["encoder"] = {}
for key, value in list(checkpoint["model_ema"].items()):
new_key = key.replace(
"cross_attn", "clr_attn"
) # Replace 'cross_attn' with 'clr_attn'
ckpt2["model_ema"][new_key] = value
for key, value in list(checkpoint["encoder"].items()):
new_key = key.replace(
"cross_attn", "clr_attn"
) # Replace 'cross_attn' with 'clr_attn'
ckpt2["encoder"][new_key] = value
torch.save(
ckpt2,
"/comp_robot/chenlinghao/CLRpreview/checkpoints/t2m/release/model/latest.tar",
)
dtype = torch.float32
bs = 1
x = torch.rand((bs, 196, 263), dtype=dtype).to(device)
timesteps = torch.randint(low=0, high=1000, size=(bs,)).to(device)
y = ["A man jumps to his left." for i in range(bs)]
length = torch.randint(low=20, high=196, size=(bs,)).to(device)
out = model(x, timesteps, text=y)
print(out.shape)
model.eval()
out = model.forward_with_cfg(x, timesteps, text=y)
print(out.shape)