mlm_llama / modeling_custom_llama.py
Shounak's picture
Update modeling_custom_llama.py
a67e8a1 verified
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import MODEL_FOR_MASKED_LM_MAPPING
from transformers import PretrainedConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers import GPT2TokenizerFast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM
import torch
import torch.nn as nn
from typing import Optional, Tuple, List
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class ModelOutput:
loss: float
logits: any
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Define your custom configuration
class CustomLlamaConfig(PretrainedConfig):
model_type = "custom_llama" # Must match config.json
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Add any custom hyperparameters if needed
# New Custom Llama Classes
class CustomLlamaAttention(LlamaAttention):
def __init__(self, config, layer_idx: int):
super().__init__(config, layer_idx)
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = 1.0 / (self.head_dim ** 0.5)
self.w_q_start = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.w_q_dir = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.w_k_start = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.w_k_dir = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.w_v = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
del self.q_proj, self.k_proj, self.v_proj
def _compute_metric(self, q_start, q_dir, k_start, k_dir):
std_term = torch.einsum("bhqd,bhkd->bhqk", q_start, k_start)
cross_term1 = torch.einsum("bhqd,bhkd->bhqk", q_dir, k_start)
cross_term2 = torch.einsum("bhqd,bhkd->bhqk", q_start, k_dir)
scores = (std_term + cross_term1 + cross_term2) * self.scale
# assert not torch.isnan(scores).any(), "NaN in attention scores"
return scores
def _get_causal_mask(self, query_length, key_length, device):
return torch.triu(
torch.full((query_length, key_length), float('-inf'), device=device),
diagonal=1
).unsqueeze(0).unsqueeze(0) # Shape: [1, 1, seq_len, seq_len]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
is_causal: Optional[torch.Tensor] = None, # New parameter
):
batch_size, seq_len, _ = hidden_states.size()
q_base = self.w_q_start(hidden_states)
q_dir = self.w_q_dir(hidden_states) - q_base
k_base = self.w_k_start(hidden_states)
k_dir = self.w_k_dir(hidden_states) - k_base
value = self.w_v(hidden_states)
# assert not torch.isnan(hidden_states).any(), "NaN in hidden_states scores"
# assert not torch.isnan(q_dir).any(), "NaN in q_dir scores"
# assert not torch.isnan(k_base).any(), "NaN in k_base scores"
# assert not torch.isnan(k_dir).any(), "NaN in k_dir scores"
q_start = q_base.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
q_dir = q_dir.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k_start = k_base.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k_dir = k_dir.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
q_start, k_start = apply_rotary_pos_emb(q_start, k_start, cos, sin)
q_dir, k_dir = apply_rotary_pos_emb(q_dir, k_dir, cos, sin)
attn_scores = self._compute_metric(q_start, q_dir, k_start, k_dir)
# Replace existing mask logic with:
if attention_mask is not None:
padding_mask = (attention_mask == 0).view(batch_size, 1, 1, -1)
padding_mask = padding_mask.expand(-1, self.num_heads, -1, -1)
else:
padding_mask = None
if is_causal is not None:
causal_mask = self._get_causal_mask(seq_len, seq_len, attn_scores.device)
causal_mask = causal_mask.expand(batch_size, self.num_heads, -1, -1)
is_causal = is_causal.view(-1, 1, 1, 1)
combined_mask = torch.where(is_causal, causal_mask, 0.0)
else:
combined_mask = 0.0
attn_scores = attn_scores + combined_mask
if padding_mask is not None:
attn_scores = attn_scores.masked_fill(padding_mask, torch.finfo(attn_scores.dtype).min)
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
if output_attentions:
return attn_output, attn_weights
return attn_output, None
class CustomLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config, layer_idx):
super().__init__(config, layer_idx)
self.self_attn = CustomLlamaAttention(config, layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
is_causal: Optional[torch.Tensor] = None, # Add this
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
is_causal=is_causal,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CustomLlamaModel(LlamaModel):
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList([
CustomLlamaDecoderLayer(config, layer_idx=i)
for i in range(config.num_hidden_layers)
])
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
is_causal: Optional[torch.Tensor] = None, # Add this
cache_position: Optional[torch.LongTensor] = None,
):
# Existing embedding and position embedding logic remains unchanged
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
if position_ids is None:
position_ids = (attention_mask.long().cumsum(dim=1) - 1).masked_fill(attention_mask == 0, 0)
cos, sin = self.rotary_emb(hidden_states, position_ids=position_ids)
for layer in self.layers:
hidden_states = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=(cos, sin),
is_causal=is_causal,
)
hidden_states = self.norm(hidden_states)
return hidden_states # Simplified for brevity; adjust as per original
class CustomLlamaForCausalLM(LlamaForCausalLM):
config_class = CustomLlamaConfig # Add this line
def __init__(self, config):
super().__init__(config)
self.model = CustomLlamaModel(config)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
is_causal: Optional[torch.Tensor] = None, # Add this
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_causal=is_causal,
)
hidden_states = outputs
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Split batch into causal/masked using is_causal
is_causal = is_causal.to(labels.device)
# Causal loss (standard next-token prediction)
causal_logits = logits[is_causal][..., :-1, :].contiguous()
causal_labels = labels[is_causal][..., 1:].contiguous()
# Masked loss (predict current masked token)
masked_logits = logits[~is_causal][..., :-1, :].contiguous()
masked_labels = labels[~is_causal][..., 1:].contiguous()
loss = 0.0
if causal_logits.numel() > 0:
loss += F.cross_entropy(
causal_logits.view(-1, causal_logits.size(-1)),
causal_labels.view(-1),
ignore_index=-100
)
if masked_logits.numel() > 0:
loss += F.cross_entropy(
masked_logits.view(-1, masked_logits.size(-1)),
masked_labels.view(-1),
ignore_index=-100
)
return ModelOutput(loss=loss, logits=logits)
class CustomLlamaForMaskedLM(CustomLlamaForCausalLM):
config_class = CustomLlamaConfig # Add this line
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs # Ignore any additional arguments
):
# Force is_causal=False for all examples to enable bidirectional attention
batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
is_causal = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
# Call the parent forward method with is_causal set to False
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_causal=is_causal,
)
from transformers import CONFIG_MAPPING, MODEL_MAPPING
CONFIG_MAPPING.update({"custom_llama": CustomLlamaConfig})
MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM})
# Optionally, if you need to register for remote code loading:
def _register():
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("custom_llama", CustomLlamaConfig)
MODEL_FOR_MASKED_LM_MAPPING.register(CustomLlamaConfig, CustomLlamaForMaskedLM)
_register()