# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. from .configuration_baichuan import BaichuanConfig # from .generation_utils import build_chat_input, TextIterStreamer import math from threading import Thread from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.activations import ACT2FN from transformers.generation.utils import GenerationConfig from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging, ContextManagers import os from contextlib import contextmanager from accelerate import init_empty_weights logger = logging.get_logger(__name__) try: from xformers import ops as xops except ImportError: xops = None logger.warning( "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers." ) def _get_interleave(n): def _get_interleave_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return _get_interleave_power_of_2(n) else: closest_power_of_2 = 2 ** math.floor(math.log2(n)) return ( _get_interleave_power_of_2(closest_power_of_2) + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] ) def _fill_with_neg_inf(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float("-inf")).type_as(t) def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1) _future_mask = _future_mask.unsqueeze(0) + alibi new_future_mask = _future_mask.to(tensor) return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos] def _gen_alibi_mask(tensor, n_head, max_pos): slopes = torch.Tensor(_get_interleave(n_head)) position_point = torch.arange(max_pos) - max_pos + 1 position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1) diag = torch.diag(position_point[0]) position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point alibi = alibi.view(n_head, 1, max_pos) alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) alibi_mask = alibi_mask.unsqueeze(0) + alibi return alibi_mask class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, epsilon=1e-6): super().__init__() self.weight = torch.nn.Parameter(torch.empty(hidden_size)) self.epsilon = epsilon def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) # convert into half-precision if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states class MLP(torch.nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, ): super().__init__() self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) self.act_fn = ACT2FN[hidden_act] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class BaichuanAttention(torch.nn.Module): def __init__(self, config: BaichuanConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.model_max_length if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}" ) self.W_pack = torch.nn.Linear( self.hidden_size, 3 * self.hidden_size, bias=False ) self.o_proj = torch.nn.Linear( self.num_heads * self.head_dim, self.hidden_size, bias=False ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return ( tensor.view(bsz, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) .contiguous() ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() proj = self.W_pack(hidden_states) proj = ( proj.unflatten(-1, (3, self.hidden_size)) .unsqueeze(0) .transpose(0, -2) .squeeze(-2) ) query_states = ( proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ) key_states = ( proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ) value_states = ( proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None if xops is not None and self.training: attn_weights = None # query_states = query_states.transpose(1, 2) # key_states = key_states.transpose(1, 2) # value_states = value_states.transpose(1, 2) # attn_output = xops.memory_efficient_attention( # query_states, key_states, value_states, attn_bias=attention_mask # ) with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask) attn_output = attn_output.transpose(1, 2) else: attn_weights = torch.matmul( query_states, key_states.transpose(2, 3) ) / math.sqrt(self.head_dim) if attention_mask is not None: if q_len == 1: # inference with cache if len(attention_mask.size()) == 4: attention_mask = attention_mask[:, :, -1:, :] else: attention_mask = attention_mask[:, -1:, :] attn_weights = attn_weights + attention_mask attn_weights = torch.max( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) ) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class BaichuanLayer(torch.nn.Module): def __init__(self, config: BaichuanConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = BaichuanAttention(config=config) self.mlp = MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, epsilon=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if use_cache: outputs += (present_key_value,) return outputs class BaichuanPreTrainedModel(PreTrainedModel): config_class = BaichuanConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["BaichuanLayer"] _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, torch.nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, BaichuanModel): module.gradient_checkpointing = value class BaichuanModel(BaichuanPreTrainedModel): def __init__(self, config: BaichuanConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.n_head = config.num_attention_heads self.embed_tokens = torch.nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = torch.nn.ModuleList( [BaichuanLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) self.gradient_checkpointing = config.gradient_checkpointing self.post_init() self.max_cache_pos = config.model_max_length self.first_run = True self.alibi_mask = None def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def get_alibi_mask(self, tensor, seq_length_with_past): if self.training: slopes = torch.Tensor(_get_interleave(self.n_head)) position_point = ( torch.arange(seq_length_with_past) - seq_length_with_past + 1 ) position_point = ( position_point.unsqueeze(0) .unsqueeze(0) .expand(self.n_head, seq_length_with_past, -1) ) diag = torch.diag(position_point[0]) position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose( -1, -2 ) alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point mask = _buffered_future_mask( tensor, seq_length_with_past, alibi, self.n_head ) else: if self.first_run: self.first_run = False self.register_buffer( "future_mask", _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to( tensor ), persistent=False, ) if seq_length_with_past > self.max_cache_pos: self.max_cache_pos = seq_length_with_past self.register_buffer( "future_mask", _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to( tensor ), persistent=False, ) mask = self.future_mask[ : self.n_head, :seq_length_with_past, :seq_length_with_past ] return mask def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot provide both input_ids and inputs_embeds simultaneously" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You need to provide input_ids or inputs_embeds") use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) seq_length_with_past = seq_length if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if self.training: if ( self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past ): self.alibi_mask = self.get_alibi_mask( inputs_embeds, seq_length_with_past ) alibi_mask = self.alibi_mask else: alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) if attention_mask is not None: if len(attention_mask.shape) == 2: expanded_mask = attention_mask.to(alibi_mask.dtype) expanded_mask = torch.tril( torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0) ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0) else: expanded_mask = attention_mask bsz = inputs_embeds.size(0) src_len, tgt_len = alibi_mask.size()[-2:] expanded_mask = ( expanded_mask.unsqueeze(1) .expand(bsz, 1, src_len, tgt_len) .to(alibi_mask.dtype) ) inverted_mask = 1.0 - expanded_mask inverted_mask = inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min ) attention_mask = inverted_mask + alibi_mask.unsqueeze(0) else: attention_mask = alibi_mask hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class NormHead(nn.Module): def __init__(self, hidden_size, vocab_size, bias=False): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) self.first_flag = True def forward(self, hidden_states): if self.training: norm_weight = nn.functional.normalize(self.weight) self.first_flag = True elif self.first_flag: self.first_flag = False self.weight = nn.Parameter(nn.functional.normalize(self.weight)) norm_weight = self.weight else: norm_weight = self.weight return nn.functional.linear(hidden_states, norm_weight) _init_weights = True @contextmanager def no_init_weights(_enable=True): global _init_weights old_init_weights = _init_weights if _enable: _init_weights = False try: yield finally: _init_weights = old_init_weights class BaichuanCharRM(BaichuanPreTrainedModel): def __init__(self, config): super().__init__(config) self.model = BaichuanModel(config) self.score = nn.Linear(config.hidden_size, 1, bias=True) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] hidden_states = hidden_states[:, -1, :] logits = F.sigmoid(self.score(hidden_states).squeeze()) loss = None if labels is not None: labels = labels.type_as(logits) loss_fct = nn.MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)/4) return loss, logits