|
|
|
from typing import Optional, Tuple, Union, List |
|
import math |
|
import copy |
|
import sys |
|
from importlib import import_module |
|
|
|
import torch |
|
from torch import nn, Tensor |
|
import torch.nn.init as init |
|
from torch.nn import functional as F |
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutput, CausalLMOutputWithPast |
|
from transformers import ( |
|
PreTrainedModel, |
|
PretrainedConfig, |
|
AutoConfig, |
|
AutoModel, |
|
AutoModelForCausalLM, |
|
) |
|
|
|
from transformers.utils import logging |
|
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
|
from transformers.utils import ( |
|
is_flash_attn_2_available, |
|
is_flash_attn_greater_or_equal_2_10, |
|
) |
|
|
|
|
|
if is_flash_attn_2_available(): |
|
try: |
|
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func |
|
except: |
|
print("Could not import flash2") |
|
|
|
if is_flash_attn_2_available(): |
|
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
model_type = "walsh-causal-v1" |
|
|
|
class Config(PretrainedConfig): |
|
model_type = model_type |
|
|
|
attribute_map = { |
|
"hidden_size": "d_embed", |
|
} |
|
|
|
def __init__( |
|
|
|
self, |
|
vocab_size=16000, |
|
pad_index=None, |
|
hidden_size=1024, |
|
num_attention_heads=8, |
|
num_hidden_layers=6, |
|
max_sequence_length=2048, |
|
dim_feedforward = 4096, |
|
dropout=0.1, |
|
loss_function = "causal_loss", |
|
|
|
|
|
positional_encoder_cls='.PositionalEncoder', |
|
attention_cls='.CausalSelfAttention', |
|
activation_cls='torch.nn.ReLU', |
|
feedforward_cls='.FeedforwardLayer', |
|
layer_stack_cls='.TransformerLayerStack', |
|
layer_cls='.PostLayerNorm', |
|
transformer_cls='.Transformer', |
|
norm_cls='torch.nn.LayerNorm', |
|
embdding_cls='torch.nn.Embedding', |
|
output_proj_cls='torch.nn.Linear', |
|
|
|
positional_encoder_args={ |
|
'd_model': 1024, |
|
'max_seq_len': 2048, |
|
}, |
|
|
|
|
|
transformer_args=dict(), |
|
attention_args=dict(), |
|
feedforward_args=dict(), |
|
activation_args=dict(), |
|
norm_args={ |
|
'normalized_shape': 1024, |
|
}, |
|
layer_stack_args=dict(), |
|
layer_args=dict(), |
|
embedding_args=dict(), |
|
output_proj_args=dict(), |
|
|
|
output_attentions=False, |
|
output_hidden_states=False, |
|
use_cache=True, |
|
|
|
**kwargs, |
|
): |
|
self.vocab_size = vocab_size |
|
self.pad_index = pad_index |
|
self.hidden_size = hidden_size |
|
self.num_attention_heads = num_attention_heads |
|
self.num_hidden_layers = num_hidden_layers |
|
self.max_sequence_length = max_sequence_length |
|
self.loss_function = loss_function |
|
|
|
self.dim_feedforward = dim_feedforward |
|
self.dropout = dropout |
|
|
|
self.positional_encoder_cls = positional_encoder_cls |
|
self.attention_cls = attention_cls |
|
self.activation_cls = activation_cls |
|
self.feedforward_cls = feedforward_cls |
|
self.layer_stack_cls = layer_stack_cls |
|
self.layer_cls = layer_cls |
|
self.transformer_cls = transformer_cls |
|
self.norm_cls = norm_cls |
|
self.embdding_cls = embdding_cls |
|
self.output_proj_cls = output_proj_cls |
|
|
|
self.positional_encoder_args = positional_encoder_args |
|
self.transformer_args = transformer_args |
|
self.attention_args = attention_args |
|
self.feedforward_args = feedforward_args |
|
self.activation_args = activation_args |
|
self.norm_args = norm_args |
|
self.layer_stack_args = layer_stack_args |
|
self.layer_args = layer_args |
|
self.embedding_args = embedding_args |
|
self.output_proj_args = output_proj_args |
|
|
|
self.output_attentions = output_attentions |
|
self.output_hidden_states = output_hidden_states |
|
self.use_cache = use_cache |
|
|
|
super().__init__(**kwargs) |
|
|
|
def causal_loss(logits: Tensor, labels: Tensor, input_ids: Tensor, ignore_index=-100) -> Tensor: |
|
""" |
|
Compute and return the loss using logits and labels. |
|
""" |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss = torch.nn.functional.cross_entropy( |
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
shift_labels.view(-1), |
|
ignore_index=ignore_index, |
|
reduction='mean', |
|
) |
|
|
|
return loss.nan_to_num() |
|
|
|
|
|
|
|
def ditto_loss(logits: Tensor, labels: Tensor, input_ids: Tensor) -> Tensor: |
|
batch_size, seq_len, vocab_size = logits.shape |
|
rep_reduce_gamma = 0.5 |
|
ditto_weight = 1.0e5 |
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
total_loss = None |
|
for i in range(batch_size): |
|
context_len = labels[i, 0].item() |
|
sentence_len = labels[i, 1].item() |
|
n_repeats = labels[i, 2].item() |
|
|
|
|
|
context_end = context_len |
|
sentence_start = context_len |
|
sentence_end = sentence_start + sentence_len |
|
target_start = sentence_end |
|
|
|
|
|
causal_ids = input_ids[i:i+1, :context_end] |
|
c_loss = causal_loss( |
|
logits=logits[i:i+1, :context_end], |
|
labels=causal_ids, |
|
input_ids=causal_ids |
|
) |
|
|
|
|
|
target_probs = probs[i , target_start:, :] |
|
|
|
|
|
|
|
baseline_probs = probs[i, sentence_start:sentence_end, :].detach().repeat(n_repeats, 1)[:target_probs.size(0), :] |
|
|
|
|
|
one_minus_probs = torch.clamp((1.0 - torch.abs((target_probs - baseline_probs * rep_reduce_gamma))), min=1e-20) |
|
r_loss = -torch.log(one_minus_probs).mean() * ditto_weight |
|
|
|
|
|
loss = c_loss + r_loss |
|
|
|
|
|
if total_loss is None: |
|
total_loss = loss |
|
else: |
|
total_loss += loss |
|
|
|
return total_loss / batch_size |
|
|
|
|
|
def get_dynamic_class(name): |
|
try: |
|
module_path, class_name = name.rsplit('.', 1) |
|
if module_path == "": |
|
return getattr(sys.modules[__name__], class_name) |
|
module = import_module(module_path) |
|
return getattr(module, class_name) |
|
except (ImportError, AttributeError) as e: |
|
raise ImportError(name) |
|
|
|
|
|
|
|
class HFCausalModel(PreTrainedModel): |
|
config_class = Config |
|
model_type = 'Transformer' |
|
supports_gradient_checkpointing = True |
|
|
|
_no_split_modules = ["DeepNetLayer"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_cache_class = True |
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.d_model = config.hidden_size |
|
self.transformer_head = self._make_transformer(config) |
|
self.loss_function = get_dynamic_class(config.loss_function) |
|
self.gradient_checkpointing = False |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> (Tensor, dict[str, Tensor]): |
|
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
if use_cache: |
|
|
|
use_legacy_cache = not isinstance(past_key_values, Cache) |
|
if use_legacy_cache: |
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
gradient_checkpointing_func = self._gradient_checkpointing_func |
|
else: |
|
gradient_checkpointing_func = None |
|
|
|
|
|
outputs = self.transformer_head( |
|
input_ids=input_ids, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
gradient_checkpointing_func=gradient_checkpointing_func, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
logits = outputs["logits"].float() |
|
attentions = outputs["attentions"] |
|
|
|
|
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids) |
|
else: |
|
loss = None |
|
|
|
|
|
new_cache = outputs["past_key_values"] |
|
if use_cache and new_cache is not None and use_legacy_cache: |
|
new_cache = new_cache.to_legacy_cache() |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=new_cache, |
|
hidden_states=outputs["hidden_states"], |
|
attentions=outputs["attentions"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
): |
|
|
|
if past_key_values is not None: |
|
if isinstance(past_key_values, Cache): |
|
cache_length = past_key_values.get_seq_length() |
|
past_length = past_key_values.seen_tokens |
|
max_cache_length = past_key_values.get_max_length() |
|
else: |
|
cache_length = past_length = past_key_values[0][0].shape[2] |
|
max_cache_length = None |
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
|
|
|
|
|
elif past_length < input_ids.shape[1]: |
|
input_ids = input_ids[:, past_length:] |
|
|
|
|
|
|
|
if ( |
|
max_cache_length is not None |
|
and attention_mask is not None |
|
and cache_length + input_ids.shape[1] > max_cache_length |
|
): |
|
attention_mask = attention_mask[:, -max_cache_length:] |
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"attention_mask": attention_mask, |
|
} |
|
) |
|
return model_inputs |
|
|
|
@staticmethod |
|
def _reorder_cache(past_key_values, beam_idx): |
|
reordered_past = () |
|
for layer_past in past_key_values: |
|
reordered_past += ( |
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
|
) |
|
return reordered_past |
|
|
|
def _make_embedding(self, config): |
|
embedding_cls = get_dynamic_class(config.embdding_cls) |
|
return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args) |
|
|
|
def _make_pos_encoder(self, config): |
|
pos_enc_cls = get_dynamic_class(config.positional_encoder_cls) |
|
return pos_enc_cls(**config.positional_encoder_args) |
|
|
|
def _make_output_projection(self, config): |
|
output_proj_cls = get_dynamic_class(config.output_proj_cls) |
|
return output_proj_cls(self.d_model, config.vocab_size, **config.output_proj_args) |
|
|
|
def _make_dropout(self, config): |
|
return nn.Dropout(config.dropout) |
|
|
|
def _make_activation(self, config): |
|
activation_cls = get_dynamic_class(config.activation_cls) |
|
return activation_cls(**config.activation_args) |
|
|
|
def _make_norm(self, config): |
|
norm_cls = get_dynamic_class(config.norm_cls) |
|
return norm_cls(self.d_model) |
|
|
|
def _make_self_attention(self, layer_idx, config): |
|
attention_cls = get_dynamic_class(config.attention_cls) |
|
|
|
match config._attn_implementation: |
|
case "flash_attention_2": |
|
if is_flash_attn_2_available(): |
|
if not is_flash_attn_greater_or_equal_2_10(): |
|
raise Exception("flash_attn_2 >= 2.10 is required") |
|
attn_type = "flash2" |
|
else: |
|
attn_type = "torch" |
|
case "sdpa": |
|
attn_type = "torch" |
|
case "eager": |
|
attn_type = "native" |
|
case _: |
|
raise Exception(f"Unimplemented attention type '{config._attn_implementation}'") |
|
return attention_cls( |
|
d_model=self.d_model, |
|
num_heads=config.num_attention_heads, |
|
attn_type=attn_type, |
|
layer_idx=layer_idx, |
|
config=config, |
|
**config.attention_args, |
|
) |
|
|
|
def _make_feedforward(self, layer_idx, config): |
|
feedforward_cls = get_dynamic_class(config.feedforward_cls) |
|
return feedforward_cls( |
|
d_model=self.d_model, |
|
feedforward_dim=config.dim_feedforward, |
|
dropout=config.dropout, |
|
activation=self._make_activation(config), |
|
layer_idx=layer_idx, |
|
**config.feedforward_args, |
|
) |
|
|
|
def _make_layer(self, layer_idx, config): |
|
layer_cls = get_dynamic_class(config.layer_cls) |
|
return layer_cls( |
|
d_model=self.d_model, |
|
dropout=self._make_dropout(config), |
|
attention=self._make_self_attention(layer_idx, config), |
|
feedforward=self._make_feedforward(layer_idx, config), |
|
norm1=self._make_norm(config), |
|
norm2=self._make_norm(config), |
|
layer_idx=layer_idx, |
|
**config.layer_args, |
|
) |
|
|
|
def _make_layer_stack(self, config): |
|
layer_stack_cls = get_dynamic_class(config.layer_stack_cls) |
|
return layer_stack_cls( |
|
layers=nn.ModuleList([ |
|
self._make_layer(layer_idx, config) for layer_idx in range(config.num_hidden_layers) |
|
]), |
|
**config.layer_stack_args, |
|
) |
|
|
|
def _make_transformer(self, config): |
|
transformer_cls = get_dynamic_class(config.transformer_cls) |
|
return transformer_cls( |
|
d_model=self.d_model, |
|
embedding=self._make_embedding(config), |
|
positional_encoder=self._make_pos_encoder(config), |
|
layer_stack=self._make_layer_stack(config), |
|
output_projection=self._make_output_projection(config), |
|
**config.transformer_args, |
|
) |
|
|
|
@torch.no_grad() |
|
def _init_weights(self, module): |
|
pass |
|
|
|
|
|
AutoConfig.register(model_type, Config) |
|
AutoModelForCausalLM.register(Config, HFCausalModel) |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, d_model, embedding, positional_encoder, layer_stack, output_projection, **kwargs): |
|
super().__init__() |
|
self.embedding = embedding |
|
self.positional_encoder = positional_encoder |
|
self.layer_stack = layer_stack |
|
self.output_projection = output_projection |
|
self.d_model = d_model |
|
self.sqrt_d_model = d_model**0.5 |
|
self.reset_parameters() |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
position_ids, |
|
output_attentions, |
|
gradient_checkpointing_func, |
|
past_key_values, |
|
use_cache, |
|
output_hidden_states, |
|
): |
|
outputs = self.layer_stack( |
|
self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model, position_ids), |
|
output_attentions=output_attentions, |
|
gradient_checkpointing_func=gradient_checkpointing_func, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
|
|
outputs["logits"] = self.output_projection(outputs["last_hidden_state"]) |
|
del outputs["last_hidden_state"] |
|
return outputs |
|
|
|
def reset_parameters(self): |
|
init.xavier_uniform_(self.output_projection.weight) |
|
init.constant_(self.output_projection.bias, 0.) |
|
init.normal_(self.embedding.weight, std=self.d_model**-0.5) |
|
|
|
|
|
def binary_tensor(x, bits): |
|
mask = 2**torch.arange(bits).to(x.device, x.dtype) |
|
return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() |
|
|
|
def hadamard_walsh_matrix(k: int): |
|
|
|
assert k > 0 |
|
|
|
|
|
h1 = torch.tensor([[1, 1], [1, -1]], dtype=torch.float) |
|
|
|
|
|
|
|
|
|
|
|
w = h1 |
|
for _ in range(k-1): |
|
w = torch.kron(h1, w) |
|
|
|
return w |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RSWalshPositionalEncoder(nn.Module): |
|
def __init__(self, d_embed, max_seq, gain=0.333): |
|
super().__init__() |
|
self.max_seq = max_seq |
|
self.d_embed = d_embed |
|
|
|
|
|
k = math.ceil(math.log2(d_embed)) |
|
|
|
|
|
bits = math.ceil(math.log2(max_seq)) |
|
|
|
|
|
|
|
self.gain = gain |
|
|
|
assert bits <= d_embed, "max_seq exceeds n-bits available for d_embed" |
|
|
|
|
|
|
|
|
|
|
|
|
|
binary_code = binary_tensor(torch.arange(0, max_seq, 1), bits) |
|
self.register_buffer('binary_code', binary_code, persistent=False) |
|
|
|
|
|
|
|
|
|
walsh = hadamard_walsh_matrix(k)[:bits,:d_embed] * self.gain |
|
|
|
|
|
|
|
|
|
self.register_buffer('walsh', walsh, persistent=False) |
|
|
|
def forward(self, x, position_ids=None): |
|
seq_len = x.size(-2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.training: |
|
shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item() |
|
seq = self.binary_code[shift:seq_len + shift,:] |
|
|
|
|
|
|
|
|
|
elif position_ids != None: |
|
seq = self.binary_code[position_ids, :] |
|
|
|
|
|
|
|
else: |
|
seq = self.binary_code[:seq_len,:] |
|
|
|
|
|
|
|
|
|
self.walsh = self.walsh.to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
|
|
|
return x + (seq.to(dtype=x.dtype) @ self.walsh) |
|
|
|
|
|
class TransformerLayerStack(nn.Module): |
|
def __init__(self, layers): |
|
super().__init__() |
|
self.layers = layers |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
output_attentions, |
|
past_key_values, |
|
use_cache, |
|
output_hidden_states, |
|
gradient_checkpointing_func=None, |
|
): |
|
present_key_value = None |
|
all_attentions = [] if output_attentions else None |
|
all_hidden_states = [hidden_states] if output_hidden_states else None |
|
|
|
for layer in self.layers: |
|
if gradient_checkpointing_func is not None: |
|
layer_outputs = gradient_checkpointing_func( |
|
layer.__call__, |
|
hidden_states, |
|
output_attentions, |
|
past_key_values, |
|
use_cache, |
|
use_reentrant=False, |
|
) |
|
else: |
|
layer_outputs = layer( |
|
hidden_states, |
|
output_attentions, |
|
past_key_values, |
|
use_cache, |
|
) |
|
|
|
hidden_states = layer_outputs["hidden_states"] |
|
|
|
if output_hidden_states: |
|
all_hidden_states.append(hidden_states) |
|
|
|
if use_cache: |
|
present_key_value = layer_outputs["past_key_values"] |
|
|
|
if output_attentions: |
|
all_attentions.append(layer_outputs["attentions"]) |
|
|
|
return dict( |
|
last_hidden_state=hidden_states, |
|
past_key_values=present_key_value, |
|
hidden_states=hidden_states, |
|
attentions=all_attentions, |
|
) |
|
|
|
|
|
|
|
|
|
class DeepnetLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model, |
|
attention, |
|
feedforward, |
|
norm1, |
|
norm2, |
|
dropout, |
|
layer_idx, |
|
alpha=1.0, |
|
): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.attention = attention |
|
self.feedforward = feedforward |
|
self.norm1 = norm1 |
|
self.norm2 = norm2 |
|
self.dropout = dropout |
|
|
|
self.alpha = alpha |
|
self.layer_idx = layer_idx |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
output_attentions, |
|
past_key_values, |
|
use_cache, |
|
): |
|
|
|
residual = hidden_states * self.alpha |
|
|
|
|
|
attn_outputs = self.attention( |
|
hidden_states, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions |
|
) |
|
|
|
hidden_states = attn_outputs["hidden_states"] |
|
|
|
|
|
hidden_states = self.norm1(residual + self.dropout(hidden_states)) |
|
|
|
|
|
residual = hidden_states * self.alpha |
|
|
|
|
|
hidden_states = self.feedforward(hidden_states) |
|
|
|
|
|
hidden_states = self.norm2(residual + self.dropout(hidden_states)) |
|
|
|
return dict( |
|
hidden_states=hidden_states, |
|
attentions=attn_outputs["attentions"], |
|
past_key_values=attn_outputs["past_key_values"] |
|
) |
|
|
|
|
|
class FeedforwardLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
feedforward_dim: int, |
|
dropout, |
|
layer_idx, |
|
activation=nn.ReLU(), |
|
beta=1.0, |
|
bias=True, |
|
): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.beta = beta |
|
self.activation = activation |
|
self.linear1 = nn.Linear(d_model, feedforward_dim, bias=bias) |
|
self.linear2 = nn.Linear(feedforward_dim, d_model, bias=bias) |
|
self.dropout = nn.Dropout(dropout) |
|
self.reset_parameters() |
|
|
|
def forward(self, x): |
|
return self.linear2(self.dropout(self.activation(self.linear1(x)))) |
|
|
|
def reset_parameters(self): |
|
init.xavier_uniform_(self.linear1.weight, gain=self.beta) |
|
init.xavier_uniform_(self.linear2.weight, gain=self.beta) |
|
init.constant_(self.linear1.bias, 0.) |
|
init.constant_(self.linear2.bias, 0.) |
|
|
|
class CausalSelfAttention(nn.Module): |
|
def __init__( |
|
self, |
|
d_model, |
|
num_heads, |
|
|
|
|
|
|
|
|
|
attn_type, |
|
layer_idx, |
|
config, |
|
beta=1.0, |
|
dropout=0.1, |
|
): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.num_heads = num_heads |
|
self.beta = beta |
|
self.attn_type = attn_type |
|
self.layer_idx = layer_idx |
|
self.config = config |
|
|
|
assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads" |
|
|
|
|
|
self.d_head = d_model // num_heads |
|
|
|
|
|
|
|
self.dot_product_scale = 1.0 / math.sqrt(self.d_head) |
|
|
|
self.in_proj = nn.Linear(self.d_model, 3 * self.d_model, bias=True) |
|
self.output_linear = nn.Linear(self.d_model, self.d_model, bias=True) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.reset_parameters() |
|
|
|
def extra_repr(self) -> str: |
|
return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, dropout={self.dropout}' |
|
|
|
def reset_parameters(self): |
|
|
|
|
|
q, k, v = self.in_proj.weight.chunk(3) |
|
init.xavier_uniform_(q, gain=1.0) |
|
init.xavier_uniform_(k, gain=1.0) |
|
init.xavier_uniform_(v, gain=self.beta) |
|
init.xavier_uniform_(self.output_linear.weight, gain=self.beta) |
|
init.constant_(self.in_proj.bias, 0.) |
|
init.constant_(self.output_linear.bias, 0.) |
|
|
|
|
|
def _project_input(self, qkv, past_key_values): |
|
batch_size, seq_len, d_embed = qkv.shape |
|
proj = self.in_proj(qkv) |
|
query, key, value = proj.chunk(chunks=3, dim=-1) |
|
|
|
|
|
query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) |
|
key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) |
|
value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) |
|
|
|
|
|
if past_key_values is not None: |
|
key, value = past_key_values.update(key, value, self.layer_idx) |
|
return query, key, value |
|
|
|
def forward( |
|
self, |
|
qkv, |
|
output_attentions, |
|
past_key_values, |
|
use_cache, |
|
): |
|
attn_type = self.attn_type |
|
if output_attentions and attn_type != "native": |
|
logger.warning_once( |
|
"CausalSelfAttention(output_attentions=True) and attn_type is not 'native': " |
|
"Forcing native attention." |
|
) |
|
attn_type = "native" |
|
|
|
if attn_type == "flash2": |
|
if use_cache is None or use_cache == False: |
|
return self._flash2_forward(qkv) |
|
else: |
|
return self._flash2_forward_cached(qkv, past_key_values) |
|
|
|
|
|
batch_size, seq_len, d_embed = qkv.shape |
|
|
|
|
|
query, key, value = self._project_input(qkv, past_key_values) |
|
kv_seq_len = key.shape[-2] |
|
|
|
|
|
attentions = None |
|
|
|
|
|
|
|
if attn_type == "torch": |
|
|
|
|
|
attended_values = F.scaled_dot_product_attention( |
|
query, |
|
key, |
|
value, |
|
attn_mask=None, |
|
dropout_p=self.dropout.p if self.training else 0.0, |
|
is_causal=(seq_len > 1), |
|
scale=self.dot_product_scale |
|
) |
|
|
|
else: |
|
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale |
|
|
|
|
|
if seq_len > 1: |
|
scores.masked_fill_( |
|
torch.tril( |
|
torch.ones(seq_len, kv_seq_len, dtype=torch.bool, device=qkv.device), |
|
diagonal=0, |
|
).logical_not(), |
|
float('-inf'), |
|
) |
|
|
|
|
|
attentions = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10)) |
|
del scores |
|
|
|
|
|
attended_values = torch.matmul(attentions, value) |
|
if not output_attentions: |
|
del attentions |
|
attentions = None |
|
|
|
|
|
attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed) |
|
|
|
|
|
attended_values = self.output_linear(attended_values) |
|
return dict( |
|
hidden_states=attended_values, |
|
attentions=attentions, |
|
past_key_values=past_key_values |
|
) |
|
|
|
|
|
def _flash2_forward( |
|
self, |
|
qkv, |
|
): |
|
batch_size, seq_len, d_embed = qkv.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qkv = self.in_proj(qkv).unflatten( |
|
-1, |
|
(3, self.num_heads, self.d_head) |
|
) |
|
|
|
attended_values = flash_attn_qkvpacked_func( |
|
self._downcast_to_float16(qkv)[0], |
|
dropout_p=self.dropout.p if self.training else 0.0, |
|
softmax_scale=self.dot_product_scale, |
|
causal=True, |
|
) |
|
|
|
|
|
|
|
attended_values = attended_values.view(batch_size, seq_len, d_embed) |
|
|
|
|
|
attended_values = self.output_linear(attended_values) |
|
return dict( |
|
hidden_states=attended_values, |
|
attentions=None, |
|
past_key_values=None |
|
) |
|
|
|
|
|
|
|
def _flash2_forward_cached( |
|
self, |
|
qkv, |
|
past_key_values, |
|
): |
|
batch_size, seq_len, d_embed = qkv.shape |
|
|
|
|
|
query, key, value = self._project_input(qkv, past_key_values) |
|
query, key, value = self._downcast_to_float16(query, key, value) |
|
|
|
|
|
|
|
|
|
|
|
query = query.transpose(1, 2) |
|
key = key.transpose(1, 2) |
|
value = value.transpose(1, 2) |
|
|
|
attended_values = flash_attn_func( |
|
q=query, |
|
k=key, |
|
v=value, |
|
dropout_p=self.dropout.p if self.training else 0.0, |
|
softmax_scale=self.dot_product_scale, |
|
causal=True, |
|
) |
|
|
|
|
|
|
|
attended_values = attended_values.view(batch_size, seq_len, d_embed) |
|
|
|
|
|
attended_values = self.output_linear(attended_values) |
|
return dict( |
|
hidden_states=attended_values, |
|
attentions=None, |
|
past_key_values=past_key_values |
|
) |
|
|
|
def _downcast_to_float16(self, *args): |
|
if args[0].dtype != torch.float32: |
|
return args |
|
|
|
if torch.is_autocast_enabled(): |
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"): |
|
target_dtype = self.config._pre_quantization_dtype |
|
else: |
|
target_dtype = self.output_linear.weight.dtype |
|
|
|
logger.warning_once( |
|
f"The input hidden states seems to be silently casted in float32, this might be related to" |
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
|
f" {target_dtype}." |
|
) |
|
|
|
return (arg.to(target_dtype) for arg in args) |