# See: https://huggingface.co/docs/transformers/custom_models from typing import Optional, Tuple, Union import math import copy import sys from importlib import import_module import torch from torch import nn, Tensor import torch.nn.init as init from torch.nn import functional as F from transformers.modeling_outputs import CausalLMOutput from transformers import ( PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel, AutoModelForCausalLM, ) from transformers.utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, ) if is_flash_attn_2_available(): from flash_attn import flash_attn_qkvpacked_func, flash_attn_func # The model type string to bind. model_type = "walsh-causal-v1" class Config(PretrainedConfig): model_type = model_type attribute_map = { "hidden_size": "d_embed", } def __init__( # All of these MUST have defaults, even if unused. self, vocab_size=16000, pad_index=None, hidden_size=1024, num_attention_heads=8, num_hidden_layers=6, max_sequence_length=2048, dim_feedforward = 4096, dropout=0.1, loss_function = "causal_loss", # Default class to use for each of these components. positional_encoder_cls='.PositionalEncoder', attention_cls='.CausalSelfAttention', activation_cls='torch.nn.ReLU', feedforward_cls='.FeedforwardLayer', layer_stack_cls='.TransformerLayerStack', layer_cls='.PostLayerNorm', transformer_cls='.Transformer', norm_cls='torch.nn.LayerNorm', embdding_cls='torch.nn.Embedding', output_proj_cls='torch.nn.Linear', positional_encoder_args={ 'd_model': 1024, 'max_seq_len': 2048, }, # Arg groups, passed to factory classes above. transformer_args=dict(), attention_args=dict(), feedforward_args=dict(), activation_args=dict(), norm_args={ 'normalized_shape': 1024, }, layer_stack_args=dict(), layer_args=dict(), embedding_args=dict(), output_proj_args=dict(), **kwargs, ): self.vocab_size = vocab_size self.pad_index = pad_index self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_hidden_layers = num_hidden_layers self.max_sequence_length = max_sequence_length self.loss_function = loss_function self.dim_feedforward = dim_feedforward self.dropout = dropout self.positional_encoder_cls = positional_encoder_cls self.attention_cls = attention_cls self.activation_cls = activation_cls self.feedforward_cls = feedforward_cls self.layer_stack_cls = layer_stack_cls self.layer_cls = layer_cls self.transformer_cls = transformer_cls self.norm_cls = norm_cls self.embdding_cls = embdding_cls self.output_proj_cls = output_proj_cls self.positional_encoder_args = positional_encoder_args self.transformer_args = transformer_args self.attention_args = attention_args self.feedforward_args = feedforward_args self.activation_args = activation_args self.norm_args = norm_args self.layer_stack_args = layer_stack_args self.layer_args = layer_args self.embedding_args = embedding_args self.output_proj_args = output_proj_args super().__init__(**kwargs) def causal_loss(logits: Tensor, labels: Tensor, input_ids: Tensor, ignore_index=-100) -> Tensor: """ Compute and return the loss using logits and labels. """ # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = torch.nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=ignore_index, reduction='mean', ) return loss.nan_to_num() # Learning to Break the Loop: Analyzing and Mitigating Repetitions for Neural Text Generation # https://arxiv.org/abs/2206.02369 def ditto_loss(logits: Tensor, labels: Tensor, input_ids: Tensor) -> Tensor: batch_size, seq_len, vocab_size = logits.shape rep_reduce_gamma = 0.5 ditto_weight = 1.0e5 probs = torch.softmax(logits, dim=-1) total_loss = None for i in range(batch_size): context_len = labels[i, 0].item() sentence_len = labels[i, 1].item() n_repeats = labels[i, 2].item() # For readability context_end = context_len sentence_start = context_len sentence_end = sentence_start + sentence_len target_start = sentence_end # Get causal loss for context tokens causal_ids = input_ids[i:i+1, :context_end] c_loss = causal_loss( logits=logits[i:i+1, :context_end], labels=causal_ids, input_ids=causal_ids ) # Slice out target probabilities target_probs = probs[i , target_start:, :] # Slice out first instance of repeated sentence, detach is (prevents back-prop), repeat in N times, # and trim to length of target_probs. baseline_probs = probs[i, sentence_start:sentence_end, :].detach().repeat(n_repeats, 1)[:target_probs.size(0), :] # Compute DITTO loss. one_minus_probs = torch.clamp((1.0 - torch.abs((target_probs - baseline_probs * rep_reduce_gamma))), min=1e-20) r_loss = -torch.log(one_minus_probs).mean() * ditto_weight # Combine repitition and causal loss loss = c_loss + r_loss # Add this to the total if total_loss is None: total_loss = loss else: total_loss += loss return total_loss / batch_size # Dynamically lookup class name and return factory for class. def get_dynamic_class(name): try: module_path, class_name = name.rsplit('.', 1) if module_path == "": return getattr(sys.modules[__name__], class_name) module = import_module(module_path) return getattr(module, class_name) except (ImportError, AttributeError) as e: raise ImportError(name) # An easily extensible dynamic transformer class # Many variations can be specified entirely in the configuration, without touching this code. class HFCausalModel(PreTrainedModel): config_class = Config model_type = 'Transformer' supports_gradient_checkpointing = True # Presently needs to be manually set to match transformer layer class... _no_split_modules = ["DeepNetLayer"] _supports_flash_attn_2 = True _supports_sdpa = True def __init__(self, config): super().__init__(config) self.d_model = config.hidden_size self.transformer_head = self._make_transformer(config) self.loss_function = get_dynamic_class(config.loss_function) self.gradient_checkpointing = False self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> (Tensor, dict[str, Tensor]): if self.gradient_checkpointing and self.training: gradient_checkpointing_func = self._gradient_checkpointing_func else: gradient_checkpointing_func = None logits, attentions = self.transformer_head( input_ids=input_ids, need_weights=output_attentions, gradient_checkpointing_func=gradient_checkpointing_func, ) # Compute loss. if labels is not None: loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids) else: loss = None return CausalLMOutput(loss=loss, logits=logits, attentions=attentions) # Needed for generate() method. def prepare_inputs_for_generation(self, input_ids, **kwargs): attention_mask = kwargs.get("attention_mask", None) model_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, } return model_inputs def _make_embedding(self, config): embedding_cls = get_dynamic_class(config.embdding_cls) return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args) def _make_pos_encoder(self, config): pos_enc_cls = get_dynamic_class(config.positional_encoder_cls) return pos_enc_cls(**config.positional_encoder_args) def _make_output_projection(self, config): output_proj_cls = get_dynamic_class(config.output_proj_cls) return output_proj_cls(self.d_model, config.vocab_size, **config.output_proj_args) def _make_dropout(self, config): return nn.Dropout(config.dropout) def _make_activation(self, config): activation_cls = get_dynamic_class(config.activation_cls) return activation_cls(**config.activation_args) def _make_norm(self, config): norm_cls = get_dynamic_class(config.norm_cls) return norm_cls(self.d_model) def _make_self_attention(self, config): attention_cls = get_dynamic_class(config.attention_cls) # Map HF _attn_implementation to attn_type match config._attn_implementation: case "flash_attention_2": if is_flash_attn_2_available(): if not is_flash_attn_greater_or_equal_2_10(): raise Exception("flash_attn_2 >= 2.10 is required") attn_type = "flash2" else: attn_type = "torch" case "sdpa": attn_type = "torch" case "eager": attn_type = "native" case _: raise Exception(f"Unimplemented attention type '{config._attn_implementation}'") return attention_cls( d_model=self.d_model, num_heads=config.num_attention_heads, attn_type=attn_type, **config.attention_args, ) def _make_feedforward(self, config): feedforward_cls = get_dynamic_class(config.feedforward_cls) return feedforward_cls( d_model=self.d_model, feedforward_dim=config.dim_feedforward, dropout=config.dropout, activation=self._make_activation(config), **config.feedforward_args, ) def _make_layer(self, config): layer_cls = get_dynamic_class(config.layer_cls) return layer_cls( d_model=self.d_model, dropout=self._make_dropout(config), attention=self._make_self_attention(config), feedforward=self._make_feedforward(config), norm1=self._make_norm(config), norm2=self._make_norm(config), **config.layer_args, ) def _make_layer_stack(self, config): layer_stack_cls = get_dynamic_class(config.layer_stack_cls) return layer_stack_cls( layers=nn.ModuleList([ self._make_layer(config) for _ in range(config.num_hidden_layers) ]), **config.layer_stack_args, ) def _make_transformer(self, config): transformer_cls = get_dynamic_class(config.transformer_cls) return transformer_cls( d_model=self.d_model, embedding=self._make_embedding(config), positional_encoder=self._make_pos_encoder(config), layer_stack=self._make_layer_stack(config), output_projection=self._make_output_projection(config), **config.transformer_args, ) @torch.no_grad() def _init_weights(self, module): pass # Register model type and configuration AutoConfig.register(model_type, Config) AutoModelForCausalLM.register(Config, HFCausalModel) # A generic container class for standard transformer components. class Transformer(nn.Module): def __init__(self, d_model, embedding, positional_encoder, layer_stack, output_projection, **kwargs): super().__init__() self.embedding = embedding self.positional_encoder = positional_encoder self.layer_stack = layer_stack self.output_projection = output_projection self.d_model = d_model self.sqrt_d_model = d_model**0.5 self.reset_parameters() def forward(self, input_ids, need_weights, gradient_checkpointing_func): x = self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model) x, attentions = self.layer_stack( x, need_weights, gradient_checkpointing_func, ) # Translate output embedding ot logits. logits = self.output_projection(x) return logits, attentions def reset_parameters(self): init.xavier_uniform_(self.output_projection.weight) init.constant_(self.output_projection.bias, 0.) init.normal_(self.embedding.weight, std=self.d_model**-0.5) # A vanilla positional encoder class PositionalEncoder(nn.Module): def __init__(self, d_embed, max_seq): super().__init__() self.d_embed = d_embed self.max_seq = max_seq weight = torch.zeros(max_seq, d_embed) position = torch.arange(0, max_seq, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed)) weight[:, 0::2] = torch.sin(position * div_term) weight[:, 1::2] = torch.cos(position * div_term) weight = weight.unsqueeze(0) self.register_buffer('weight', weight) def forward(self, x): seq_len = x.size(-2) return x + self.weight[:, :seq_len] # Converts a torch array of integers into their equivalent binary codes. def binary_tensor(x, bits): mask = 2**torch.arange(bits).to(x.device, x.dtype) return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() def hadamard_walsh_matrix(k: int): # k: The dimension of the matrix is 2^k assert k > 0 # Start with Hadamard H2^1 matrix. h1 = torch.tensor([[1, 1], [1, -1]], dtype=torch.float) # The series of matrices can be computed by recurisvely applying the Kronecker product, # starting with h1. # # This will produce the series of Hadamard-Wlash matrices in natural order. w = h1 for _ in range(k-1): w = torch.kron(h1, w) return w # This positional encoder adds absolute binary positions to the embedding, encoded via # Hadamard-Walsh matrix. # See: https://en.wikipedia.org/wiki/Hadamard_code # Each bit in the binary code word is encoded via a row the Hadamard-Walsh matrix, with a # 1 being encoded by the presense of the row and a 0 by its absence. While training, the base # sequence offset is randomly selected, which appears to allow the model to generalize to # sequences longer than it was trained on. This is similar to what is described here: # https://arxiv.org/pdf/2305.16843.pdf # I have tried this approach and found that my approach works better for generalization. # # Note: Without random shifting, the early performance of this encoder is exceptionally good. # The drawback is that the model can't generalize to longer sequences than it was trained on # and can't easily accomidate additonal bits later in the training process. class RSWalshPositionalEncoder(nn.Module): def __init__(self, d_embed, max_seq, gain=0.333): super().__init__() self.max_seq = max_seq self.d_embed = d_embed # Hadamard-Walsh k, where the dimension of the matrix is 2^k k = math.ceil(math.log2(d_embed)) # The number of bits required to encode max_seq bits = math.ceil(math.log2(max_seq)) # Gain controls the weight given to the encodings. # When a trainable parameter, the value appears to settle at around 0.333. self.gain = gain assert bits <= d_embed, "max_seq exceeds n-bits available for d_embed" # Generate sequential binary codes for absolute positionals. # The implementation originally used Grey codes, which where successive symbols # differ by by only one bit. See: https://en.wikipedia.org/wiki/Gray_code # This, along with a few other coding schemes were tested, with a simple # binary code having the best performance. binary_code = binary_tensor(torch.arange(0, max_seq, 1), bits) self.register_buffer('binary_code', binary_code, persistent=False) # Each bit is encoded via a row of a Hadamard-Walsh matrix. # We slice off the unused rows and columns -- ideally, d_embed should be # the same dimension as the matrix. walsh = hadamard_walsh_matrix(k)[:bits,:d_embed] * self.gain # This alternative appears superior to the original. # If starting from scratch, this use this. # walsh = (hadamard_walsh_matrix(k)[:bits,:d_embed] -0.5) * self.gain self.register_buffer('walsh', walsh, persistent=False) def forward(self, x): seq_len = x.size(-2) # Get sequence of binary codes... # We use a random base offset when training. # This results in slower initial gains, but appears to allow the model to generalize to # the value of max_seq, even if never trained with sequences of this length. I also have # a suspicion that this has a regularizing effect on training, similar to dropout. Models with # random base offset shifting, despite slower initial improvement, appear to perform better in the long-run. # TODO: Setup a controlled experiment to test this hypothesis. if self.training: shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item() seq = self.binary_code[shift:seq_len + shift,:] # Disable shifting when not training. This does not appear to change the evaluation loss, but # it does makes predictions easier to analyse when the attention weights are not shifting with each step. else: seq = self.binary_code[:seq_len,:] # For reasons I have yet to identify, when the model is running in Textgenwebui, the matrix appears # to evade conversion to bfloat16, despite everything else having been converted. # This is a work-around for this. self.walsh = self.walsh.to(dtype=x.dtype) # Encode binary sequence with Hadamard-Walsh codes and apply to embeddings. # If nothing else, the Walsh encodings make the positional information exceptionally # robust with respect to dropout and other adversities. They can still be easily detected # at the final layer. return x + (seq.to(dtype=x.dtype) @ self.walsh) # A generic stack of transformer layers. class TransformerLayerStack(nn.Module): def __init__(self, layers): super().__init__() self.layers = layers def forward(self, x, need_weights, gradient_checkpointing_func=None): attentions = [] for layer in self.layers: if gradient_checkpointing_func is not None: x, attention_weights = gradient_checkpointing_func( layer.__call__, x, need_weights, use_reentrant=False ) else: x, attention_weights = layer(x, need_weights=need_weights) if need_weights: attentions.append(attention_weights) return x, attentions # DeepNet: Scaling Transformers to 1,000 Layers # https://arxiv.org/abs/2203.00555 class DeepnetLayer(nn.Module): def __init__( self, d_model, attention, feedforward, norm1, norm2, dropout, alpha=1.0, ): super().__init__() self.d_model = d_model self.attention = attention self.feedforward = feedforward self.norm1 = norm1 self.norm2 = norm2 self.dropout = dropout # Deepnet alpha self.alpha = alpha def forward(self, x, need_weights=False): # Keep input as residual residual = x * self.alpha # Compute attention x, attention_weights = self.attention(x, need_weights) # Add attention with residual and normalize. x = self.norm1(residual + self.dropout(x)) # Keep output as next residual. residual = x * self.alpha # Pass through feedforward network. x = self.feedforward(x) # Combine residual and ff output, then normalize again. x = self.norm2(residual + self.dropout(x)) return x, attention_weights # A vanilla MLP transfomer layer. class FeedforwardLayer(nn.Module): def __init__( self, d_model: int, feedforward_dim: int, dropout, activation=nn.ReLU(), beta=1.0, bias=True, ): super().__init__() self.d_model = d_model self.beta = beta self.activation = activation self.linear1 = nn.Linear(d_model, feedforward_dim, bias=bias) self.linear2 = nn.Linear(feedforward_dim, d_model, bias=bias) self.dropout = nn.Dropout(dropout) self.reset_parameters() def forward(self, x): return self.linear2(self.dropout(self.activation(self.linear1(x)))) def reset_parameters(self): init.xavier_uniform_(self.linear1.weight, gain=self.beta) init.xavier_uniform_(self.linear2.weight, gain=self.beta) init.constant_(self.linear1.bias, 0.) init.constant_(self.linear2.bias, 0.) # GLU Variants Improve Transformer # https://arxiv.org/pdf/2002.05202v1.pdf class SwiGLUFeedforwardLayer(nn.Module): def __init__( self, d_model, d_feedforward, beta=1.0, dropout=0.1 ): super().__init__() self.d_model = d_model self.d_feedforward = d_feedforward self.beta = 1.0 self.linear1 = nn.Linear(self.d_model, self.d_feedforward * 2, bias=False) self.linear2 = nn.Linear(self.d_feedforward, self.d_model, bias=False) self.dropout = nn.Dropout(dropout) self.reset_parameters() def forward(self, x): x, gate = self.linear1(x).chunk(2, dim=-1) x = x * F.silu(gate) x = self.dropout(x) x = self.linear2(x) return x def reset_parameters(self): # Deepnet initialization # https://arxiv.org/pdf/2203.00555.pdf w, g = self.linear1.weight.chunk(2, dim=0) init.xavier_uniform_(w, gain=self.beta) init.xavier_uniform_(g, gain=self.beta) init.xavier_uniform_(self.linear2.weight, gain=self.beta) class CausalSelfAttention(nn.Module): def __init__( self, d_model, num_heads, # values: # native: Use local impementation; slowest option; good for debugging; useful when experimenting with non-standard stuff. # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights. # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage. attn_type, beta=1.0, dropout=0.1, ): super().__init__() self.d_model = d_model self.num_heads = num_heads self.beta = beta self.attn_type = attn_type assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads" # The dimension of each head. self.d_head = d_model // num_heads # We scale the attention scores by the inverse-square-root of the head dimension # this shifts the temerature of softmax. self.dot_product_scale = 1.0 / math.sqrt(self.d_head) self.in_proj = nn.Linear(self.d_model, 3 * self.d_model, bias=True) self.output_linear = nn.Linear(self.d_model, self.d_model, bias=True) self.dropout = nn.Dropout(dropout) self.reset_parameters() def extra_repr(self) -> str: return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, dropout={self.dropout}' def reset_parameters(self): # Deepnet initialization # https://arxiv.org/pdf/2203.00555.pdf q, k, v = self.in_proj.weight.chunk(3) init.xavier_uniform_(q, gain=1.0) init.xavier_uniform_(k, gain=1.0) init.xavier_uniform_(v, gain=self.beta) init.xavier_uniform_(self.output_linear.weight, gain=self.beta) init.constant_(self.in_proj.bias, 0.) init.constant_(self.output_linear.bias, 0.) def project_input(self, qkv): proj = self.in_proj(qkv) return proj.chunk(chunks=3, dim=-1) def forward(self, qkv, need_weights): if self.attn_type == "flash2": return self.flash2_forward(qkv) # qkv: (batch_size, seq_len, d_embed) batch_size, seq_len, d_embed = qkv.shape # Feed the inputs through the K, Q, V matrices. query, key, value = self.project_input(qkv) # Split projections into multiple heads and swap position of sequence / heads dimension query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) # Default to returning empty attention weights. attention_weights = None if self.attn_type == "torch": # This context manager can be used to force which implementation to use. #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): attended_values = F.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=self.dropout.p if self.training else 0.0, is_causal=True, scale=self.dot_product_scale ) # "native" scaled-dot-product attention implementation. else: # Compute attention scores scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale # Mask future positions from the past scores.masked_fill_( torch.tril( torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device), diagonal=0, ).logical_not(), float('-inf'), ) # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10)) del scores # Use the attention weights to get a weighted combination of value vectors attended_values = torch.matmul(attention_weights, value) if not need_weights: del attention_weights attention_weights = None # Concatenate attention heads and project to original embedding size using the output linear layer attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed) # Project the concatenated output through the output matrix. attended_values = self.output_linear(attended_values) return attended_values, attention_weights def flash2_forward(self, qkv): batch_size, seq_len, d_embed = qkv.shape # Feed the inputs through the K, Q, V matrices. # query : (batch_size, seq_len, d_model) # qkv : (batch_size, seq_len, 3, num_heads, d_kq) qkv = self.in_proj(qkv).unflatten( -1, (3, self.num_heads, self.d_head) ) attended_values = flash_attn_qkvpacked_func( qkv.bfloat16(), dropout_p=self.dropout.p if self.training else 0.0, softmax_scale=self.dot_product_scale, causal=True, ) # attended_values: (batch_size, seqlen, nheads, headdim) # Concatentate heads back into d_embed attended_values = attended_values.view(batch_size, seq_len, d_embed) # Project the concatenated output through the output matrix. attended_values = self.output_linear(attended_values) return attended_values, None # Attention layer with ALiBi relative positional encoding # TRAIN SHORT, TEST LONG: ATTENTION WITH LINEAR BIASES ENABLES INPUT LENGTH EXTRAPOLATION # https://arxiv.org/pdf/2108.12409.pdf def alibi_biases(query_len, key_len, device='cpu'): x = torch.arange(key_len, device=device)[None, :] y = torch.arange(query_len, device=device)[:, None] return x - y class CausalAlibiAttention(nn.Module): def __init__( self, d_model, num_heads, beta=1.0, dropout=0.1, # values: # native: Use local impementation; slowest option; good for debugging; useful when experimenting with non-standard stuff. # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights. # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; can't train Alibi weights; least memory usage. # Note: You can perform initial training with "torch," then switch to "flash2," after the Alibi weights have settled. window_size=None, attn_type="native", freeze_alibi=True, ): super().__init__() self.d_model = d_model self.num_heads = num_heads self.beta = beta self.attn_type = attn_type assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads" # The dimension of each head. self.d_head = d_model // num_heads # We scale the attention scores by the inverse-square-root of the head dimension # this shifts the temerature of softmax. self.dot_product_scale = 1.0 / math.sqrt(self.d_head) self.in_proj = nn.Parameter(torch.empty(3 * self.d_model, self.d_model)) self.output_linear = nn.Linear(self.d_model, self.d_model, bias=False) if window_size is not None: self.window_size=(window_size, -1) else: self.window_size = (-1, -1) self.dropout = nn.Dropout(dropout) # This generates the original slope distribution from the paper. # Observations with trainable slopes suggest that the high half of the slopes shift # towards / past 1.0 and the low half approach zero or even go slightly negative. # alibi_slopes = 1.0 / torch.logspace(1, 8, self.num_heads, base=2, dtype=torch.float) # These appear to work better, as initial values, in practice. alibi_slopes = 1.0 / torch.logspace(0, 7, self.num_heads, base=2, dtype=torch.float) # If not trainable, it can improve performance somewhat if the low half are set to zero. Apparently # making roughly half of the slopes position-agnostic is somehow closer to optimal? # alibi_slopes.masked_fill_(torch.where(torch.arange(0, self.num_heads) >= (self.num_heads / 2), True, False), 0) self.alibi_slopes = nn.Parameter(alibi_slopes) # Optionally, allow/disallow training of ALiBi slopes. self.alibi_slopes.requires_grad = (not freeze_alibi) self.reset_parameters() def extra_repr(self) -> str: return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, window_size={self.window_size}, dropout={self.dropout}' def reset_parameters(self): # Deepnet initialization # https://arxiv.org/pdf/2203.00555.pdf q, k, v = self.in_proj.chunk(3) init.xavier_uniform_(q, gain=1.0) init.xavier_uniform_(k, gain=1.0) init.xavier_uniform_(v, gain=self.beta) init.xavier_uniform_(self.output_linear.weight, gain=self.beta) def project_input(self, qkv): proj = F.linear(qkv, self.in_proj) return proj.chunk(chunks=3, dim=-1) def forward(self, qkv, need_weights): if self.attn_type == "flash2": return self.flash2_forward(qkv) # qkv: (batch_size, seq_len, d_embed) batch_size, seq_len, d_embed = qkv.shape # Feed the inputs through the K, Q, V matrices. query, key, value = self.project_input(qkv) # Split projections into multiple heads and swap position of sequence / heads dimension query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) # Apply Alibi relative positional biases. attn_bias = alibi_biases(seq_len, seq_len, device=query.device) * self.alibi_slopes.view(-1, 1, 1) # Mask future positions from the past causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device), diagonal=0) attn_bias.masked_fill_(causal_mask.logical_not(), float('-inf')) del causal_mask # Default to returning empty attention weights. attention_weights = None if self.attn_type == "torch": # This context manager can be used to force which implementation to use. #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): attended_values = F.scaled_dot_product_attention( query, key, value, attn_mask=attn_bias.to(dtype=query.dtype), dropout_p=self.dropout.p if self.training else 0.0, is_causal=False, scale=self.dot_product_scale ) # "native" scaled-dot-product attention implementation. else: # Compute attention scores scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale # Adjust scores with attn_mask scores += attn_bias # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10)) # Use the attention weights to get a weighted combination of value vectors attended_values = torch.matmul(attention_weights, value) if not need_weights: attention_weights = None # Concatenate attention heads and project to original embedding size using the output linear layer attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed) # Project the concatenated output through the output matrix. attended_values = self.output_linear(attended_values) return attended_values, attention_weights def flash2_forward(self, qkv): batch_size, seq_len, d_embed = qkv.shape # Feed the inputs through the K, Q, V matrices. # query : (batch_size, seq_len, d_model) # qkv : (batch_size, seq_len, 3, num_heads, d_kq) qkv = F.linear( qkv, self.in_proj, ).unflatten( -1, (3, self.num_heads, self.d_head) ) attended_values = flash_attn_qkvpacked_func( qkv.bfloat16(), dropout_p=self.dropout.p if self.training else 0.0, softmax_scale=self.dot_product_scale, causal=True, window_size=self.window_size, alibi_slopes=self.alibi_slopes.float(), ).to(dtype=qkv.dtype) # attended_values: (batch_size, seqlen, nheads, headdim) # Concatentate heads back into d_embed attended_values = attended_values.view(batch_size, seq_len, d_embed) # Project the concatenated output through the output matrix. attended_values = self.output_linear(attended_values) return attended_values, None