from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers import MODEL_FOR_MASKED_LM_MAPPING from transformers import PretrainedConfig from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from transformers import GPT2TokenizerFast from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM import torch import torch.nn as nn from typing import Optional, Tuple, List import torch.nn.functional as F from dataclasses import dataclass @dataclass class ModelOutput: loss: float logits: any def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # Define your custom configuration class CustomLlamaConfig(PretrainedConfig): model_type = "custom_llama" # Must match config.json def __init__(self, **kwargs): super().__init__(**kwargs) # Add any custom hyperparameters if needed # New Custom Llama Classes class CustomLlamaAttention(LlamaAttention): def __init__(self, config, layer_idx: int): super().__init__(config, layer_idx) self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.scale = 1.0 / (self.head_dim ** 0.5) self.w_q_start = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.w_q_dir = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.w_k_start = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.w_k_dir = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.w_v = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) del self.q_proj, self.k_proj, self.v_proj def _compute_metric(self, q_start, q_dir, k_start, k_dir): std_term = torch.einsum("bhqd,bhkd->bhqk", q_start, k_start) cross_term1 = torch.einsum("bhqd,bhkd->bhqk", q_dir, k_start) cross_term2 = torch.einsum("bhqd,bhkd->bhqk", q_start, k_dir) scores = (std_term + cross_term1 + cross_term2) * self.scale # assert not torch.isnan(scores).any(), "NaN in attention scores" return scores def _get_causal_mask(self, query_length, key_length, device): return torch.triu( torch.full((query_length, key_length), float('-inf'), device=device), diagonal=1 ).unsqueeze(0).unsqueeze(0) # Shape: [1, 1, seq_len, seq_len] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, is_causal: Optional[torch.Tensor] = None, # New parameter ): batch_size, seq_len, _ = hidden_states.size() q_base = self.w_q_start(hidden_states) q_dir = self.w_q_dir(hidden_states) - q_base k_base = self.w_k_start(hidden_states) k_dir = self.w_k_dir(hidden_states) - k_base value = self.w_v(hidden_states) # assert not torch.isnan(hidden_states).any(), "NaN in hidden_states scores" # assert not torch.isnan(q_dir).any(), "NaN in q_dir scores" # assert not torch.isnan(k_base).any(), "NaN in k_base scores" # assert not torch.isnan(k_dir).any(), "NaN in k_dir scores" q_start = q_base.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) q_dir = q_dir.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k_start = k_base.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k_dir = k_dir.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings q_start, k_start = apply_rotary_pos_emb(q_start, k_start, cos, sin) q_dir, k_dir = apply_rotary_pos_emb(q_dir, k_dir, cos, sin) attn_scores = self._compute_metric(q_start, q_dir, k_start, k_dir) # Replace existing mask logic with: if attention_mask is not None: padding_mask = (attention_mask == 0).view(batch_size, 1, 1, -1) padding_mask = padding_mask.expand(-1, self.num_heads, -1, -1) else: padding_mask = None if is_causal is not None: causal_mask = self._get_causal_mask(seq_len, seq_len, attn_scores.device) causal_mask = causal_mask.expand(batch_size, self.num_heads, -1, -1) is_causal = is_causal.view(-1, 1, 1, 1) combined_mask = torch.where(is_causal, causal_mask, 0.0) else: combined_mask = 0.0 attn_scores = attn_scores + combined_mask if padding_mask is not None: attn_scores = attn_scores.masked_fill(padding_mask, torch.finfo(attn_scores.dtype).min) attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, seq_len, self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) if output_attentions: return attn_output, attn_weights return attn_output, None class CustomLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config, layer_idx): super().__init__(config, layer_idx) self.self_attn = CustomLlamaAttention(config, layer_idx) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, is_causal: Optional[torch.Tensor] = None, # Add this ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, is_causal=is_causal, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class CustomLlamaModel(LlamaModel): def __init__(self, config): super().__init__(config) self.layers = nn.ModuleList([ CustomLlamaDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers) ]) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, is_causal: Optional[torch.Tensor] = None, # Add this cache_position: Optional[torch.LongTensor] = None, ): # Existing embedding and position embedding logic remains unchanged if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds if position_ids is None: position_ids = (attention_mask.long().cumsum(dim=1) - 1).masked_fill(attention_mask == 0, 0) cos, sin = self.rotary_emb(hidden_states, position_ids=position_ids) for layer in self.layers: hidden_states = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=(cos, sin), is_causal=is_causal, ) hidden_states = self.norm(hidden_states) return hidden_states # Simplified for brevity; adjust as per original class CustomLlamaForCausalLM(LlamaForCausalLM): config_class = CustomLlamaConfig # Add this line def __init__(self, config): super().__init__(config) self.model = CustomLlamaModel(config) self.post_init() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, is_causal: Optional[torch.Tensor] = None, # Add this ): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_causal=is_causal, ) hidden_states = outputs logits = self.lm_head(hidden_states) loss = None if labels is not None: # Split batch into causal/masked using is_causal is_causal = is_causal.to(labels.device) # Causal loss (standard next-token prediction) causal_logits = logits[is_causal][..., :-1, :].contiguous() causal_labels = labels[is_causal][..., 1:].contiguous() # Masked loss (predict current masked token) masked_logits = logits[~is_causal][..., :-1, :].contiguous() masked_labels = labels[~is_causal][..., 1:].contiguous() loss = 0.0 if causal_logits.numel() > 0: loss += F.cross_entropy( causal_logits.view(-1, causal_logits.size(-1)), causal_labels.view(-1), ignore_index=-100 ) if masked_logits.numel() > 0: loss += F.cross_entropy( masked_logits.view(-1, masked_logits.size(-1)), masked_labels.view(-1), ignore_index=-100 ) return ModelOutput(loss=loss, logits=logits) class CustomLlamaForMaskedLM(CustomLlamaForCausalLM): config_class = CustomLlamaConfig # Add this line def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs # Ignore any additional arguments ): # Force is_causal=False for all examples to enable bidirectional attention batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0) is_causal = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device) # Call the parent forward method with is_causal set to False return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_causal=is_causal, ) from transformers import CONFIG_MAPPING, MODEL_MAPPING CONFIG_MAPPING.update({"custom_llama": CustomLlamaConfig}) MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM}) # Optionally, if you need to register for remote code loading: def _register(): from transformers import AutoConfig, AutoModelForCausalLM AutoConfig.register("custom_llama", CustomLlamaConfig) MODEL_FOR_MASKED_LM_MAPPING.register(CustomLlamaConfig, CustomLlamaForMaskedLM) _register()