ai-forever's picture
add files
9d3c2b7
import time
import math
import torch
from torch import nn
from flash_attn import flash_attn_varlen_qkvpacked_func
from .utils import exist, get_freqs, cat_interleave, split_interleave, to_1dimension, to_3dimension
def apply_rotary(x, rope):
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
x_out = rope[..., 0] * x_[..., 0] + rope[..., 1] * x_[..., 1]
return x_out.reshape(*x.shape)
class TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.):
super().__init__()
assert model_dim % 2 == 0
self.freqs = get_freqs(model_dim // 2, max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
def forward(self, time):
args = torch.outer(time, self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return self.out_layer(self.activation(self.in_layer(time_embed)))
class TextEmbeddings(nn.Module):
def __init__(self, text_dim, model_dim):
super().__init__()
self.in_layer = nn.Linear(text_dim, model_dim, bias=True)
def forward(self, text_embed):
return self.in_layer(text_embed)
class VisualEmbeddings(nn.Module):
def __init__(self, visual_dim, model_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
def forward(self, x):
duration, height, width, dim = x.shape
x = x.view(
duration // self.patch_size[0], self.patch_size[0],
height // self.patch_size[1], self.patch_size[1],
width // self.patch_size[2], self.patch_size[2], dim
).permute(0, 2, 4, 1, 3, 5, 6).flatten(3, 6)
return self.in_layer(x)
class RoPE3D(nn.Module):
def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.):
super().__init__()
for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)):
freq = get_freqs(axes_dim // 2, max_period)
pos = torch.arange(ax_max_pos, dtype=freq.dtype)
self.register_buffer(f'args_{i}', torch.outer(pos, freq))
def args(self, i, cu_seqlens):
args = self.__getattr__(f'args_{i}')
if torch.is_tensor(cu_seqlens):
args = torch.cat([args[:end] for end in torch.diff(cu_seqlens)])
else:
args = args[:cu_seqlens]
return args
def forward(self, x, cu_seqlens, scale_factor=(1., 1., 1.)):
duration, height, width = x.shape[:-1]
args = [
self.args(i, ax_cu_seqlens) / ax_scale_factor
for i, (ax_cu_seqlens, ax_scale_factor) in enumerate(zip([cu_seqlens, height, width], scale_factor))
]
args = torch.cat([
args[0].view(duration, 1, 1, -1).repeat(1, height, width, 1),
args[1].view(1, height, 1, -1).repeat(duration, 1, width, 1),
args[2].view(1, 1, width, -1).repeat(duration, height, 1, 1)
], dim=-1)
rope = torch.stack([torch.cos(args), -torch.sin(args), torch.sin(args), torch.cos(args)], dim=-1)
rope = rope.view(*rope.shape[:-1], 2, 2)
return rope.unsqueeze(-4)
class Modulation(nn.Module):
def __init__(self, time_dim, model_dim):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, 6 * model_dim)
self.out_layer.weight.data.zero_()
self.out_layer.bias.data.zero_()
def forward(self, x, cu_seqlens):
modulation_params = self.out_layer(self.activation(x))
modulation_params = modulation_params.repeat_interleave(torch.diff(cu_seqlens), dim=0)
self_attn_params, ff_params = torch.chunk(modulation_params, 2, dim=-1)
return self_attn_params, ff_params
class MultiheadSelfAttention(nn.Module):
def __init__(self, num_channels, head_dim=64, attention_type='flash'):
super().__init__()
assert num_channels % head_dim == 0
self.attention_type = attention_type
self.num_heads = num_channels // head_dim
self.to_query_key_value = nn.Linear(num_channels, 3 * num_channels, bias=True)
self.query_norm = nn.LayerNorm(head_dim)
self.key_norm = nn.LayerNorm(head_dim)
self.output_layer = nn.Linear(num_channels, num_channels, bias=True)
def scaled_dot_product_attention(
self, visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type,
return_attn_probs=False
):
if self.attention_type == 'flash':
visual_shape, text_len = visual_query_key_value.shape[:3], text_cu_seqlens[1]
visual_query_key_value, visual_cu_seqlens = to_1dimension(
visual_query_key_value, visual_cu_seqlens, visual_shape, num_groups, attention_type
)
text_query_key_value = text_query_key_value.unsqueeze(0).expand(math.prod(num_groups), *text_query_key_value.size())
query_key_value = cat_interleave(visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens)
cu_seqlens = visual_cu_seqlens + text_cu_seqlens
max_seqlen = torch.diff(cu_seqlens).max()
query_key_value = query_key_value.flatten(0, 1)
large_cu_seqlens = torch.cat([cu_seqlens + i * cu_seqlens[-1] for i in range(math.prod(num_groups))])
out, softmax_lse, _ = flash_attn_varlen_qkvpacked_func(query_key_value, large_cu_seqlens, max_seqlen, return_attn_probs=True)
out = out.reshape(math.prod(num_groups), -1, *out.shape[1:]).flatten(-2, -1)
visual_out, text_out = split_interleave(out, cu_seqlens, text_len)
visual_out = to_3dimension(visual_out, visual_shape, num_groups, attention_type)
if return_attn_probs:
return (visual_out, text_out), softmax_lse, None
return visual_out, text_out
def forward(self, visual_embed, text_embed, rope, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type):
visual_shape = visual_embed.shape[:-1]
visual_query_key_value = self.to_query_key_value(visual_embed)
visual_query, visual_key, visual_value = torch.chunk(visual_query_key_value, 3, dim=-1)
visual_query = self.query_norm(visual_query.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_query)
visual_key = self.key_norm(visual_key.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_key)
visual_value = visual_value.reshape(*visual_shape, self.num_heads, -1)
visual_query = apply_rotary(visual_query, rope).type_as(visual_query)
visual_key = apply_rotary(visual_key, rope).type_as(visual_key)
visual_query_key_value = torch.stack([visual_query, visual_key, visual_value], dim=3)
text_len = text_embed.shape[0]
text_query_key_value = self.to_query_key_value(text_embed)
text_query, text_key, text_value = torch.chunk(text_query_key_value, 3, dim=-1)
text_query = self.query_norm(text_query.reshape(text_len, self.num_heads, -1)).type_as(text_query)
text_key = self.key_norm(text_key.reshape(text_len, self.num_heads, -1)).type_as(text_key)
text_value = text_value.reshape(text_len, self.num_heads, -1)
text_query_key_value = torch.stack([text_query, text_key, text_value], dim=1)
visual_out, text_out = self.scaled_dot_product_attention(
visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type
)
visual_out = self.output_layer(visual_out)
text_out = self.output_layer(text_out)
return visual_out, text_out
class MultiheadSelfAttentionTP(nn.Module):
def __init__(self, initial_multihead_self_attention):
super().__init__()
num_channels = initial_multihead_self_attention.to_query_key_value.weight.shape[1]
self.num_heads = initial_multihead_self_attention.num_heads
head_dim = num_channels // self.num_heads
self.attention_type = initial_multihead_self_attention.attention_type
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
weight = initial_multihead_self_attention.to_query_key_value.weight
bias = initial_multihead_self_attention.to_query_key_value.bias
self.to_query.weight = torch.nn.Parameter(weight[:num_channels])
self.to_key.weight = torch.nn.Parameter(weight[num_channels:2 * num_channels])
self.to_value.weight = torch.nn.Parameter(weight[2 * num_channels:])
self.to_query.bias = torch.nn.Parameter(bias[:num_channels])
self.to_key.bias = torch.nn.Parameter(bias[num_channels:2 * num_channels])
self.to_value.bias = torch.nn.Parameter(bias[2 * num_channels:])
self.query_norm = initial_multihead_self_attention.query_norm
self.key_norm = initial_multihead_self_attention.key_norm
self.output_layer = initial_multihead_self_attention.output_layer
def scaled_dot_product_attention(
self, visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type,
return_attn_probs=False
):
if self.attention_type == 'flash':
visual_shape, text_len = visual_query_key_value.shape[:3], text_cu_seqlens[1]
visual_query_key_value, visual_cu_seqlens = to_1dimension(
visual_query_key_value, visual_cu_seqlens, visual_shape, num_groups, attention_type
)
text_query_key_value = text_query_key_value.unsqueeze(0).expand(math.prod(num_groups), *text_query_key_value.size())
query_key_value = cat_interleave(visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens)
cu_seqlens = visual_cu_seqlens + text_cu_seqlens
max_seqlen = torch.diff(cu_seqlens).max()
query_key_value = query_key_value.flatten(0, 1)
large_cu_seqlens = torch.cat([cu_seqlens + i * cu_seqlens[-1] for i in range(math.prod(num_groups))])
out, softmax_lse, _ = flash_attn_varlen_qkvpacked_func(query_key_value, large_cu_seqlens, max_seqlen, return_attn_probs=True)
out = out.reshape(math.prod(num_groups), -1, *out.shape[1:]).flatten(-2, -1)
visual_out, text_out = split_interleave(out, cu_seqlens, text_len)
visual_out = to_3dimension(visual_out, visual_shape, num_groups, attention_type)
if return_attn_probs:
return (visual_out, text_out), softmax_lse, None
return visual_out, text_out
def forward(self, visual_embed, text_embed, rope, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type):
visual_shape = visual_embed.shape[:-1]
visual_query, visual_key, visual_value = self.to_query(visual_embed), self.to_key(visual_embed), self.to_value(visual_embed)
visual_query = self.query_norm(visual_query.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_query)
visual_key = self.key_norm(visual_key.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_key)
visual_value = visual_value.reshape(*visual_shape, self.num_heads, -1)
visual_query = apply_rotary(visual_query, rope).type_as(visual_query)
visual_key = apply_rotary(visual_key, rope).type_as(visual_key)
visual_query_key_value = torch.stack([visual_query, visual_key, visual_value], dim=3)
text_len = text_embed.shape[0]
text_query, text_key, text_value = self.to_query(text_embed), self.to_key(text_embed), self.to_value(text_embed)
text_query = self.query_norm(text_query.reshape(text_len, self.num_heads, -1)).type_as(text_query)
text_key = self.key_norm(text_key.reshape(text_len, self.num_heads, -1)).type_as(text_key)
text_value = text_value.reshape(text_len, self.num_heads, -1)
text_query_key_value = torch.stack([text_query, text_key, text_value], dim=1)
visual_out, text_out = self.scaled_dot_product_attention(
visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type
)
visual_out = self.output_layer(visual_out)
text_out = self.output_layer(text_out)
return visual_out, text_out
class FeedForward(nn.Module):
def __init__(self, dim, ff_dim):
super().__init__()
self.in_layer = nn.Linear(dim, ff_dim, bias=True)
self.activation = nn.GELU()
self.out_layer = nn.Linear(ff_dim, dim, bias=True)
def forward(self, x):
return self.out_layer(self.activation(self.in_layer(x)))
class OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.norm = nn.LayerNorm(model_dim, elementwise_affine=True)
self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True)
self.modulation_activation = nn.SiLU()
self.modulation_out = nn.Linear(time_dim, 2 * model_dim, bias=True)
self.modulation_out.weight.data.zero_()
self.modulation_out.bias.data.zero_()
def forward(self, visual_embed, text_embed, time_embed, visual_cu_seqlens):
modulation_params = self.modulation_out(self.modulation_activation(time_embed))
modulation_params = modulation_params.repeat_interleave(torch.diff(visual_cu_seqlens), dim=0)
shift, scale = torch.chunk(modulation_params, 2, dim=-1)
visual_embed = self.norm(visual_embed) * (scale[:, None, None, :] + 1) + shift[:, None, None, :]
x = self.out_layer(visual_embed)
duration, height, width, dim = x.shape
x = x.view(
duration, height, width,
-1, self.patch_size[0], self.patch_size[1], self.patch_size[2]
).permute(0, 4, 1, 5, 2, 6, 3).flatten(0, 1).flatten(1, 2).flatten(2, 3)
return x