LingLong-317M / modeling_linglong.py
AlumiK's picture
push model
96b7ded
raw
history blame
26.9 kB
import math
import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from transformers.utils import logging
from transformers.activations import ACT2FN
from transformers.pytorch_utils import Conv1D
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutput
from .configuration_linglong import LingLongConfig
logger = logging.get_logger(__name__)
class LingLongAttention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
n_position = config.n_position
self.register_buffer(
'bias',
torch.tril(torch.ones((n_position, n_position), dtype=torch.bool)).view(
1, 1, n_position, n_position
),
persistent=False,
)
self.register_buffer('masked_bias', torch.tensor(-1e4), persistent=False)
self.n_embd = config.n_embd
self.n_head = config.n_head
self.head_dim = self.n_embd // self.n_head
self.split_size = self.n_embd
if self.head_dim * self.n_head != self.n_embd:
raise ValueError(
f'`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.n_embd} and `num_heads`:'
f' {self.n_head}).'
)
self.scale_attn_weights = config.scale_attn_weights
# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
self.c_attn = Conv1D(3 * self.n_embd, self.n_embd)
self.c_proj = Conv1D(self.n_embd, self.n_embd)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
# LingLong sparse attention.
self.mode = config.attn_mode
self.stride = config.attn_stride
self.c = config.attn_c
self.causal_mask = None
def _causal_mask(self, query_length, key_length):
return self.bias[:, :, key_length - query_length: key_length, :key_length]
def _sparse_causal_mask(self, query_length, key_length):
layout = torch.zeros([key_length, key_length], dtype=torch.bool, device=self.bias.device)
for idx in range(self.c):
layout[:, (self.stride - 1 - idx)::self.stride] = 1
for q_idx in range(key_length):
row = q_idx // self.stride
layout[q_idx, row * self.stride:(row + 1) * self.stride] = 1
# Any query cannot attend to keys above it.
layout[q_idx, q_idx + 1:] = 0
return layout[(key_length - query_length):].view(1, 1, query_length, key_length)
def _attn(self, query, key, value, attention_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
if self.causal_mask is None or self.causal_mask.size() != torch.Size([1, 1, query_length, key_length]):
if self.mode == 'sparse' and self.layer_idx % 2 != 0:
self.causal_mask = self._sparse_causal_mask(query_length, key_length)
else:
self.causal_mask = self._causal_mask(query_length, key_length)
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(self.causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()
# Preallocate attn_weights for `baddbmm`
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
# Compute Scale Factor
scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5
if self.scale_attn_by_inverse_layer_idx:
scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
if self.causal_mask is None or self.causal_mask.size() != torch.Size([1, 1, query_length, key_length]):
if self.mode == 'sparse' and self.layer_idx % 2 != 0:
self.causal_mask = self._sparse_causal_mask(query_length, key_length)
else:
self.causal_mask = self._causal_mask(query_length, key_length)
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(self.causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
if attn_weights.dtype != torch.float32:
raise RuntimeError('Error with upcasting, attn_weights does not have dtype torch.float32.')
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
@staticmethod
def _split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
@staticmethod
def _merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
use_cache=False,
output_attentions=False,
):
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.n_head, self.head_dim)
key = self._split_heads(key, self.n_head, self.head_dim)
value = self._split_heads(value, self.n_head, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask)
attn_output = self._merge_heads(attn_output, self.n_head, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class LingLongMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
n_embd = config.n_embd
self.c_fc = Conv1D(intermediate_size, n_embd)
self.c_proj = Conv1D(n_embd, intermediate_size)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class LingLongBlock(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
n_embd = config.n_embd
inner_dim = config.n_inner if config.n_inner is not None else 4 * n_embd
self.ln_1 = nn.LayerNorm(n_embd, eps=config.layer_norm_epsilon)
self.attn = LingLongAttention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(n_embd, eps=config.layer_norm_epsilon)
self.mlp = LingLongMLP(inner_dim, config)
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
use_cache=False,
output_attentions=False,
):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions)
class LingLongPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LingLongConfig
base_model_prefix = 'transformer'
supports_gradient_checkpointing = True
_no_split_modules = ['LingLongBlock']
_skip_keys_device_placement = 'past_key_values'
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, Conv1D)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name == 'c_proj.weight':
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
class LingLongModel(LingLongPreTrainedModel):
def __init__(self, config: LingLongConfig):
super().__init__(config)
self.n_embd = config.n_embd
self.wte = nn.Embedding(config.vocab_size, self.n_embd)
self.wpe = nn.Embedding(config.n_position, self.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([LingLongBlock(config, layer_idx=i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon)
# Model parallel
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
def forward(
self,
input_ids: torch.LongTensor | None = None,
past_key_values: tuple[tuple[torch.Tensor]] | None = None,
attention_mask: torch.FloatTensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
) -> tuple | BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time.')
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError('You have to specify either input_ids or inputs_embeds.')
device = input_ids.device if input_ids is not None else inputs_embeds.device
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
# LingLongAttention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError('batch_size has to be defined and > 0.')
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
# noinspection PyUnresolvedReferences
logger.warning_once(
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
use_cache,
output_attentions,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class LingLongForCausalLM(LingLongPreTrainedModel):
_tied_weights_keys = ['lm_head.weight']
def __init__(self, config):
super().__init__(config)
self.transformer = LingLongModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
attention_mask = kwargs.get('attention_mask', None)
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
# noinspection PyUnresolvedReferences
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1]:]
else:
position_ids = None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update(
{
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'position_ids': position_ids,
'attention_mask': attention_mask,
}
)
return model_inputs
def forward(
self,
input_ids: torch.LongTensor | None = None,
past_key_values: tuple[tuple[torch.Tensor]] | None = None,
attention_mask: torch.FloatTensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
) -> tuple | CausalLMOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(
loss=loss,
logits=lm_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(
past_key_values: tuple[tuple[torch.Tensor]],
beam_idx: torch.Tensor,
**kwargs,
) -> tuple[tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
# noinspection PyTypeChecker
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)