from functools import partial import torch from torch import nn import torch.nn.functional as F from einops import rearrange from celle.reversible import SequentialSequence from celle.attention import Attention from rotary_embedding_torch import RotaryEmbedding, broadcat from celle.utils import exists, default, cast_tuple # https://arxiv.org/abs/2103.17239 class LayerScale(nn.Module): def __init__(self, dim, depth, fn): super().__init__() if depth <= 18: init_eps = 0.1 elif depth > 18 and depth <= 24: init_eps = 1e-5 else: init_eps = 1e-6 scale = torch.zeros(1, 1, dim).fill_(init_eps) self.scale = nn.Parameter(scale) self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale # layer norm class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.norm_out = nn.Identity() self.fn = fn def forward(self, x, **kwargs): x = self.norm(x) x = self.fn(x, **kwargs) return self.norm_out(x) # feed forward class GEGLU(nn.Module): def forward(self, x): x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) class FeedForward(nn.Module): def __init__(self, dim, dropout=0.0, mult=4.0): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Dropout(dropout), nn.Linear(dim * mult, dim), ) def forward(self, x): return self.net(x) # main transformer class class Transformer(nn.Module): def __init__( self, *, dim, depth, seq_len, causal=True, heads=8, dim_head=64, ff_mult=4, attn_dropout=0.0, ff_dropout=0.0, image_fmap_size=None, num_images=None, stable=False, rotary_emb=True, ): super().__init__() layers = nn.ModuleList([]) self.seq_len = seq_len self.image_fmap_size = image_fmap_size for ind in range(depth): attn_class = partial(Attention, stable=stable) attn = attn_class( dim, causal=causal, seq_len=seq_len, heads=heads, dim_head=dim_head, dropout=attn_dropout, ) ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) layers.append( nn.ModuleList( [ LayerScale( dim, ind + 1, PreNorm(dim, attn) ), LayerScale( dim, ind + 1, PreNorm(dim, ff) ), ] ) ) # pairs arguments with attention layer route_attn = ((True, False),) * depth attn_route_map = { "mask": route_attn, "rotary_pos_emb": route_attn, } self.layers = SequentialSequence(layers, args_route=attn_route_map) # generate positional embeddings for rotary pos_emb = None if rotary_emb: rot_dim = dim_head // 3 img_seq_len = ((image_fmap_size // num_images) ** 2) * num_images text_len = seq_len - img_seq_len + 1 text_pos_emb = RotaryEmbedding(dim=rot_dim) img_axial_pos_emb = RotaryEmbedding(dim=rot_dim, freqs_for="pixel") text_freqs = text_pos_emb(torch.arange(text_len)) img_to_text_freqs = text_pos_emb( torch.full((img_seq_len,), 8192) ) # image is given a position far away from text text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim=0) img_freqs_axial = img_axial_pos_emb( torch.linspace(-1, 1, steps=image_fmap_size) ) if num_images > 1: split_img_freqs_axial = torch.split( img_freqs_axial, image_fmap_size // num_images, dim=0 ) split_img_freqs = [ broadcat( ( rearrange(img_freqs_axial_per_image, "i d -> i () d"), rearrange(img_freqs_axial_per_image, "j d -> () j d"), ), dim=-1, ) for img_freqs_axial_per_image in split_img_freqs_axial ] split_img_freqs = [ rearrange(img_freqs_per_image, "h w d -> (h w) d") for img_freqs_per_image in split_img_freqs ] # concat per image-image_freqs img_freqs = torch.cat(split_img_freqs, dim=0) elif num_images == 1: img_freqs = broadcat( ( rearrange(img_freqs_axial, "i d -> i () d"), rearrange(img_freqs_axial, "j d -> () j d"), ), dim=-1, ) img_freqs = rearrange(img_freqs, "h w d -> (h w) d") else: assert False, "num_images must be int greater than 0" self.img_axial_pos_emb = img_axial_pos_emb self.text_pos_emb = text_pos_emb text_axial_freqs = img_axial_pos_emb( torch.full((text_len,), -10.0) ) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1] text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim=-1) img_freqs = torch.cat((text_axial_freqs, img_freqs), dim=0) pos_emb = torch.cat((text_freqs, img_freqs), dim=-1) pos_emb = rearrange(pos_emb, "n d -> () n d") self.register_buffer("pos_emb", pos_emb) def forward(self, x, **kwargs): return self.layers(x, rotary_pos_emb=self.pos_emb, **kwargs)