Spaces:
Running
on
Zero
Running
on
Zero
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import os | |
import torch.nn as nn | |
from typing import List, Optional, Tuple, Union | |
import math | |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast | |
# sinusoidal positional encoding | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] * 1.0 | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
class LlamaAdaptiveRMSNorm(nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-6): | |
super().__init__() | |
self.eps = eps | |
# The gamma parameter | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def _norm(self, x: torch.Tensor): | |
# (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim) | |
# rsqrt: 1 / sqrt(x) | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x: torch.Tensor): | |
# (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim) | |
return self.weight * self._norm(x.float()).type_as(x) | |
class MultiEmbedding(nn.Module): | |
"""Embedding for multiple quantization layers, summing up the embeddings of each layer.""" | |
def __init__( | |
self, | |
num_embeddings=1028, | |
embedding_dim=1024, | |
num_quantization_layers=8, | |
): | |
super().__init__() | |
self.embeddings = nn.ModuleList( | |
[ | |
nn.Embedding(num_embeddings, embedding_dim) | |
for _ in range(num_quantization_layers) | |
] | |
) | |
# initialize embeddings | |
for i in range(num_quantization_layers): | |
self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02) | |
self._is_hf_initialized = True # disable automatic init | |
def forward(self, input_ids): | |
"""Input: [num_quant, B, T] -> Output: [B, T, H]""" | |
num_quant, B, T = input_ids.shape | |
summed_embeddings = torch.zeros( | |
B, T, self.embeddings[0].embedding_dim, device=input_ids.device | |
) | |
for i in range(num_quant): | |
summed_embeddings += self.embeddings[i](input_ids[i]) | |
return summed_embeddings | |
class LlamaNARDecoderLayer(LlamaDecoderLayer): | |
def __init__(self, config: LlamaConfig, layer_idx: int): | |
"""Override to adaptive layer norm""" | |
super().__init__(config, layer_idx) # init attention, mlp, etc. | |
self.input_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps | |
) | |
self.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: Optional[bool] = False, | |
use_cache: Optional[bool] = False, | |
) -> Tuple[ | |
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] | |
]: | |
""" | |
Args: | |
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | |
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | |
""" | |
residual = hidden_states | |
hidden_states = self.input_layernorm(hidden_states) | |
# Self Attention | |
hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = residual + hidden_states | |
# Fully Connected | |
residual = hidden_states | |
hidden_states = self.post_attention_layernorm(hidden_states) | |
hidden_states = self.mlp(hidden_states) | |
hidden_states = residual + hidden_states | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (self_attn_weights,) | |
if use_cache: | |
outputs += (present_key_value,) | |
return outputs | |
class LlamaNAR(LlamaModel): | |
def __init__( | |
self, | |
hidden_size=1024, | |
num_heads=16, | |
num_layers=16, | |
config=LlamaConfig(0, 256, 1024, 1, 1), | |
): | |
super().__init__(config) | |
self.layers = nn.ModuleList( | |
[ | |
LlamaNARDecoderLayer( | |
config=LlamaConfig(hidden_size=hidden_size,num_attention_heads=num_heads,max_position_embeddings=4096,intermediate_size=hidden_size*4), | |
layer_idx=i, | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.norm = LlamaAdaptiveRMSNorm(hidden_size) | |
self.multi_embedding = MultiEmbedding( | |
num_quantization_layers=8, embedding_dim=hidden_size | |
) | |
self.post_init() | |
def _prepare_decoder_attention_mask( | |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length | |
): | |
# create noncausal mask | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
combined_attention_mask = None | |
def _expand_mask( | |
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None | |
): | |
""" | |
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
""" | |
bsz, src_len = mask.size() | |
tgt_len = tgt_len if tgt_len is not None else src_len | |
expanded_mask = ( | |
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
) | |
inverted_mask = 1.0 - expanded_mask | |
return inverted_mask.masked_fill( | |
inverted_mask.to(torch.bool), torch.finfo(dtype).min | |
) | |
if attention_mask is not None: | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
expanded_attn_mask = _expand_mask( | |
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] | |
).to(inputs_embeds.device) | |
combined_attention_mask = ( | |
expanded_attn_mask | |
if combined_attention_mask is None | |
else expanded_attn_mask + combined_attention_mask | |
) | |
return combined_attention_mask | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
length: Optional[torch.LongTensor] = None, | |
)-> Union[Tuple, BaseModelOutputWithPast]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, seq_length, num_quant = input_ids.shape | |
input_ids = input_ids.permute(2, 0, 1) # [num_quant, B, T] | |
inputs_embeds = self.multi_embedding(input_ids) | |
seq_length_with_past = seq_length | |
past_key_values_length = 0 | |
if past_key_values is not None: | |
past_key_values_length = past_key_values[0][0].shape[2] | |
seq_length_with_past = seq_length_with_past + past_key_values_length | |
if position_ids is None: | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
position_ids = torch.arange( | |
past_key_values_length, | |
seq_length + past_key_values_length, | |
dtype=torch.long, | |
device=device, | |
) | |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | |
else: | |
position_ids = position_ids.view(-1, seq_length).long() | |
# embed positions | |
if attention_mask is None: | |
attention_mask = torch.ones( | |
(batch_size, seq_length_with_past), | |
dtype=torch.bool, | |
device=inputs_embeds.device, | |
) | |
attention_mask = self._prepare_decoder_attention_mask( | |
attention_mask, | |
(batch_size, seq_length), | |
inputs_embeds, | |
past_key_values_length, | |
) | |
hidden_states = inputs_embeds | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
use_cache = False | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = () if use_cache else None | |
for idx, decoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
past_key_value = ( | |
past_key_values[idx] if past_key_values is not None else None | |
) | |
if self.gradient_checkpointing and self.training: | |
raise NotImplementedError | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
# None for past_key_value | |
return module(*inputs, output_attentions, None) | |
return custom_forward | |
layer_outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(decoder_layer), | |
hidden_states, | |
attention_mask, | |
position_ids, | |
None, | |
) | |
else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
hidden_states = self.norm(hidden_states) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = next_decoder_cache if use_cache else None | |
return hidden_states | |
class LlamaNAREmb(LlamaModel): | |
"""LlamaNAR model that works directly with embeddings input. | |
This variant of LlamaNAR takes pre-computed embeddings as input | |
instead of token IDs that need to be embedded. | |
""" | |
def __init__( | |
self, | |
hidden_size=1024, | |
num_heads=16, | |
num_layers=16, | |
config=LlamaConfig(0, 256, 1024, 1, 1), | |
): | |
super().__init__(config) | |
self.layers = nn.ModuleList( | |
[ | |
LlamaNARDecoderLayer( | |
config=LlamaConfig(hidden_size=hidden_size,num_attention_heads=num_heads,max_position_embeddings=4096,intermediate_size=hidden_size*4), | |
layer_idx=i, | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.norm = LlamaAdaptiveRMSNorm(hidden_size) | |
self.post_init() | |
def _prepare_decoder_attention_mask( | |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length | |
): | |
# create noncausal mask | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
combined_attention_mask = None | |
def _expand_mask( | |
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None | |
): | |
""" | |
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
""" | |
bsz, src_len = mask.size() | |
tgt_len = tgt_len if tgt_len is not None else src_len | |
expanded_mask = ( | |
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
) | |
inverted_mask = 1.0 - expanded_mask | |
return inverted_mask.masked_fill( | |
inverted_mask.to(torch.bool), torch.finfo(dtype).min | |
) | |
if attention_mask is not None: | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
expanded_attn_mask = _expand_mask( | |
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] | |
).to(inputs_embeds.device) | |
combined_attention_mask = ( | |
expanded_attn_mask | |
if combined_attention_mask is None | |
else expanded_attn_mask + combined_attention_mask | |
) | |
return combined_attention_mask | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
)-> torch.Tensor: | |
""" | |
Returns: | |
hidden_states: Tensor of shape (batch_size, sequence_length, hidden_size) | |
""" | |
if inputs_embeds is None: | |
raise ValueError("inputs_embeds must be provided for LlamaNAREmb") | |
if input_ids is not None: | |
warnings.warn("input_ids is ignored in LlamaNAREmb, use inputs_embeds instead") | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, seq_length, hidden_size = inputs_embeds.shape | |
seq_length_with_past = seq_length | |
past_key_values_length = 0 | |
if past_key_values is not None: | |
past_key_values_length = past_key_values[0][0].shape[2] | |
seq_length_with_past = seq_length_with_past + past_key_values_length | |
if position_ids is None: | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
position_ids = torch.arange( | |
past_key_values_length, | |
seq_length + past_key_values_length, | |
dtype=torch.long, | |
device=device, | |
) | |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | |
else: | |
position_ids = position_ids.view(-1, seq_length).long() | |
# embed positions | |
if attention_mask is None: | |
attention_mask = torch.ones( | |
(batch_size, seq_length_with_past), | |
dtype=torch.bool, | |
device=inputs_embeds.device, | |
) | |
attention_mask = self._prepare_decoder_attention_mask( | |
attention_mask, | |
(batch_size, seq_length), | |
inputs_embeds, | |
past_key_values_length, | |
) | |
hidden_states = inputs_embeds | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
use_cache = False | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = () if use_cache else None | |
for idx, decoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
past_key_value = ( | |
past_key_values[idx] if past_key_values is not None else None | |
) | |
if self.gradient_checkpointing and self.training: | |
raise NotImplementedError | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
# None for past_key_value | |
return module(*inputs, output_attentions, None) | |
return custom_forward | |
layer_outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(decoder_layer), | |
hidden_states, | |
attention_mask, | |
position_ids, | |
None, | |
) | |
else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
hidden_states = self.norm(hidden_states) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = next_decoder_cache if use_cache else None | |
return hidden_states | |
if __name__ == '__main__': | |
config = LlamaConfig(hidden_size=1024, num_attention_heads=8, num_hidden_layers=8) | |
model = LlamaNAR(config=config) | |
# 模拟输入数据 | |
batch_size = 2 | |
seq_length = 10 | |
n_q = 8 | |
input_ids = torch.randint(0, 1028, (batch_size, seq_length, n_q)) # 随机生成输入ID | |
inputs_embeds = torch.randn(batch_size, seq_length, config.hidden_size) # 随机生成输入嵌入 | |
attention_mask = torch.ones(batch_size, seq_length) # 所有位置可见 | |
length = torch.tensor([4,10]) # 输入长度 | |
# 前向传播 | |
hidden_states, class_out = model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_attentions=True, | |
output_hidden_states=True, | |
length=length | |
) | |
# 打印输出形状 | |
print("Hidden States Shape:", hidden_states.shape) # 输出隐藏状态形状 | |
print('Class output Shape:', class_out.shape) | |