#=================================================================================================================== # X Trasformer Module # Partial x-transformers code With useful modifications # # Version 1.0 # # Original source code courtesy of lucidrains # https://github.com/lucidrains/x-transformers # # Original source code retrieved on 05/10/2023 # # Project Los Angeles # Tegridy Code 2023 #=================================================================================================================== # Critical dependencies # # !pip install torch # !pip install einops #=================================================================================================================== from functools import partial import torch from torch import nn, einsum, Tensor import torch.nn.functional as F from collections import namedtuple from functools import wraps from packaging import version from dataclasses import dataclass from einops import rearrange import math from random import random from functools import partial from inspect import isfunction from dataclasses import dataclass from typing import List from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange from math import ceil from einops import rearrange, pack, unpack #=================================================================================================================== # constants EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) @dataclass class Intermediates: qk_similarities: Tensor = None pre_softmax_attn: Tensor = None post_softmax_attn: Tensor = None # helpers def exists(val): return val is not None def default(val, d): return val if exists(val) else d def once(fn): called = False @wraps(fn) def inner(x): nonlocal called if called: return called = True return fn(x) return inner print_once = once(print) # main class class Attend(nn.Module): def __init__( self, *, dropout = 0., causal = False, heads = None, talking_heads = False, scale = None, qk_norm = False, flash = False, ): super().__init__() self.scale = scale self.qk_norm = qk_norm self.causal = causal self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax self.dropout = dropout self.attn_dropout = nn.Dropout(dropout) # talking heads assert not (flash and talking_heads), 'talking heads not compatible with flash attention' self.talking_heads = talking_heads if talking_heads: self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) # flash attention self.flash = flash assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' # determine efficient attention configs for cuda and cpu self.cpu_config = EfficientAttentionConfig(True, True, True) self.cuda_config = None if not torch.cuda.is_available() or not flash: return device_properties = torch.cuda.get_device_properties(torch.device('cuda')) if device_properties.major == 8 and device_properties.minor == 0: print_once('A100 GPU detected, using flash attention if input tensor is on cuda') self.cuda_config = EfficientAttentionConfig(True, False, False) else: print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') self.cuda_config = EfficientAttentionConfig(False, True, True) def flash_attn( self, q, k, v, mask = None, attn_bias = None ): batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # Recommended for multi-query single-key-value attention by Tri Dao # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) if k.ndim == 3: k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) if v.ndim == 3: v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention if self.qk_norm: default_scale = q.shape[-1] ** -0.5 q = q * (default_scale / self.scale) # Check if mask exists and expand to compatible shape # The mask is B L, so it would have to be expanded to B H N L causal = self.causal if exists(mask): assert mask.ndim == 4 mask = mask.expand(batch, heads, q_len, k_len) # manually handle causal mask, if another mask was given if causal: causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1) mask = mask | causal_mask causal = False # handle alibi positional bias # convert from bool to float if exists(attn_bias): attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, -1, -1, -1) # if mask given, the mask would already contain the causal mask from above logic # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number mask_value = -torch.finfo(q.dtype).max if exists(mask): attn_bias = attn_bias.masked_fill(mask, mask_value // 2) elif causal: causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1) attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2) causal = False # scaled_dot_product_attention handles attn_mask either as bool or additive bias # make it an additive bias here mask = attn_bias # Check if there is a compatible device for flash attention config = self.cuda_config if is_cuda else self.cpu_config # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale with torch.backends.cuda.sdp_kernel(**config._asdict()): out = F.scaled_dot_product_attention( q, k, v, attn_mask = mask, dropout_p = self.dropout if self.training else 0., is_causal = causal ) return out, Intermediates() def forward( self, q, k, v, mask = None, attn_bias = None, prev_attn = None ): """ einstein notation b - batch h - heads n, i, j - sequence length (base sequence length, source, target) d - feature dimension """ n, device = q.shape[-2], q.device scale = default(self.scale, q.shape[-1] ** -0.5) if self.flash: assert not exists(prev_attn), 'residual attention not compatible with flash attention' return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias) kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale if exists(prev_attn): dots = dots + prev_attn qk_similarities = dots.clone() if self.talking_heads: dots = self.pre_softmax_talking_heads(dots) if exists(attn_bias): dots = dots + attn_bias dtype = dots.dtype pre_softmax_attn = dots.clone() mask_value = -torch.finfo(dots.dtype).max if exists(mask): dots = dots.masked_fill(mask, mask_value) if self.causal: i, j = dots.shape[-2:] causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) dots = dots.masked_fill(causal_mask, mask_value) attn = self.attn_fn(dots, dim = -1) attn = attn.type(dtype) post_softmax_attn = attn.clone() attn = self.attn_dropout(attn) if self.talking_heads: attn = self.post_softmax_talking_heads(attn) out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) intermediates = Intermediates( qk_similarities = qk_similarities, pre_softmax_attn = pre_softmax_attn, post_softmax_attn = post_softmax_attn ) return out, intermediates #=================================================================================================================== # constants DEFAULT_DIM_HEAD = 64 @dataclass class LayerIntermediates: hiddens: List[Tensor] = None, attn_intermediates: List[Intermediates] = None # helpers def exists(val): return val is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def cast_tuple(val, depth): return val if isinstance(val, tuple) else (val,) * depth def maybe(fn): @wraps(fn) def inner(x, *args, **kwargs): if not exists(x): return x return fn(x, *args, **kwargs) return inner class always(): def __init__(self, val): self.val = val def __call__(self, *args, **kwargs): return self.val class not_equals(): def __init__(self, val): self.val = val def __call__(self, x, *args, **kwargs): return x != self.val class equals(): def __init__(self, val): self.val = val def __call__(self, x, *args, **kwargs): return x == self.val # tensor helpers def max_neg_value(tensor): return -torch.finfo(tensor.dtype).max def l2norm(t, groups = 1): t = rearrange(t, '... (g d) -> ... g d', g = groups) t = F.normalize(t, p = 2, dim = -1) return rearrange(t, '... g d -> ... (g d)') def pad_at_dim(t, pad, dim = -1, value = 0.): dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) zeros = ((0, 0) * dims_from_right) return F.pad(t, (*zeros, *pad), value = value) def or_reduce(masks): head, *body = masks for rest in body: head = head | rest return head # init helpers def init_zero_(layer): nn.init.constant_(layer.weight, 0.) if exists(layer.bias): nn.init.constant_(layer.bias, 0.) # keyword argument helpers def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) def group_dict_by_key(cond, d): return_val = [dict(),dict()] for key in d.keys(): match = bool(cond(key)) ind = int(not match) return_val[ind][key] = d[key] return (*return_val,) def string_begins_with(prefix, str): return str.startswith(prefix) def group_by_key_prefix(prefix, d): return group_dict_by_key(partial(string_begins_with, prefix), d) def groupby_prefix_and_trim(prefix, d): kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs # initializations def deepnorm_init( transformer, beta, module_name_match_list = ['.ff.', '.to_v', '.to_out'] ): for name, module in transformer.named_modules(): if type(module) != nn.Linear: continue needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list)) gain = beta if needs_beta_gain else 1 nn.init.xavier_normal_(module.weight.data, gain = gain) if exists(module.bias): nn.init.constant_(module.bias.data, 0) # structured dropout, more effective than traditional attention dropouts def dropout_seq(seq, mask, dropout): b, n, *_, device = *seq.shape, seq.device logits = torch.randn(b, n, device = device) if exists(mask): mask_value = max_neg_value(logits) logits = logits.masked_fill(~mask, mask_value) keep_prob = 1. - dropout num_keep = max(1, int(keep_prob * n)) keep_indices = logits.topk(num_keep, dim = 1).indices batch_indices = torch.arange(b, device = device) batch_indices = rearrange(batch_indices, 'b -> b 1') seq = seq[batch_indices, keep_indices] if exists(mask): seq_counts = mask.sum(dim = -1) seq_keep_counts = torch.ceil(seq_counts * keep_prob).int() keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1') mask = mask[batch_indices, keep_indices] & keep_mask return seq, mask # activations class ReluSquared(nn.Module): def forward(self, x): return F.relu(x) ** 2 # embedding class TokenEmbedding(nn.Module): def __init__(self, dim, num_tokens, l2norm_embed = False): super().__init__() self.l2norm_embed = l2norm_embed self.emb = nn.Embedding(num_tokens, dim) def forward(self, x): token_emb = self.emb(x) return l2norm(token_emb) if self.l2norm_embed else token_emb # positional embeddings class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len, l2norm_embed = False): super().__init__() self.scale = dim ** -0.5 if not l2norm_embed else 1. self.max_seq_len = max_seq_len self.l2norm_embed = l2norm_embed self.emb = nn.Embedding(max_seq_len, dim) def forward(self, x, pos = None): seq_len, device = x.shape[1], x.device assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' if not exists(pos): pos = torch.arange(seq_len, device = device) pos_emb = self.emb(pos) pos_emb = pos_emb * self.scale return l2norm(pos_emb) if self.l2norm_embed else pos_emb class ScaledSinusoidalEmbedding(nn.Module): def __init__(self, dim, theta = 10000): super().__init__() assert (dim % 2) == 0 self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) half_dim = dim // 2 freq_seq = torch.arange(half_dim).float() / half_dim inv_freq = theta ** -freq_seq self.register_buffer('inv_freq', inv_freq, persistent = False) def forward(self, x, pos = None): seq_len, device = x.shape[1], x.device if not exists(pos): pos = torch.arange(seq_len, device = device) emb = einsum('i, j -> i j', pos, self.inv_freq) emb = torch.cat((emb.sin(), emb.cos()), dim = -1) return emb * self.scale class RelativePositionBias(nn.Module): def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8): super().__init__() self.scale = scale self.causal = causal self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128): ret = 0 n = -relative_position if not causal: num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) else: n = torch.max(n, torch.zeros_like(n)) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret @property def device(self): return next(self.parameters()).device def forward(self, i, j): device = self.device q_pos = torch.arange(j - i, j, dtype = torch.long, device = device) k_pos = torch.arange(j, dtype = torch.long, device = device) rel_pos = k_pos[None, :] - q_pos[:, None] rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) values = self.relative_attention_bias(rp_bucket) bias = rearrange(values, 'i j h -> h i j') return bias * self.scale class DynamicPositionBias(nn.Module): def __init__(self, dim, *, heads, depth, log_distance = False, norm = False): super().__init__() assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1' self.log_distance = log_distance self.mlp = nn.ModuleList([]) self.mlp.append(nn.Sequential( nn.Linear(1, dim), nn.LayerNorm(dim) if norm else nn.Identity(), nn.SiLU() )) for _ in range(depth - 1): self.mlp.append(nn.Sequential( nn.Linear(dim, dim), nn.LayerNorm(dim) if norm else nn.Identity(), nn.SiLU() )) self.mlp.append(nn.Linear(dim, heads)) @property def device(self): return next(self.parameters()).device def forward(self, i, j): assert i == j n, device = j, self.device # get the (n x n) matrix of distances seq_arange = torch.arange(n, device = device) context_arange = torch.arange(n, device = device) indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j') indices += (n - 1) # input to continuous positions MLP pos = torch.arange(-n + 1, n, device = device).float() pos = rearrange(pos, '... -> ... 1') if self.log_distance: pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1) for layer in self.mlp: pos = layer(pos) # get position biases bias = pos[indices] bias = rearrange(bias, 'i j h -> h i j') return bias class AlibiPositionalBias(nn.Module): def __init__(self, heads, total_heads, **kwargs): super().__init__() self.heads = heads self.total_heads = total_heads slopes = Tensor(self._get_slopes(heads)) slopes = rearrange(slopes, 'h -> h 1 1') self.register_buffer('slopes', slopes, persistent = False) self.register_buffer('bias', None, persistent = False) def get_bias(self, i, j, device): i_arange = torch.arange(j - i, j, device = device) j_arange = torch.arange(j, device = device) bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) return bias @staticmethod def _get_slopes(heads): def get_slopes_power_of_2(n): start = (2**(-2**-(math.log2(n)-3))) ratio = start return [start*ratio**i for i in range(n)] if math.log2(heads).is_integer(): return get_slopes_power_of_2(heads) closest_power_of_2 = 2 ** math.floor(math.log2(heads)) return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2] @property def device(self): return next(self.buffers()).device def forward(self, i, j): h, device = self.total_heads, self.device if exists(self.bias) and self.bias.shape[-1] >= j: return self.bias[..., :i, :j] bias = self.get_bias(i, j, device) bias = bias * self.slopes num_heads_unalibied = h - bias.shape[0] bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0) self.register_buffer('bias', bias, persistent = False) return self.bias class LearnedAlibiPositionalBias(AlibiPositionalBias): def __init__(self, heads, total_heads): super().__init__(heads, total_heads) log_slopes = torch.log(self.slopes) self.learned_logslopes = nn.Parameter(log_slopes) def forward(self, i, j): h, i, j, device = self.heads, self.device def get_slopes(param): return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim = -2) if exists(self.bias) and self.bias.shape[-1] >= j: bias = self.bias[..., :i, :j] else: bias = self.get_bias(i, j, device) self.register_buffer('bias', bias, persistent = False) slopes = get_slopes(self.learned_logslopes) bias = bias * slopes return bias class RotaryEmbedding(nn.Module): def __init__( self, dim, use_xpos = False, scale_base = 512 ): super().__init__() inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) if not use_xpos: self.register_buffer('scale', None) return scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) self.scale_base = scale_base self.register_buffer('scale', scale) def forward(self, seq_len, device): t = torch.arange(seq_len, device = device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) freqs = torch.cat((freqs, freqs), dim = -1) if not exists(self.scale): return freqs, 1. power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base scale = self.scale ** rearrange(power, 'n -> n 1') scale = torch.cat((scale, scale), dim = -1) return freqs, scale def rotate_half(x): x = rearrange(x, '... (j d) -> ... j d', j = 2) x1, x2 = x.unbind(dim = -2) return torch.cat((-x2, x1), dim = -1) def apply_rotary_pos_emb(t, freqs, scale = 1): seq_len = t.shape[-2] freqs = freqs[-seq_len:, :] return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) # norms class Scale(nn.Module): def __init__(self, value, fn): super().__init__() self.value = value self.fn = fn def forward(self, x, **kwargs): out = self.fn(x, **kwargs) scale_fn = lambda t: t * self.value if not isinstance(out, tuple): return scale_fn(out) return (scale_fn(out[0]), *out[1:]) class ScaleNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5)) def forward(self, x): norm = torch.norm(x, dim = -1, keepdim = True) return x / norm.clamp(min = self.eps) * self.g class RMSNorm(nn.Module): def __init__(self, dim, eps = 1e-8): super().__init__() self.scale = dim ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): norm = torch.norm(x, dim = -1, keepdim = True) * self.scale return x / norm.clamp(min = self.eps) * self.g # residual and residual gates class Residual(nn.Module): def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.): super().__init__() self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None self.scale_residual_constant = scale_residual_constant def forward(self, x, residual): if exists(self.residual_scale): residual = residual * self.residual_scale if self.scale_residual_constant != 1: residual = residual * self.scale_residual_constant return x + residual class GRUGating(nn.Module): def __init__(self, dim, scale_residual = False, **kwargs): super().__init__() self.gru = nn.GRUCell(dim, dim) self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None def forward(self, x, residual): if exists(self.residual_scale): residual = residual * self.residual_scale gated_output = self.gru( rearrange(x, 'b n d -> (b n) d'), rearrange(residual, 'b n d -> (b n) d') ) return gated_output.reshape_as(x) # token shifting def shift(t, amount, mask = None): if amount == 0: return t else: amount = min(amount, t.shape[1]) if exists(mask): t = t.masked_fill(~mask[..., None], 0.) return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.) class ShiftTokens(nn.Module): def __init__(self, shifts, fn): super().__init__() self.fn = fn self.shifts = tuple(shifts) def forward(self, x, **kwargs): mask = kwargs.get('mask', None) shifts = self.shifts segments = len(shifts) feats_per_shift = x.shape[-1] // segments splitted = x.split(feats_per_shift, dim = -1) segments_to_shift, rest = splitted[:segments], splitted[segments:] segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts))) x = torch.cat((*segments_to_shift, *rest), dim = -1) return self.fn(x, **kwargs) # feedforward class GLU(nn.Module): def __init__(self, dim_in, dim_out, activation): super().__init__() self.act = activation self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim = -1) return x * self.act(gate) class FeedForward(nn.Module): def __init__( self, dim, dim_out = None, mult = 4, glu = False, swish = False, relu_squared = False, post_act_ln = False, dropout = 0., no_bias = False, zero_init_output = False ): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) if relu_squared: activation = ReluSquared() elif swish: activation = nn.SiLU() else: activation = nn.GELU() project_in = nn.Sequential( nn.Linear(dim, inner_dim, bias = not no_bias), activation ) if not glu else GLU(dim, inner_dim, activation) self.ff = nn.Sequential( project_in, nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out, bias = not no_bias) ) # init last linear layer to 0 if zero_init_output: init_zero_(self.ff[-1]) def forward(self, x): return self.ff(x) # attention. it is all we need class Attention(nn.Module): def __init__( self, dim, dim_head = DEFAULT_DIM_HEAD, heads = 8, causal = False, flash = False, talking_heads = False, head_scale = False, sparse_topk = None, num_mem_kv = 0, dropout = 0., on_attn = False, gate_values = False, zero_init_output = False, max_attend_past = None, qk_norm = False, qk_norm_groups = 1, qk_norm_scale = 10, qk_norm_dim_scale = False, one_kv_head = False, shared_kv = False, value_dim_head = None, tensor_product = False # https://arxiv.org/abs/2208.06061 ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads self.causal = causal self.max_attend_past = max_attend_past value_dim_head = default(value_dim_head, dim_head) q_dim = k_dim = dim_head * heads v_dim = out_dim = value_dim_head * heads self.one_kv_head = one_kv_head if one_kv_head: k_dim = dim_head v_dim = value_dim_head out_dim = v_dim * heads self.to_q = nn.Linear(dim, q_dim, bias = False) self.to_k = nn.Linear(dim, k_dim, bias = False) # shared key / values, for further memory savings during inference assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values' self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None # relations projection from tp-attention self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None # add GLU gating for aggregated values, from alphafold2 self.to_v_gate = None if gate_values: self.to_v_gate = nn.Linear(dim, out_dim) nn.init.constant_(self.to_v_gate.weight, 0) nn.init.constant_(self.to_v_gate.bias, 1) # cosine sim attention self.qk_norm = qk_norm self.qk_norm_groups = qk_norm_groups self.qk_norm_scale = qk_norm_scale # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442 self.qk_norm_dim_scale = qk_norm_dim_scale self.qk_norm_q_scale = self.qk_norm_k_scale = 1 if qk_norm and qk_norm_dim_scale: self.qk_norm_q_scale = nn.Parameter(torch.ones(dim_head)) self.qk_norm_k_scale = nn.Parameter(torch.ones(dim_head)) assert (not qk_norm) or (dim_head % qk_norm_groups) == 0, 'dimension per attention head must be divisible by the qk norm groups' assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)' # attend class - includes core attention algorithm + talking heads self.attend = Attend( heads = heads, causal = causal, talking_heads = talking_heads, dropout = dropout, qk_norm = qk_norm, scale = qk_norm_scale if qk_norm else self.scale, flash = flash ) # head scaling self.head_scale = head_scale if head_scale: self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1)) # explicit topk sparse attention self.sparse_topk = sparse_topk # add memory key / values self.num_mem_kv = num_mem_kv if num_mem_kv > 0: self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) # attention on attention self.attn_on_attn = on_attn self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False) # init output projection 0 if zero_init_output: init_zero_(self.to_out) def forward( self, x, context = None, mask = None, context_mask = None, attn_mask = None, rel_pos = None, rotary_pos_emb = None, prev_attn = None, mem = None ): b, n, _, h, head_scale, device, has_context = *x.shape, self.heads, self.head_scale, x.device, exists(context) kv_input = default(context, x) q_input = x k_input = kv_input v_input = kv_input r_input = x if exists(mem): k_input = torch.cat((mem, k_input), dim = -2) v_input = torch.cat((mem, v_input), dim = -2) q = self.to_q(q_input) k = self.to_k(k_input) v = self.to_v(v_input) if exists(self.to_v) else k r = self.to_r(r_input) if exists(self.to_r) else None q = rearrange(q, 'b n (h d) -> b h n d', h = h) if not self.one_kv_head: k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = h), (k, v, r)) if self.qk_norm: qk_l2norm = partial(l2norm, groups = self.qk_norm_groups) q, k = map(qk_l2norm, (q, k)) scale = self.qk_norm_scale q = q * self.qk_norm_q_scale k = k * self.qk_norm_k_scale if exists(rotary_pos_emb) and not has_context: freqs, xpos_scale = rotary_pos_emb l = freqs.shape[-1] q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.) (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) ql, kl, vl = map(lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]), ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale))) q, k, v = map(lambda t: torch.cat(t, dim = -1), ((ql, qr), (kl, kr), (vl, vr))) input_mask = default(context_mask, mask) if self.num_mem_kv > 0: mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v)) if self.qk_norm: mem_k = l2norm(mem_k) mem_k = mem_k * self.qk_norm_k_scale k = torch.cat((mem_k, k), dim = -2) v = torch.cat((mem_v, v), dim = -2) if exists(input_mask): input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True) i, j = map(lambda t: t.shape[-2], (q, k)) # determine masking mask_value = max_neg_value(q) masks = [] final_attn_mask = None if exists(input_mask): input_mask = rearrange(input_mask, 'b j -> b 1 1 j') masks.append(~input_mask) if exists(attn_mask): assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4' if attn_mask.ndim == 2: attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j') elif attn_mask.ndim == 3: attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j') masks.append(~attn_mask) if exists(self.max_attend_past): range_q = torch.arange(j - i, j, device = device) range_k = torch.arange(j, device = device) dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j') max_attend_past_mask = dist > self.max_attend_past masks.append(max_attend_past_mask) if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: top, _ = dots.topk(self.sparse_topk, dim = -1) vk = rearrange(top[..., -1], '... -> ... 1') sparse_topk_mask = dots < vk masks.append(sparse_topk_mask) if len(masks) > 0: final_attn_mask = or_reduce(masks) # prepare relative positional bias, if needed attn_bias = None if exists(rel_pos): attn_bias = rel_pos(i, j) # attention is all we need out, intermediates = self.attend( q, k, v, mask = final_attn_mask, attn_bias = attn_bias, prev_attn = prev_attn ) # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients if exists(r): out = out * r + out # normformer scaling of heads if head_scale: out = out * self.head_scale_params # merge heads out = rearrange(out, 'b h n d -> b n (h d)') # alphafold2 styled gating of the values if exists(self.to_v_gate): gates = self.to_v_gate(x) out = out * gates.sigmoid() # combine the heads out = self.to_out(out) if exists(mask): mask = rearrange(mask, 'b n -> b n 1') out = out.masked_fill(~mask, 0.) return out, intermediates class AttentionLayers(nn.Module): def __init__( self, dim, depth, heads = 8, causal = False, cross_attend = False, only_cross = False, use_scalenorm = False, use_rmsnorm = False, alibi_pos_bias = False, alibi_num_heads = None, alibi_learned = False, rel_pos_bias = False, rel_pos_num_buckets = 32, rel_pos_max_distance = 128, dynamic_pos_bias = False, dynamic_pos_bias_log_distance = False, dynamic_pos_bias_mlp_depth = 2, dynamic_pos_bias_norm = False, rotary_pos_emb = False, rotary_emb_dim = None, rotary_xpos = False, rotary_xpos_scale_base = 512, custom_layers = None, sandwich_coef = None, par_ratio = None, residual_attn = False, cross_residual_attn = False, macaron = False, pre_norm = True, gate_residual = False, scale_residual = False, scale_residual_constant = 1., deepnorm = False, shift_tokens = 0, sandwich_norm = False, resi_dual = False, zero_init_branch_output = False, layer_dropout = 0., cross_attn_tokens_dropout = 0., **kwargs ): super().__init__() rotary_pos_emb = rotary_pos_emb or rotary_xpos ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs) dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) self.dim = dim self.depth = depth self.layers = nn.ModuleList([]) self.has_pos_emb = rel_pos_bias or rotary_pos_emb rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention' self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base) if rotary_pos_emb else None assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' # relative positional bias flash_attn = attn_kwargs.get('flash', False) assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias' self.rel_pos = None if rel_pos_bias: assert not flash_attn, 'flash attention not compatible with t5 relative positional bias' self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance) elif dynamic_pos_bias: assert not flash_attn, 'flash attention not compatible with dynamic positional bias' self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm) elif alibi_pos_bias: alibi_num_heads = default(alibi_num_heads, heads) assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned else AlibiPositionalBias self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, total_heads = heads) # determine deepnorm and residual scale if deepnorm: assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings' pre_norm = sandwich_norm = resi_dual = False scale_residual = True scale_residual_constant = (2 * depth) ** 0.25 assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both' assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' assert not (not pre_norm and resi_dual), 'resiDualcannot be used when not using prenorm' self.pre_norm = pre_norm self.sandwich_norm = sandwich_norm self.resi_dual = resi_dual self.residual_attn = residual_attn self.cross_residual_attn = cross_residual_attn self.cross_attend = cross_attend norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm norm_class = RMSNorm if use_rmsnorm else norm_class norm_fn = partial(norm_class, dim) if cross_attend and not only_cross: default_block = ('a', 'c', 'f') elif cross_attend and only_cross: default_block = ('c', 'f') else: default_block = ('a', 'f') if macaron: default_block = ('f',) + default_block # zero init if zero_init_branch_output: attn_kwargs = {**attn_kwargs, 'zero_init_output': True} ff_kwargs = {**ff_kwargs, 'zero_init_output': True} # calculate layer block order if exists(custom_layers): layer_types = custom_layers elif exists(par_ratio): par_depth = depth * len(default_block) assert 1 < par_ratio <= par_depth, 'par ratio out of range' default_block = tuple(filter(not_equals('f'), default_block)) par_attn = par_depth // par_ratio depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper par_width = (depth_cut + depth_cut // par_attn) // par_attn assert len(default_block) <= par_width, 'default block is too large for par_ratio' par_block = default_block + ('f',) * (par_width - len(default_block)) par_head = par_block * par_attn layer_types = par_head + ('f',) * (par_depth - len(par_head)) elif exists(sandwich_coef): assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef else: layer_types = default_block * depth self.layer_types = layer_types self.num_attn_layers = len(list(filter(equals('a'), layer_types))) # stochastic depth self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types)) # structured dropout for cross attending self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # calculate token shifting shift_tokens = cast_tuple(shift_tokens, len(layer_types)) # iterate and construct layers for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): is_last_layer = ind == (len(self.layer_types) - 1) if layer_type == 'a': layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs) elif layer_type == 'c': layer = Attention(dim, heads = heads, **attn_kwargs) elif layer_type == 'f': layer = FeedForward(dim, **ff_kwargs) layer = layer if not macaron else Scale(0.5, layer) else: raise Exception(f'invalid layer type {layer_type}') if layer_shift_tokens > 0: shift_range_upper = layer_shift_tokens + 1 shift_range_lower = -layer_shift_tokens if not causal else 0 layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) residual_fn = GRUGating if gate_residual else Residual residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant) pre_branch_norm = norm_fn() if pre_norm else None post_branch_norm = norm_fn() if sandwich_norm else None post_main_norm = norm_fn() if (resi_dual or not pre_norm) and not is_last_layer else None norms = nn.ModuleList([ pre_branch_norm, post_branch_norm, post_main_norm ]) self.layers.append(nn.ModuleList([ norms, layer, residual ])) if deepnorm: init_gain = (8 * depth) ** -0.25 deepnorm_init(self, init_gain) def forward( self, x, context = None, mask = None, context_mask = None, attn_mask = None, self_attn_context_mask = None, mems = None, return_hiddens = False ): assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True' hiddens = [] intermediates = [] prev_attn = None prev_cross_attn = None mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers rotary_pos_emb = None if exists(self.rotary_pos_emb): max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems))) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) outer_residual = x for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)): is_last = ind == (len(self.layers) - 1) if self.training and layer_dropout > 0. and random() < layer_dropout: continue if layer_type == 'a': if return_hiddens: hiddens.append(x) layer_mem = mems.pop(0) if mems else None if layer_type == 'c': if self.training and self.cross_attn_tokens_dropout > 0.: context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout) inner_residual = x pre_norm, post_branch_norm, post_main_norm = norm if exists(pre_norm) and not self.resi_dual: x = pre_norm(x) if layer_type == 'a': out, inter = block(x, mask = mask, context_mask = self_attn_context_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem) elif layer_type == 'c': out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn) elif layer_type == 'f': out = block(x) if self.resi_dual: outer_residual = residual_fn(out, outer_residual) if exists(post_branch_norm): out = post_branch_norm(out) x = residual_fn(out, inner_residual) if layer_type in ('a', 'c') and return_hiddens: intermediates.append(inter) if layer_type == 'a' and self.residual_attn: prev_attn = inter.pre_softmax_attn elif layer_type == 'c' and self.cross_residual_attn: prev_cross_attn = inter.pre_softmax_attn if exists(post_main_norm): x = post_main_norm(x) if self.resi_dual: x = x + pre_norm(outer_residual) if return_hiddens: intermediates = LayerIntermediates( hiddens = hiddens, attn_intermediates = intermediates ) return x, intermediates return x class Encoder(AttentionLayers): def __init__(self, **kwargs): assert 'causal' not in kwargs, 'cannot set causality on encoder' super().__init__(causal = False, **kwargs) class Decoder(AttentionLayers): def __init__(self, **kwargs): assert 'causal' not in kwargs, 'cannot set causality on decoder' super().__init__(causal = True, **kwargs) class CrossAttender(AttentionLayers): def __init__(self, **kwargs): super().__init__(cross_attend = True, only_cross = True, **kwargs) class ViTransformerWrapper(nn.Module): def __init__( self, *, image_size, patch_size, attn_layers, channels = 3, num_classes = None, dropout = 0., post_emb_norm = False, emb_dropout = 0. ): super().__init__() assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' dim = attn_layers.dim num_patches = (image_size // patch_size) ** 2 patch_dim = channels * patch_size ** 2 self.patch_size = patch_size self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.patch_to_embedding = nn.Sequential( nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim) ) self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity() self.dropout = nn.Dropout(emb_dropout) self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity() def forward( self, img, return_embeddings = False ): p = self.patch_size x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) x = self.patch_to_embedding(x) n = x.shape[1] x = x + self.pos_embedding[:, :n] x = self.post_emb_norm(x) x = self.dropout(x) x = self.attn_layers(x) x = self.norm(x) if not exists(self.mlp_head) or return_embeddings: return x x = x.mean(dim = -2) return self.mlp_head(x) class TransformerWrapper(nn.Module): def __init__( self, *, num_tokens, max_seq_len, attn_layers, emb_dim = None, max_mem_len = 0., shift_mem_down = 0, emb_dropout = 0., post_emb_norm = False, num_memory_tokens = None, tie_embedding = False, logits_dim = None, use_abs_pos_emb = True, scaled_sinu_pos_emb = False, l2norm_embed = False, emb_frac_gradient = 1. # GLM-130B and Cogview successfully used this, set at 0.1 ): super().__init__() assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' dim = attn_layers.dim emb_dim = default(emb_dim, dim) self.emb_dim = emb_dim self.num_tokens = num_tokens self.token_pad = num_tokens self.max_seq_len = max_seq_len self.max_mem_len = max_mem_len self.shift_mem_down = shift_mem_down self.l2norm_embed = l2norm_embed self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed) if not (use_abs_pos_emb and not attn_layers.has_pos_emb): self.pos_emb = always(0) elif scaled_sinu_pos_emb: self.pos_emb = ScaledSinusoidalEmbedding(emb_dim) else: self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed) self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290 self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity() self.emb_dropout = nn.Dropout(emb_dropout) self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.init_() logits_dim = default(logits_dim, num_tokens) self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) self.num_memory_tokens = num_memory_tokens if num_memory_tokens > 0: self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) def init_(self): if self.l2norm_embed: nn.init.normal_(self.token_emb.emb.weight, std = 1e-5) if not isinstance(self.pos_emb, always): nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5) return nn.init.kaiming_normal_(self.token_emb.emb.weight) def forward( self, x, return_embeddings = False, return_logits_and_embeddings = False, return_intermediates = False, mask = None, return_mems = False, return_attn = False, mems = None, pos = None, prepend_embeds = None, sum_embeds = None, **kwargs ): b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient return_hiddens = return_mems | return_attn # absolute positional embedding external_pos_emb = exists(pos) and pos.dtype != torch.long pos_emb = self.pos_emb(x, pos = pos) if not external_pos_emb else pos x = self.token_emb(x) + pos_emb # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training if exists(sum_embeds): x = x + sum_embeds # post embedding norm, purportedly leads to greater stabilization x = self.post_emb_norm(x) # whether to append embeds, as in PaLI, for image embeddings if exists(prepend_embeds): prepend_seq, prepend_dim = prepend_embeds.shape[1:] assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions' x = torch.cat((prepend_embeds, x), dim = -2) # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model if emb_frac_gradient < 1: assert emb_frac_gradient > 0 x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient) # embedding dropout x = self.emb_dropout(x) x = self.project_emb(x) if num_mem > 0: mem = repeat(self.memory_tokens, 'n d -> b n d', b = b) x = torch.cat((mem, x), dim = 1) # auto-handle masking after appending memory tokens if exists(mask): mask = pad_at_dim(mask, (num_mem, 0), dim = -1, value = True) if self.shift_mem_down and exists(mems): mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] mems = [*mems_r, *mems_l] if return_hiddens: x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs) else: x = self.attn_layers(x, mask = mask, mems = mems, **kwargs) x = self.norm(x) mem, x = x[:, :num_mem], x[:, num_mem:] if return_logits_and_embeddings: out = (self.to_logits(x), x) elif return_embeddings: out = x else: out = self.to_logits(x) if return_intermediates: return out, intermediates if return_mems: hiddens = intermediates.hiddens new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) return out, new_mems if return_attn: attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) return out, attn_maps return out class ContinuousTransformerWrapper(nn.Module): def __init__( self, *, max_seq_len, attn_layers, dim_in = None, dim_out = None, emb_dim = None, post_emb_norm = False, emb_dropout = 0., use_abs_pos_emb = True, scaled_sinu_pos_emb = False ): super().__init__() assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' dim = attn_layers.dim self.max_seq_len = max_seq_len if not (use_abs_pos_emb and not attn_layers.has_pos_emb): self.pos_emb = always(0) elif scaled_sinu_pos_emb: self.pos_emb = ScaledSinusoidalEmbedding(dim) else: self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity() self.emb_dropout = nn.Dropout(emb_dropout) self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() def forward( self, x, return_embeddings = False, return_intermediates = False, mask = None, return_attn = False, mems = None, pos = None, prepend_embeds = None, **kwargs ): x = self.project_in(x) x = x + self.pos_emb(x, pos = pos) x = self.post_emb_norm(x) # whether to append embeds, as in PaLI, for image embeddings if exists(prepend_embeds): _, prepend_dim = prepend_embeds.shape[1:] assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions' x = torch.cat((prepend_embeds, x), dim = -2) x = self.emb_dropout(x) x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs) x = self.norm(x) out = self.project_out(x) if not return_embeddings else x if return_intermediates: return out, intermediates if return_attn: attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) return out, attn_maps return out class XTransformer(nn.Module): def __init__( self, *, dim, tie_token_emb = False, ignore_index = -100, pad_value = 0, deepnorm = False, cross_attn_tokens_dropout = 0., **kwargs ): super().__init__() enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs) dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs) assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword' enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs) enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0) enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None) enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False) enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True) dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs) dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0) dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False) dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True) self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories if deepnorm: enc_kwargs['scale_residual'] = True dec_kwargs['scale_residual'] = True enc_depth = enc_kwargs['depth'] dec_depth = dec_kwargs['depth'] enc_kwargs['scale_residual_constant'] = 0.81 * ((enc_depth ** 4) * dec_depth) ** .0625 dec_kwargs['scale_residual_constant'] = (3 * dec_depth) ** 0.25 self.encoder = TransformerWrapper( **enc_transformer_kwargs, attn_layers = Encoder(dim = dim, **enc_kwargs) ) self.decoder = TransformerWrapper( **dec_transformer_kwargs, attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs) ) if deepnorm: deepnorm_init(self.encoder, 0.87 * ((enc_depth ** 4) * dec_depth) ** -0.0625) deepnorm_init(self.decoder, (12 * dec_depth) ** -0.25) if tie_token_emb: self.decoder.token_emb = self.encoder.token_emb self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value) @torch.no_grad() def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs): encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True) return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs) def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None): if exists(src_prepend_embeds) and exists(mask): mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True) enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True) if self.training and self.cross_attn_tokens_dropout > 0: enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout) out = self.decoder(tgt, context = enc, context_mask = mask) return out #=================================================================================================================== def exists(val): return val is not None def eval_decorator(fn): def inner(self, *args, **kwargs): was_training = self.training self.eval() out = fn(self, *args, **kwargs) self.train(was_training) return out return inner # nucleus def top_p(logits, thres = 0.9): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cum_probs > (1 - thres) sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = 0 sorted_logits[sorted_indices_to_remove] = float('-inf') return sorted_logits.scatter(1, sorted_indices, sorted_logits) # topk def top_k(logits, thres = 0.9): k = ceil((1 - thres) * logits.shape[-1]) val, ind = torch.topk(logits, k) probs = torch.full_like(logits, float('-inf')) probs.scatter_(1, ind, val) return probs # top_a def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02): probs = F.softmax(logits, dim=-1) limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio logits[probs < limit] = float('-inf') logits[probs >= limit] = 1 return logits # autoregressive wrapper class class AutoregressiveWrapper(nn.Module): def __init__( self, net, ignore_index = -100, pad_value = 0, mask_prob = 0. ): super().__init__() self.pad_value = pad_value self.ignore_index = ignore_index self.net = net self.max_seq_len = net.max_seq_len # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432 assert mask_prob < 1. self.mask_prob = mask_prob @torch.no_grad() @eval_decorator def generate( self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, min_p_pow = 2.0, min_p_ratio = 0.02, verbose=True, return_prime=False, **kwargs ): device = start_tokens.device num_dims = start_tokens.ndim start_tokens, ps = pack([start_tokens], '* n') b, t = start_tokens.shape out = start_tokens if verbose: print("Generating sequence of max length:", seq_len) for s in range(seq_len): x = out[:, -self.max_seq_len:] logits = self.net(x, **kwargs)[:, -1] if filter_logits_fn in {top_k, top_p}: filtered_logits = filter_logits_fn(logits, thres = filter_thres) probs = F.softmax(filtered_logits / temperature, dim=-1) elif filter_logits_fn is top_a: filtered_logits = filter_logits_fn(logits, min_p_pow = min_p_pow, min_p_ratio= min_p_ratio) probs = F.softmax(filtered_logits / temperature, dim=-1) sample = torch.multinomial(probs, 1) out = torch.cat((out, sample), dim=-1) if verbose: if s % 32 == 0: print(s, '/', seq_len) if exists(eos_token): is_eos_tokens = (out == eos_token) if is_eos_tokens.any(dim = -1).all(): # mask out everything after the eos tokens shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 out = out.masked_fill(mask, self.pad_value) if verbose: print('Model called the end of sequence at:', s, '/', seq_len) break if return_prime: return out[:, :] else: return out[:, t:] out, = unpack(out, ps, '* n') return out def compute_accuracy(self, logits, labels): out = torch.argmax(logits, dim=-1) out = out.flatten() labels = labels.flatten() mask = (labels != 999999) # dummy pad value / supposed to be self.token_pad / will fix later out = out[mask] labels = labels[mask] num_right = (out == labels) num_right = torch.sum(num_right).type(torch.float32) acc = num_right / len(labels) return acc def forward(self, x, labels = None, **kwargs): seq, ignore_index = x.shape[1], self.ignore_index inp, target = x[:, :-1], x[:, 1:] if self.mask_prob > 0.: rand = torch.randn(inp.shape, device = x.device) rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out num_mask = min(int(seq * self.mask_prob), seq - 1) indices = rand.topk(num_mask, dim = -1).indices mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool() kwargs.update(self_attn_context_mask = mask) logits = self.net(inp, **kwargs) acc = self.compute_accuracy(logits, target) loss = F.cross_entropy( rearrange(logits, 'b n c -> b c n'), target, ignore_index = ignore_index ) return loss, acc #===================================================================================================================