Spaces:
Running
on
Zero
Running
on
Zero
import time | |
import math | |
import torch | |
from torch import nn | |
from flash_attn import flash_attn_varlen_qkvpacked_func | |
from .utils import exist, get_freqs, cat_interleave, split_interleave, to_1dimension, to_3dimension | |
def apply_rotary(x, rope): | |
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) | |
x_out = rope[..., 0] * x_[..., 0] + rope[..., 1] * x_[..., 1] | |
return x_out.reshape(*x.shape) | |
class TimeEmbeddings(nn.Module): | |
def __init__(self, model_dim, time_dim, max_period=10000.): | |
super().__init__() | |
assert model_dim % 2 == 0 | |
self.freqs = get_freqs(model_dim // 2, max_period) | |
self.in_layer = nn.Linear(model_dim, time_dim, bias=True) | |
self.activation = nn.SiLU() | |
self.out_layer = nn.Linear(time_dim, time_dim, bias=True) | |
def forward(self, time): | |
args = torch.outer(time, self.freqs.to(device=time.device)) | |
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
return self.out_layer(self.activation(self.in_layer(time_embed))) | |
class TextEmbeddings(nn.Module): | |
def __init__(self, text_dim, model_dim): | |
super().__init__() | |
self.in_layer = nn.Linear(text_dim, model_dim, bias=True) | |
def forward(self, text_embed): | |
return self.in_layer(text_embed) | |
class VisualEmbeddings(nn.Module): | |
def __init__(self, visual_dim, model_dim, patch_size): | |
super().__init__() | |
self.patch_size = patch_size | |
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) | |
def forward(self, x): | |
duration, height, width, dim = x.shape | |
x = x.view( | |
duration // self.patch_size[0], self.patch_size[0], | |
height // self.patch_size[1], self.patch_size[1], | |
width // self.patch_size[2], self.patch_size[2], dim | |
).permute(0, 2, 4, 1, 3, 5, 6).flatten(3, 6) | |
return self.in_layer(x) | |
class RoPE3D(nn.Module): | |
def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.): | |
super().__init__() | |
for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)): | |
freq = get_freqs(axes_dim // 2, max_period) | |
pos = torch.arange(ax_max_pos, dtype=freq.dtype) | |
self.register_buffer(f'args_{i}', torch.outer(pos, freq)) | |
def args(self, i, cu_seqlens): | |
args = self.__getattr__(f'args_{i}') | |
if torch.is_tensor(cu_seqlens): | |
args = torch.cat([args[:end] for end in torch.diff(cu_seqlens)]) | |
else: | |
args = args[:cu_seqlens] | |
return args | |
def forward(self, x, cu_seqlens, scale_factor=(1., 1., 1.)): | |
duration, height, width = x.shape[:-1] | |
args = [ | |
self.args(i, ax_cu_seqlens) / ax_scale_factor | |
for i, (ax_cu_seqlens, ax_scale_factor) in enumerate(zip([cu_seqlens, height, width], scale_factor)) | |
] | |
args = torch.cat([ | |
args[0].view(duration, 1, 1, -1).repeat(1, height, width, 1), | |
args[1].view(1, height, 1, -1).repeat(duration, 1, width, 1), | |
args[2].view(1, 1, width, -1).repeat(duration, height, 1, 1) | |
], dim=-1) | |
rope = torch.stack([torch.cos(args), -torch.sin(args), torch.sin(args), torch.cos(args)], dim=-1) | |
rope = rope.view(*rope.shape[:-1], 2, 2) | |
return rope.unsqueeze(-4) | |
class Modulation(nn.Module): | |
def __init__(self, time_dim, model_dim): | |
super().__init__() | |
self.activation = nn.SiLU() | |
self.out_layer = nn.Linear(time_dim, 6 * model_dim) | |
self.out_layer.weight.data.zero_() | |
self.out_layer.bias.data.zero_() | |
def forward(self, x, cu_seqlens): | |
modulation_params = self.out_layer(self.activation(x)) | |
modulation_params = modulation_params.repeat_interleave(torch.diff(cu_seqlens), dim=0) | |
self_attn_params, ff_params = torch.chunk(modulation_params, 2, dim=-1) | |
return self_attn_params, ff_params | |
class MultiheadSelfAttention(nn.Module): | |
def __init__(self, num_channels, head_dim=64, attention_type='flash'): | |
super().__init__() | |
assert num_channels % head_dim == 0 | |
self.attention_type = attention_type | |
self.num_heads = num_channels // head_dim | |
self.to_query_key_value = nn.Linear(num_channels, 3 * num_channels, bias=True) | |
self.query_norm = nn.LayerNorm(head_dim) | |
self.key_norm = nn.LayerNorm(head_dim) | |
self.output_layer = nn.Linear(num_channels, num_channels, bias=True) | |
def scaled_dot_product_attention( | |
self, visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type, | |
return_attn_probs=False | |
): | |
if self.attention_type == 'flash': | |
visual_shape, text_len = visual_query_key_value.shape[:3], text_cu_seqlens[1] | |
visual_query_key_value, visual_cu_seqlens = to_1dimension( | |
visual_query_key_value, visual_cu_seqlens, visual_shape, num_groups, attention_type | |
) | |
text_query_key_value = text_query_key_value.unsqueeze(0).expand(math.prod(num_groups), *text_query_key_value.size()) | |
query_key_value = cat_interleave(visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens) | |
cu_seqlens = visual_cu_seqlens + text_cu_seqlens | |
max_seqlen = torch.diff(cu_seqlens).max() | |
query_key_value = query_key_value.flatten(0, 1) | |
large_cu_seqlens = torch.cat([cu_seqlens + i * cu_seqlens[-1] for i in range(math.prod(num_groups))]) | |
out, softmax_lse, _ = flash_attn_varlen_qkvpacked_func(query_key_value, large_cu_seqlens, max_seqlen, return_attn_probs=True) | |
out = out.reshape(math.prod(num_groups), -1, *out.shape[1:]).flatten(-2, -1) | |
visual_out, text_out = split_interleave(out, cu_seqlens, text_len) | |
visual_out = to_3dimension(visual_out, visual_shape, num_groups, attention_type) | |
if return_attn_probs: | |
return (visual_out, text_out), softmax_lse, None | |
return visual_out, text_out | |
def forward(self, visual_embed, text_embed, rope, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type): | |
visual_shape = visual_embed.shape[:-1] | |
visual_query_key_value = self.to_query_key_value(visual_embed) | |
visual_query, visual_key, visual_value = torch.chunk(visual_query_key_value, 3, dim=-1) | |
visual_query = self.query_norm(visual_query.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_query) | |
visual_key = self.key_norm(visual_key.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_key) | |
visual_value = visual_value.reshape(*visual_shape, self.num_heads, -1) | |
visual_query = apply_rotary(visual_query, rope).type_as(visual_query) | |
visual_key = apply_rotary(visual_key, rope).type_as(visual_key) | |
visual_query_key_value = torch.stack([visual_query, visual_key, visual_value], dim=3) | |
text_len = text_embed.shape[0] | |
text_query_key_value = self.to_query_key_value(text_embed) | |
text_query, text_key, text_value = torch.chunk(text_query_key_value, 3, dim=-1) | |
text_query = self.query_norm(text_query.reshape(text_len, self.num_heads, -1)).type_as(text_query) | |
text_key = self.key_norm(text_key.reshape(text_len, self.num_heads, -1)).type_as(text_key) | |
text_value = text_value.reshape(text_len, self.num_heads, -1) | |
text_query_key_value = torch.stack([text_query, text_key, text_value], dim=1) | |
visual_out, text_out = self.scaled_dot_product_attention( | |
visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type | |
) | |
visual_out = self.output_layer(visual_out) | |
text_out = self.output_layer(text_out) | |
return visual_out, text_out | |
class MultiheadSelfAttentionTP(nn.Module): | |
def __init__(self, initial_multihead_self_attention): | |
super().__init__() | |
num_channels = initial_multihead_self_attention.to_query_key_value.weight.shape[1] | |
self.num_heads = initial_multihead_self_attention.num_heads | |
head_dim = num_channels // self.num_heads | |
self.attention_type = initial_multihead_self_attention.attention_type | |
self.to_query = nn.Linear(num_channels, num_channels, bias=True) | |
self.to_key = nn.Linear(num_channels, num_channels, bias=True) | |
self.to_value = nn.Linear(num_channels, num_channels, bias=True) | |
weight = initial_multihead_self_attention.to_query_key_value.weight | |
bias = initial_multihead_self_attention.to_query_key_value.bias | |
self.to_query.weight = torch.nn.Parameter(weight[:num_channels]) | |
self.to_key.weight = torch.nn.Parameter(weight[num_channels:2 * num_channels]) | |
self.to_value.weight = torch.nn.Parameter(weight[2 * num_channels:]) | |
self.to_query.bias = torch.nn.Parameter(bias[:num_channels]) | |
self.to_key.bias = torch.nn.Parameter(bias[num_channels:2 * num_channels]) | |
self.to_value.bias = torch.nn.Parameter(bias[2 * num_channels:]) | |
self.query_norm = initial_multihead_self_attention.query_norm | |
self.key_norm = initial_multihead_self_attention.key_norm | |
self.output_layer = initial_multihead_self_attention.output_layer | |
def scaled_dot_product_attention( | |
self, visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type, | |
return_attn_probs=False | |
): | |
if self.attention_type == 'flash': | |
visual_shape, text_len = visual_query_key_value.shape[:3], text_cu_seqlens[1] | |
visual_query_key_value, visual_cu_seqlens = to_1dimension( | |
visual_query_key_value, visual_cu_seqlens, visual_shape, num_groups, attention_type | |
) | |
text_query_key_value = text_query_key_value.unsqueeze(0).expand(math.prod(num_groups), *text_query_key_value.size()) | |
query_key_value = cat_interleave(visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens) | |
cu_seqlens = visual_cu_seqlens + text_cu_seqlens | |
max_seqlen = torch.diff(cu_seqlens).max() | |
query_key_value = query_key_value.flatten(0, 1) | |
large_cu_seqlens = torch.cat([cu_seqlens + i * cu_seqlens[-1] for i in range(math.prod(num_groups))]) | |
out, softmax_lse, _ = flash_attn_varlen_qkvpacked_func(query_key_value, large_cu_seqlens, max_seqlen, return_attn_probs=True) | |
out = out.reshape(math.prod(num_groups), -1, *out.shape[1:]).flatten(-2, -1) | |
visual_out, text_out = split_interleave(out, cu_seqlens, text_len) | |
visual_out = to_3dimension(visual_out, visual_shape, num_groups, attention_type) | |
if return_attn_probs: | |
return (visual_out, text_out), softmax_lse, None | |
return visual_out, text_out | |
def forward(self, visual_embed, text_embed, rope, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type): | |
visual_shape = visual_embed.shape[:-1] | |
visual_query, visual_key, visual_value = self.to_query(visual_embed), self.to_key(visual_embed), self.to_value(visual_embed) | |
visual_query = self.query_norm(visual_query.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_query) | |
visual_key = self.key_norm(visual_key.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_key) | |
visual_value = visual_value.reshape(*visual_shape, self.num_heads, -1) | |
visual_query = apply_rotary(visual_query, rope).type_as(visual_query) | |
visual_key = apply_rotary(visual_key, rope).type_as(visual_key) | |
visual_query_key_value = torch.stack([visual_query, visual_key, visual_value], dim=3) | |
text_len = text_embed.shape[0] | |
text_query, text_key, text_value = self.to_query(text_embed), self.to_key(text_embed), self.to_value(text_embed) | |
text_query = self.query_norm(text_query.reshape(text_len, self.num_heads, -1)).type_as(text_query) | |
text_key = self.key_norm(text_key.reshape(text_len, self.num_heads, -1)).type_as(text_key) | |
text_value = text_value.reshape(text_len, self.num_heads, -1) | |
text_query_key_value = torch.stack([text_query, text_key, text_value], dim=1) | |
visual_out, text_out = self.scaled_dot_product_attention( | |
visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type | |
) | |
visual_out = self.output_layer(visual_out) | |
text_out = self.output_layer(text_out) | |
return visual_out, text_out | |
class FeedForward(nn.Module): | |
def __init__(self, dim, ff_dim): | |
super().__init__() | |
self.in_layer = nn.Linear(dim, ff_dim, bias=True) | |
self.activation = nn.GELU() | |
self.out_layer = nn.Linear(ff_dim, dim, bias=True) | |
def forward(self, x): | |
return self.out_layer(self.activation(self.in_layer(x))) | |
class OutLayer(nn.Module): | |
def __init__(self, model_dim, time_dim, visual_dim, patch_size): | |
super().__init__() | |
self.patch_size = patch_size | |
self.norm = nn.LayerNorm(model_dim, elementwise_affine=True) | |
self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True) | |
self.modulation_activation = nn.SiLU() | |
self.modulation_out = nn.Linear(time_dim, 2 * model_dim, bias=True) | |
self.modulation_out.weight.data.zero_() | |
self.modulation_out.bias.data.zero_() | |
def forward(self, visual_embed, text_embed, time_embed, visual_cu_seqlens): | |
modulation_params = self.modulation_out(self.modulation_activation(time_embed)) | |
modulation_params = modulation_params.repeat_interleave(torch.diff(visual_cu_seqlens), dim=0) | |
shift, scale = torch.chunk(modulation_params, 2, dim=-1) | |
visual_embed = self.norm(visual_embed) * (scale[:, None, None, :] + 1) + shift[:, None, None, :] | |
x = self.out_layer(visual_embed) | |
duration, height, width, dim = x.shape | |
x = x.view( | |
duration, height, width, | |
-1, self.patch_size[0], self.patch_size[1], self.patch_size[2] | |
).permute(0, 4, 1, 5, 2, 6, 3).flatten(0, 1).flatten(1, 2).flatten(2, 3) | |
return x | |