import math import torch import torch.utils.checkpoint from torch import nn from torch.cuda.amp import autocast from transformers.utils import logging from transformers.activations import ACT2FN from transformers.pytorch_utils import Conv1D from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutput from .configuration_linglong import LingLongConfig logger = logging.get_logger(__name__) class LingLongAttention(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() n_position = config.n_position self.register_buffer( 'bias', torch.tril(torch.ones((n_position, n_position), dtype=torch.bool)).view( 1, 1, n_position, n_position ), persistent=False, ) self.register_buffer('masked_bias', torch.tensor(-1e4), persistent=False) self.n_embd = config.n_embd self.n_head = config.n_head self.head_dim = self.n_embd // self.n_head self.split_size = self.n_embd if self.head_dim * self.n_head != self.n_embd: raise ValueError( f'`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.n_embd} and `num_heads`:' f' {self.n_head}).' ) self.scale_attn_weights = config.scale_attn_weights # Layer-wise attention scaling, reordering, and upcasting self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx self.layer_idx = layer_idx self.reorder_and_upcast_attn = config.reorder_and_upcast_attn self.c_attn = Conv1D(3 * self.n_embd, self.n_embd) self.c_proj = Conv1D(self.n_embd, self.n_embd) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) # LingLong sparse attention. self.mode = config.attn_mode self.stride = config.attn_stride self.c = config.attn_c self.causal_mask = None def _causal_mask(self, query_length, key_length): return self.bias[:, :, key_length - query_length: key_length, :key_length] def _sparse_causal_mask(self, query_length, key_length): layout = torch.zeros([key_length, key_length], dtype=torch.bool, device=self.bias.device) for idx in range(self.c): layout[:, (self.stride - 1 - idx)::self.stride] = 1 for q_idx in range(key_length): row = q_idx // self.stride layout[q_idx, row * self.stride:(row + 1) * self.stride] = 1 # Any query cannot attend to keys above it. layout[q_idx, q_idx + 1:] = 0 return layout[(key_length - query_length):].view(1, 1, query_length, key_length) def _attn(self, query, key, value, attention_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: attn_weights = attn_weights / torch.full( [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device ) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: attn_weights = attn_weights / float(self.layer_idx + 1) # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) if self.causal_mask is None or self.causal_mask.size() != torch.Size([1, 1, query_length, key_length]): if self.mode == 'sparse' and self.layer_idx % 2 != 0: self.causal_mask = self._sparse_causal_mask(query_length, key_length) else: self.causal_mask = self._causal_mask(query_length, key_length) mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(self.causal_mask, attn_weights.to(attn_weights.dtype), mask_value) if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise attn_weights = attn_weights.type(value.dtype) attn_weights = self.attn_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) bsz, num_heads, q_seq_len, dk = query.size() _, _, k_seq_len, _ = key.size() # Preallocate attn_weights for `baddbmm` attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) # Compute Scale Factor scale_factor = 1.0 if self.scale_attn_weights: scale_factor /= float(value.size(-1)) ** 0.5 if self.scale_attn_by_inverse_layer_idx: scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) with autocast(enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) if self.causal_mask is None or self.causal_mask.size() != torch.Size([1, 1, query_length, key_length]): if self.mode == 'sparse' and self.layer_idx % 2 != 0: self.causal_mask = self._sparse_causal_mask(query_length, key_length) else: self.causal_mask = self._causal_mask(query_length, key_length) mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) attn_weights = torch.where(self.causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise if attn_weights.dtype != torch.float32: raise RuntimeError('Error with upcasting, attn_weights does not have dtype torch.float32.') attn_weights = attn_weights.type(value.dtype) attn_weights = self.attn_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights @staticmethod def _split_heads(tensor, num_heads, attn_head_size): """ Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) @staticmethod def _merge_heads(tensor, num_heads, attn_head_size): """ Merges attn_head_size dim and num_attn_heads dim into hidden_size """ tensor = tensor.permute(0, 2, 1, 3).contiguous() new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) return tensor.view(new_shape) def forward( self, hidden_states, layer_past=None, attention_mask=None, use_cache=False, output_attentions=False, ): query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) query = self._split_heads(query, self.n_head, self.head_dim) key = self._split_heads(key, self.n_head, self.head_dim) value = self._split_heads(value, self.n_head, self.head_dim) if layer_past is not None: past_key, past_value = layer_past key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) if use_cache is True: present = (key, value) else: present = None if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask) else: attn_output, attn_weights = self._attn(query, key, value, attention_mask) attn_output = self._merge_heads(attn_output, self.n_head, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: outputs += (attn_weights,) return outputs # a, present, (attentions) class LingLongMLP(nn.Module): def __init__(self, intermediate_size, config): super().__init__() n_embd = config.n_embd self.c_fc = Conv1D(intermediate_size, n_embd) self.c_proj = Conv1D(n_embd, intermediate_size) self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_pdrop) def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.c_proj(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class LingLongBlock(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() n_embd = config.n_embd inner_dim = config.n_inner if config.n_inner is not None else 4 * n_embd self.ln_1 = nn.LayerNorm(n_embd, eps=config.layer_norm_epsilon) self.attn = LingLongAttention(config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(n_embd, eps=config.layer_norm_epsilon) self.mlp = LingLongMLP(inner_dim, config) def forward( self, hidden_states, layer_past=None, attention_mask=None, use_cache=False, output_attentions=False, ): residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( hidden_states, layer_past=layer_past, attention_mask=attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] return outputs # hidden_states, present, (attentions) class LingLongPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = LingLongConfig base_model_prefix = 'transformer' supports_gradient_checkpointing = True _no_split_modules = ['LingLongBlock'] _skip_keys_device_placement = 'past_key_values' def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): if name == 'c_proj.weight': # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) class LingLongModel(LingLongPreTrainedModel): def __init__(self, config: LingLongConfig): super().__init__(config) self.n_embd = config.n_embd self.wte = nn.Embedding(config.vocab_size, self.n_embd) self.wpe = nn.Embedding(config.n_position, self.n_embd) self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([LingLongBlock(config, layer_idx=i) for i in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon) # Model parallel self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.wte def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings def forward( self, input_ids: torch.LongTensor | None = None, past_key_values: tuple[tuple[torch.Tensor]] | None = None, attention_mask: torch.FloatTensor | None = None, position_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, ) -> tuple | BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time.') elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size = inputs_embeds.shape[0] else: raise ValueError('You have to specify either input_ids or inputs_embeds.') device = input_ids.device if input_ids is not None else inputs_embeds.device if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) # LingLongAttention mask. if attention_mask is not None: if batch_size <= 0: raise ValueError('batch_size has to be defined and > 0.') attention_mask = attention_mask.view(batch_size, -1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. attention_mask = attention_mask[:, None, None, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min if inputs_embeds is None: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds hidden_states = self.drop(hidden_states) output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) if self.gradient_checkpointing and self.training: if use_cache: # noinspection PyUnresolvedReferences logger.warning_once( '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' ) use_cache = False presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, attention_mask, use_cache, output_attentions, ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class LingLongForCausalLM(LingLongPreTrainedModel): _tied_weights_keys = ['lm_head.weight'] def __init__(self, config): super().__init__(config) self.transformer = LingLongModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): # Omit tokens covered by past_key_values if past_key_values: past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] attention_mask = kwargs.get('attention_mask', None) position_ids = kwargs.get('position_ids', None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation # noinspection PyUnresolvedReferences position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1]:] else: position_ids = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {'inputs_embeds': inputs_embeds} else: model_inputs = {'input_ids': input_ids} model_inputs.update( { 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache'), 'position_ids': position_ids, 'attention_mask': attention_mask, } ) return model_inputs def forward( self, input_ids: torch.LongTensor | None = None, past_key_values: tuple[tuple[torch.Tensor]] | None = None, attention_mask: torch.FloatTensor | None = None, position_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, ) -> tuple | CausalLMOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutput( loss=loss, logits=lm_logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @staticmethod def _reorder_cache( past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor, **kwargs, ) -> tuple[tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ # noinspection PyTypeChecker return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) for layer_past in past_key_values )