|
from transformers.models.llama.modeling_llama import LlamaForCausalLM |
|
from transformers import MODEL_FOR_MASKED_LM_MAPPING |
|
from transformers import PretrainedConfig |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from transformers import GPT2TokenizerFast |
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM |
|
import torch |
|
import torch.nn as nn |
|
from typing import Optional, Tuple, List |
|
import torch.nn.functional as F |
|
from dataclasses import dataclass |
|
|
|
@dataclass |
|
class ModelOutput: |
|
loss: float |
|
logits: any |
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
position_ids (`torch.Tensor`, *optional*): |
|
Deprecated and unused. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
class CustomLlamaConfig(PretrainedConfig): |
|
model_type = "custom_llama" |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
class CustomLlamaAttention(LlamaAttention): |
|
def __init__(self, config, layer_idx: int): |
|
super().__init__(config, layer_idx) |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = config.hidden_size // config.num_attention_heads |
|
self.scale = 1.0 / (self.head_dim ** 0.5) |
|
|
|
self.w_q_start = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
self.w_q_dir = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
self.w_k_start = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
self.w_k_dir = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
self.w_v = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) |
|
|
|
del self.q_proj, self.k_proj, self.v_proj |
|
|
|
def _compute_metric(self, q_start, q_dir, k_start, k_dir): |
|
std_term = torch.einsum("bhqd,bhkd->bhqk", q_start, k_start) |
|
cross_term1 = torch.einsum("bhqd,bhkd->bhqk", q_dir, k_start) |
|
cross_term2 = torch.einsum("bhqd,bhkd->bhqk", q_start, k_dir) |
|
scores = (std_term + cross_term1 + cross_term2) * self.scale |
|
|
|
return scores |
|
|
|
def _get_causal_mask(self, query_length, key_length, device): |
|
return torch.triu( |
|
torch.full((query_length, key_length), float('-inf'), device=device), |
|
diagonal=1 |
|
).unsqueeze(0).unsqueeze(0) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
is_causal: Optional[torch.Tensor] = None, |
|
): |
|
batch_size, seq_len, _ = hidden_states.size() |
|
|
|
q_base = self.w_q_start(hidden_states) |
|
q_dir = self.w_q_dir(hidden_states) - q_base |
|
k_base = self.w_k_start(hidden_states) |
|
k_dir = self.w_k_dir(hidden_states) - k_base |
|
value = self.w_v(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
q_start = q_base.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
q_dir = q_dir.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
k_start = k_base.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
k_dir = k_dir.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
cos, sin = position_embeddings |
|
q_start, k_start = apply_rotary_pos_emb(q_start, k_start, cos, sin) |
|
q_dir, k_dir = apply_rotary_pos_emb(q_dir, k_dir, cos, sin) |
|
|
|
attn_scores = self._compute_metric(q_start, q_dir, k_start, k_dir) |
|
|
|
|
|
if attention_mask is not None: |
|
padding_mask = (attention_mask == 0).view(batch_size, 1, 1, -1) |
|
padding_mask = padding_mask.expand(-1, self.num_heads, -1, -1) |
|
else: |
|
padding_mask = None |
|
|
|
if is_causal is not None: |
|
causal_mask = self._get_causal_mask(seq_len, seq_len, attn_scores.device) |
|
causal_mask = causal_mask.expand(batch_size, self.num_heads, -1, -1) |
|
is_causal = is_causal.view(-1, 1, 1, 1) |
|
combined_mask = torch.where(is_causal, causal_mask, 0.0) |
|
else: |
|
combined_mask = 0.0 |
|
|
|
attn_scores = attn_scores + combined_mask |
|
if padding_mask is not None: |
|
attn_scores = attn_scores.masked_fill(padding_mask, torch.finfo(attn_scores.dtype).min) |
|
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
|
|
|
attn_output = torch.matmul(attn_weights, value) |
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.view(batch_size, seq_len, self.num_heads * self.head_dim) |
|
attn_output = self.o_proj(attn_output) |
|
|
|
if output_attentions: |
|
return attn_output, attn_weights |
|
return attn_output, None |
|
class CustomLlamaDecoderLayer(LlamaDecoderLayer): |
|
def __init__(self, config, layer_idx): |
|
super().__init__(config, layer_idx) |
|
self.self_attn = CustomLlamaAttention(config, layer_idx) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
is_causal: Optional[torch.Tensor] = None, |
|
): |
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
hidden_states, _ = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
is_causal=is_causal, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
return hidden_states |
|
|
|
class CustomLlamaModel(LlamaModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.layers = nn.ModuleList([ |
|
CustomLlamaDecoderLayer(config, layer_idx=i) |
|
for i in range(config.num_hidden_layers) |
|
]) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
is_causal: Optional[torch.Tensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
): |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
hidden_states = inputs_embeds |
|
if position_ids is None: |
|
position_ids = (attention_mask.long().cumsum(dim=1) - 1).masked_fill(attention_mask == 0, 0) |
|
cos, sin = self.rotary_emb(hidden_states, position_ids=position_ids) |
|
|
|
|
|
for layer in self.layers: |
|
hidden_states = layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=(cos, sin), |
|
is_causal=is_causal, |
|
) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
return hidden_states |
|
|
|
class CustomLlamaForCausalLM(LlamaForCausalLM): |
|
config_class = CustomLlamaConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = CustomLlamaModel(config) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
is_causal: Optional[torch.Tensor] = None, |
|
): |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
is_causal=is_causal, |
|
) |
|
|
|
hidden_states = outputs |
|
logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
is_causal = is_causal.to(labels.device) |
|
|
|
|
|
causal_logits = logits[is_causal][..., :-1, :].contiguous() |
|
causal_labels = labels[is_causal][..., 1:].contiguous() |
|
|
|
|
|
masked_logits = logits[~is_causal][..., :-1, :].contiguous() |
|
masked_labels = labels[~is_causal][..., 1:].contiguous() |
|
|
|
loss = 0.0 |
|
if causal_logits.numel() > 0: |
|
loss += F.cross_entropy( |
|
causal_logits.view(-1, causal_logits.size(-1)), |
|
causal_labels.view(-1), |
|
ignore_index=-100 |
|
) |
|
if masked_logits.numel() > 0: |
|
loss += F.cross_entropy( |
|
masked_logits.view(-1, masked_logits.size(-1)), |
|
masked_labels.view(-1), |
|
ignore_index=-100 |
|
) |
|
|
|
return ModelOutput(loss=loss, logits=logits) |
|
|
|
class CustomLlamaForMaskedLM(CustomLlamaForCausalLM): |
|
config_class = CustomLlamaConfig |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs |
|
): |
|
|
|
batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0) |
|
is_causal = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device) |
|
|
|
|
|
return super().forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
is_causal=is_causal, |
|
) |
|
from transformers import CONFIG_MAPPING, MODEL_MAPPING |
|
|
|
CONFIG_MAPPING.update({"custom_llama": CustomLlamaConfig}) |
|
MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM}) |
|
|
|
def _register(): |
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
AutoConfig.register("custom_llama", CustomLlamaConfig) |
|
MODEL_FOR_MASKED_LM_MAPPING.register(CustomLlamaConfig, CustomLlamaForMaskedLM) |
|
|
|
_register() |
|
|