Spaces:
Runtime error
Runtime error
""" | |
Copyright 2021 S-Lab | |
""" | |
from cv2 import norm | |
import torch | |
import torch.nn.functional as F | |
from torch import layer_norm, nn | |
import numpy as np | |
import clip | |
import math | |
def timestep_embedding(timesteps, dim, max_period=10000): | |
""" | |
Create sinusoidal timestep embeddings. | |
:param timesteps: a 1-D Tensor of N indices, one per batch element. | |
These may be fractional. | |
:param dim: the dimension of the output. | |
:param max_period: controls the minimum frequency of the embeddings. | |
:return: an [N x dim] Tensor of positional embeddings. | |
""" | |
half = dim // 2 | |
freqs = torch.exp( | |
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
).to(device=timesteps.device) | |
args = timesteps[:, None].float() * freqs[None] | |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
if dim % 2: | |
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
return embedding | |
def set_requires_grad(nets, requires_grad=False): | |
"""Set requies_grad for all the networks. | |
Args: | |
nets (nn.Module | list[nn.Module]): A list of networks or a single | |
network. | |
requires_grad (bool): Whether the networks require gradients or not | |
""" | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |
def zero_module(module): | |
""" | |
Zero out the parameters of a module and return it. | |
""" | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
class StylizationBlock(nn.Module): | |
def __init__(self, latent_dim, time_embed_dim, dropout): | |
super().__init__() | |
self.emb_layers = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(time_embed_dim, 2 * latent_dim), | |
) | |
self.norm = nn.LayerNorm(latent_dim) | |
self.out_layers = nn.Sequential( | |
nn.SiLU(), | |
nn.Dropout(p=dropout), | |
zero_module(nn.Linear(latent_dim, latent_dim)), | |
) | |
def forward(self, h, emb): | |
""" | |
h: B, T, D | |
emb: B, D | |
""" | |
# B, 1, 2D | |
emb_out = self.emb_layers(emb).unsqueeze(1) | |
# scale: B, 1, D / shift: B, 1, D | |
scale, shift = torch.chunk(emb_out, 2, dim=2) | |
h = self.norm(h) * (1 + scale) + shift | |
h = self.out_layers(h) | |
return h | |
class LinearTemporalSelfAttention(nn.Module): | |
def __init__(self, seq_len, latent_dim, num_head, dropout, time_embed_dim): | |
super().__init__() | |
self.num_head = num_head | |
self.norm = nn.LayerNorm(latent_dim) | |
self.query = nn.Linear(latent_dim, latent_dim) | |
self.key = nn.Linear(latent_dim, latent_dim) | |
self.value = nn.Linear(latent_dim, latent_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) | |
def forward(self, x, emb, src_mask): | |
""" | |
x: B, T, D | |
""" | |
B, T, D = x.shape | |
H = self.num_head | |
# B, T, D | |
query = self.query(self.norm(x)) | |
# B, T, D | |
key = (self.key(self.norm(x)) + (1 - src_mask) * -1000000) | |
query = F.softmax(query.view(B, T, H, -1), dim=-1) | |
key = F.softmax(key.view(B, T, H, -1), dim=1) | |
# B, T, H, HD | |
value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1) | |
# B, H, HD, HD | |
attention = torch.einsum('bnhd,bnhl->bhdl', key, value) | |
y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) | |
y = x + self.proj_out(y, emb) | |
return y | |
class LinearTemporalCrossAttention(nn.Module): | |
def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim): | |
super().__init__() | |
self.num_head = num_head | |
self.norm = nn.LayerNorm(latent_dim) | |
self.text_norm = nn.LayerNorm(text_latent_dim) | |
self.query = nn.Linear(latent_dim, latent_dim) | |
self.key = nn.Linear(text_latent_dim, latent_dim) | |
self.value = nn.Linear(text_latent_dim, latent_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) | |
def forward(self, x, xf, emb): | |
""" | |
x: B, T, D | |
xf: B, N, L | |
""" | |
B, T, D = x.shape | |
N = xf.shape[1] | |
H = self.num_head | |
# B, T, D | |
query = self.query(self.norm(x)) | |
# B, N, D | |
key = self.key(self.text_norm(xf)) | |
query = F.softmax(query.view(B, T, H, -1), dim=-1) | |
key = F.softmax(key.view(B, N, H, -1), dim=1) | |
# B, N, H, HD | |
value = self.value(self.text_norm(xf)).view(B, N, H, -1) | |
# B, H, HD, HD | |
attention = torch.einsum('bnhd,bnhl->bhdl', key, value) | |
y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) | |
y = x + self.proj_out(y, emb) | |
return y | |
class FFN(nn.Module): | |
def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim): | |
super().__init__() | |
self.linear1 = nn.Linear(latent_dim, ffn_dim) | |
self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) | |
self.activation = nn.GELU() | |
self.dropout = nn.Dropout(dropout) | |
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) | |
def forward(self, x, emb): | |
y = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
y = x + self.proj_out(y, emb) | |
return y | |
class LinearTemporalDiffusionTransformerDecoderLayer(nn.Module): | |
def __init__(self, | |
seq_len=60, | |
latent_dim=32, | |
text_latent_dim=512, | |
time_embed_dim=128, | |
ffn_dim=256, | |
num_head=4, | |
dropout=0.1): | |
super().__init__() | |
self.sa_block = LinearTemporalSelfAttention( | |
seq_len, latent_dim, num_head, dropout, time_embed_dim) | |
self.ca_block = LinearTemporalCrossAttention( | |
seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim) | |
self.ffn = FFN(latent_dim, ffn_dim, dropout, time_embed_dim) | |
def forward(self, x, xf, emb, src_mask): | |
x = self.sa_block(x, emb, src_mask) | |
x = self.ca_block(x, xf, emb) | |
x = self.ffn(x, emb) | |
return x | |
class TemporalSelfAttention(nn.Module): | |
def __init__(self, seq_len, latent_dim, num_head, dropout, time_embed_dim): | |
super().__init__() | |
self.num_head = num_head | |
self.norm = nn.LayerNorm(latent_dim) | |
self.query = nn.Linear(latent_dim, latent_dim) | |
self.key = nn.Linear(latent_dim, latent_dim) | |
self.value = nn.Linear(latent_dim, latent_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) | |
def forward(self, x, emb, src_mask): | |
""" | |
x: B, T, D | |
""" | |
B, T, D = x.shape | |
H = self.num_head | |
# B, T, 1, D | |
query = self.query(self.norm(x)).unsqueeze(2) | |
# B, 1, T, D | |
key = self.key(self.norm(x)).unsqueeze(1) | |
query = query.view(B, T, H, -1) | |
key = key.view(B, T, H, -1) | |
# B, T, T, H | |
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H) | |
attention = attention + (1 - src_mask.unsqueeze(-1)) * -100000 | |
weight = self.dropout(F.softmax(attention, dim=2)) | |
value = self.value(self.norm(x)).view(B, T, H, -1) | |
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D) | |
y = x + self.proj_out(y, emb) | |
return y | |
class TemporalCrossAttention(nn.Module): | |
def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim): | |
super().__init__() | |
self.num_head = num_head | |
self.norm = nn.LayerNorm(latent_dim) | |
self.text_norm = nn.LayerNorm(text_latent_dim) | |
self.query = nn.Linear(latent_dim, latent_dim) | |
self.key = nn.Linear(text_latent_dim, latent_dim) | |
self.value = nn.Linear(text_latent_dim, latent_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) | |
def forward(self, x, xf, emb): | |
""" | |
x: B, T, D | |
xf: B, N, L | |
""" | |
B, T, D = x.shape | |
N = xf.shape[1] | |
H = self.num_head | |
# B, T, 1, D | |
query = self.query(self.norm(x)).unsqueeze(2) | |
# B, 1, N, D | |
key = self.key(self.text_norm(xf)).unsqueeze(1) | |
query = query.view(B, T, H, -1) | |
key = key.view(B, N, H, -1) | |
# B, T, N, H | |
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H) | |
weight = self.dropout(F.softmax(attention, dim=2)) | |
value = self.value(self.text_norm(xf)).view(B, N, H, -1) | |
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D) | |
y = x + self.proj_out(y, emb) | |
return y | |
class TemporalDiffusionTransformerDecoderLayer(nn.Module): | |
def __init__(self, | |
seq_len=60, | |
latent_dim=32, | |
text_latent_dim=512, | |
time_embed_dim=128, | |
ffn_dim=256, | |
num_head=4, | |
dropout=0.1): | |
super().__init__() | |
self.sa_block = TemporalSelfAttention( | |
seq_len, latent_dim, num_head, dropout, time_embed_dim) | |
self.ca_block = TemporalCrossAttention( | |
seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim) | |
self.ffn = FFN(latent_dim, ffn_dim, dropout, time_embed_dim) | |
def forward(self, x, xf, emb, src_mask): | |
x = self.sa_block(x, emb, src_mask) | |
x = self.ca_block(x, xf, emb) | |
x = self.ffn(x, emb) | |
return x | |
class MotionTransformer(nn.Module): | |
def __init__(self, | |
input_feats, | |
num_frames=240, | |
latent_dim=512, | |
ff_size=1024, | |
num_layers=8, | |
num_heads=8, | |
dropout=0, | |
activation="gelu", | |
num_text_layers=4, | |
text_latent_dim=256, | |
text_ff_size=2048, | |
text_num_heads=4, | |
no_clip=False, | |
no_eff=False, | |
**kargs): | |
super().__init__() | |
self.num_frames = num_frames | |
self.latent_dim = latent_dim | |
self.ff_size = ff_size | |
self.num_layers = num_layers | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.activation = activation | |
self.input_feats = input_feats | |
self.time_embed_dim = latent_dim * 4 | |
self.sequence_embedding = nn.Parameter(torch.randn(num_frames, latent_dim)) | |
# Text Transformer | |
self.clip, _ = clip.load('ViT-B/32', "cpu") | |
if no_clip: | |
self.clip.initialize_parameters() | |
else: | |
set_requires_grad(self.clip, False) | |
if text_latent_dim != 512: | |
self.text_pre_proj = nn.Linear(512, text_latent_dim) | |
else: | |
self.text_pre_proj = nn.Identity() | |
textTransEncoderLayer = nn.TransformerEncoderLayer( | |
d_model=text_latent_dim, | |
nhead=text_num_heads, | |
dim_feedforward=text_ff_size, | |
dropout=dropout, | |
activation=activation) | |
self.textTransEncoder = nn.TransformerEncoder( | |
textTransEncoderLayer, | |
num_layers=num_text_layers) | |
self.text_ln = nn.LayerNorm(text_latent_dim) | |
self.text_proj = nn.Sequential( | |
nn.Linear(text_latent_dim, self.time_embed_dim) | |
) | |
# Input Embedding | |
self.joint_embed = nn.Linear(self.input_feats, self.latent_dim) | |
self.time_embed = nn.Sequential( | |
nn.Linear(self.latent_dim, self.time_embed_dim), | |
nn.SiLU(), | |
nn.Linear(self.time_embed_dim, self.time_embed_dim), | |
) | |
self.temporal_decoder_blocks = nn.ModuleList() | |
for i in range(num_layers): | |
if no_eff: | |
self.temporal_decoder_blocks.append( | |
TemporalDiffusionTransformerDecoderLayer( | |
seq_len=num_frames, | |
latent_dim=latent_dim, | |
text_latent_dim=text_latent_dim, | |
time_embed_dim=self.time_embed_dim, | |
ffn_dim=ff_size, | |
num_head=num_heads, | |
dropout=dropout | |
) | |
) | |
else: | |
self.temporal_decoder_blocks.append( | |
LinearTemporalDiffusionTransformerDecoderLayer( | |
seq_len=num_frames, | |
latent_dim=latent_dim, | |
text_latent_dim=text_latent_dim, | |
time_embed_dim=self.time_embed_dim, | |
ffn_dim=ff_size, | |
num_head=num_heads, | |
dropout=dropout | |
) | |
) | |
# Output Module | |
self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats)) | |
def encode_text(self, text, device): | |
with torch.no_grad(): | |
text = clip.tokenize(text, truncate=True).to(device) | |
x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, d_model] | |
x = x + self.clip.positional_embedding.type(self.clip.dtype) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.clip.transformer(x) | |
x = self.clip.ln_final(x).type(self.clip.dtype) | |
# T, B, D | |
x = self.text_pre_proj(x) | |
xf_out = self.textTransEncoder(x) | |
xf_out = self.text_ln(xf_out) | |
xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])]) | |
# B, T, D | |
xf_out = xf_out.permute(1, 0, 2) | |
return xf_proj, xf_out | |
def generate_src_mask(self, T, length): | |
B = len(length) | |
src_mask = torch.ones(B, T) | |
for i in range(B): | |
for j in range(length[i], T): | |
src_mask[i, j] = 0 | |
return src_mask | |
def forward(self, x, timesteps, length=None, text=None, xf_proj=None, xf_out=None): | |
""" | |
x: B, T, D | |
""" | |
B, T = x.shape[0], x.shape[1] | |
if xf_proj is None or xf_out is None: | |
xf_proj, xf_out = self.encode_text(text, x.device) | |
emb = self.time_embed(timestep_embedding(timesteps, self.latent_dim)) + xf_proj | |
# B, T, latent_dim | |
h = self.joint_embed(x) | |
h = h + self.sequence_embedding.unsqueeze(0)[:, :T, :] | |
src_mask = self.generate_src_mask(T, length).to(x.device).unsqueeze(-1) | |
for module in self.temporal_decoder_blocks: | |
h = module(h, xf_out, emb, src_mask) | |
output = self.out(h).view(B, T, -1).contiguous() | |
return output | |