|
from functools import partial |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
from celle.reversible import SequentialSequence |
|
from celle.attention import Attention |
|
|
|
from rotary_embedding_torch import RotaryEmbedding, broadcat |
|
from celle.utils import exists, default, cast_tuple |
|
|
|
|
|
class LayerScale(nn.Module): |
|
def __init__(self, dim, depth, fn): |
|
super().__init__() |
|
if depth <= 18: |
|
init_eps = 0.1 |
|
elif depth > 18 and depth <= 24: |
|
init_eps = 1e-5 |
|
else: |
|
init_eps = 1e-6 |
|
|
|
scale = torch.zeros(1, 1, dim).fill_(init_eps) |
|
self.scale = nn.Parameter(scale) |
|
self.fn = fn |
|
|
|
def forward(self, x, **kwargs): |
|
return self.fn(x, **kwargs) * self.scale |
|
|
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(dim) |
|
self.norm_out = nn.Identity() |
|
self.fn = fn |
|
|
|
def forward(self, x, **kwargs): |
|
x = self.norm(x) |
|
x = self.fn(x, **kwargs) |
|
return self.norm_out(x) |
|
|
|
|
|
|
|
|
|
|
|
class GEGLU(nn.Module): |
|
def forward(self, x): |
|
x, gates = x.chunk(2, dim=-1) |
|
return x * F.gelu(gates) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, dropout=0.0, mult=4.0): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, dim * mult * 2), |
|
GEGLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(dim * mult, dim), |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
depth, |
|
seq_len, |
|
causal=True, |
|
heads=8, |
|
dim_head=64, |
|
ff_mult=4, |
|
attn_dropout=0.0, |
|
ff_dropout=0.0, |
|
image_fmap_size=None, |
|
num_images=None, |
|
stable=False, |
|
rotary_emb=True, |
|
): |
|
super().__init__() |
|
layers = nn.ModuleList([]) |
|
|
|
self.seq_len = seq_len |
|
self.image_fmap_size = image_fmap_size |
|
|
|
for ind in range(depth): |
|
|
|
attn_class = partial(Attention, stable=stable) |
|
|
|
attn = attn_class( |
|
dim, |
|
causal=causal, |
|
seq_len=seq_len, |
|
heads=heads, |
|
dim_head=dim_head, |
|
dropout=attn_dropout, |
|
) |
|
|
|
ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) |
|
|
|
layers.append( |
|
nn.ModuleList( |
|
[ |
|
LayerScale( |
|
dim, ind + 1, PreNorm(dim, attn) |
|
), |
|
LayerScale( |
|
dim, ind + 1, PreNorm(dim, ff) |
|
), |
|
] |
|
) |
|
) |
|
|
|
|
|
route_attn = ((True, False),) * depth |
|
attn_route_map = { |
|
"mask": route_attn, |
|
"rotary_pos_emb": route_attn, |
|
} |
|
|
|
self.layers = SequentialSequence(layers, args_route=attn_route_map) |
|
|
|
|
|
|
|
pos_emb = None |
|
if rotary_emb: |
|
rot_dim = dim_head // 3 |
|
img_seq_len = ((image_fmap_size // num_images) ** 2) * num_images |
|
|
|
text_len = seq_len - img_seq_len + 1 |
|
|
|
text_pos_emb = RotaryEmbedding(dim=rot_dim) |
|
|
|
img_axial_pos_emb = RotaryEmbedding(dim=rot_dim, freqs_for="pixel") |
|
|
|
text_freqs = text_pos_emb(torch.arange(text_len)) |
|
|
|
img_to_text_freqs = text_pos_emb( |
|
torch.full((img_seq_len,), 8192) |
|
) |
|
|
|
text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim=0) |
|
|
|
img_freqs_axial = img_axial_pos_emb( |
|
torch.linspace(-1, 1, steps=image_fmap_size) |
|
) |
|
|
|
if num_images > 1: |
|
split_img_freqs_axial = torch.split( |
|
img_freqs_axial, image_fmap_size // num_images, dim=0 |
|
) |
|
|
|
split_img_freqs = [ |
|
broadcat( |
|
( |
|
rearrange(img_freqs_axial_per_image, "i d -> i () d"), |
|
rearrange(img_freqs_axial_per_image, "j d -> () j d"), |
|
), |
|
dim=-1, |
|
) |
|
for img_freqs_axial_per_image in split_img_freqs_axial |
|
] |
|
|
|
split_img_freqs = [ |
|
rearrange(img_freqs_per_image, "h w d -> (h w) d") |
|
for img_freqs_per_image in split_img_freqs |
|
] |
|
|
|
|
|
|
|
img_freqs = torch.cat(split_img_freqs, dim=0) |
|
|
|
elif num_images == 1: |
|
img_freqs = broadcat( |
|
( |
|
rearrange(img_freqs_axial, "i d -> i () d"), |
|
rearrange(img_freqs_axial, "j d -> () j d"), |
|
), |
|
dim=-1, |
|
) |
|
|
|
img_freqs = rearrange(img_freqs, "h w d -> (h w) d") |
|
|
|
else: |
|
assert False, "num_images must be int greater than 0" |
|
self.img_axial_pos_emb = img_axial_pos_emb |
|
self.text_pos_emb = text_pos_emb |
|
|
|
text_axial_freqs = img_axial_pos_emb( |
|
torch.full((text_len,), -10.0) |
|
) |
|
|
|
text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim=-1) |
|
|
|
img_freqs = torch.cat((text_axial_freqs, img_freqs), dim=0) |
|
|
|
pos_emb = torch.cat((text_freqs, img_freqs), dim=-1) |
|
|
|
pos_emb = rearrange(pos_emb, "n d -> () n d") |
|
|
|
self.register_buffer("pos_emb", pos_emb) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.layers(x, rotary_pos_emb=self.pos_emb, **kwargs) |