Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import torch | |
import transformers | |
from transformers.cache_utils import * | |
from transformers.models.llama.modeling_llama import * | |
from .modules.inf_llm import InfLLMGenerator, inf_llm_forward | |
from .modules.minference_forward import ( | |
gather_last_q_vertical_slash_topk_v4, | |
gather_last_q_vertical_slash_topk_vllm, | |
init_minference_parameters, | |
minference_forward, | |
minference_kv_cache_cpu_forward, | |
minference_vllm_forward, | |
minference_with_snapkv_forward, | |
search_pattern, | |
sum_all_diagonal_matrix, | |
) | |
from .ops.streaming_kernel import stream_llm_forward | |
class RotaryEmbeddingESM(torch.nn.Module): | |
""" | |
Rotary position embeddings based on those in | |
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation | |
matrices which depend on their relative positions. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
base: Union[int, float] = 10000, | |
distance_scale: Union[int, float] = 1, | |
): | |
super().__init__() | |
self.base = base | |
self.distance_scale = distance_scale | |
# Generate and save the inverse frequency buffer (non trainable) | |
inv_freq = 1.0 / ( | |
base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim) | |
) | |
self.register_buffer("inv_freq", inv_freq, persistent=False) | |
self._seq_len_cached = -1 | |
self._cos_cached = None | |
self._sin_cached = None | |
def rotate_half(self, x): | |
x1, x2 = x.chunk(2, dim=-1) | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(self, x, length, right, cos, sin): | |
dtype = x.dtype | |
if cos.dim() == 2: | |
cos = cos[right - length : right, :] | |
sin = sin[right - length : right, :] | |
elif cos.dim() == 3: | |
cos = cos[:, right - length : right, :] | |
sin = sin[:, right - length : right, :] | |
elif cos.dim() == 4: | |
cos = cos[:, :, right - length : right, :] | |
sin = sin[:, :, right - length : right, :] | |
return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype) | |
def _update_cos_sin_tables(self, x, seq_dim): | |
seq_len = x.size(seq_dim) | |
if seq_len > self._seq_len_cached: | |
self._seq_len_cached = seq_len | |
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | |
freqs = torch.outer(t * self.distance_scale, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
if x.dim() == 2: | |
self._cos_cached = emb.cos() | |
self._sin_cached = emb.sin() | |
elif x.dim() == 3: | |
self._cos_cached = emb.cos()[None, :, :] | |
self._sin_cached = emb.sin()[None, :, :] | |
elif x.dim() == 4: | |
self._cos_cached = emb.cos()[None, None, :, :] | |
self._sin_cached = emb.sin()[None, None, :, :] | |
return self._cos_cached, self._sin_cached | |
def _update_cos_sin_tables_len(self, seq_len, device, dim=None): | |
if seq_len > self._seq_len_cached: | |
if dim is None: | |
assert self._cos_cached is not None | |
dim = self._cos_cached.dim() | |
self._seq_len_cached = seq_len | |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq) | |
freqs = torch.outer(t * self.distance_scale, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
if dim == 2: | |
self._cos_cached = emb.cos() | |
self._sin_cached = emb.sin() | |
elif dim == 3: | |
self._cos_cached = emb.cos()[None, :, :] | |
self._sin_cached = emb.sin()[None, :, :] | |
elif dim == 4: | |
self._cos_cached = emb.cos()[None, None, :, :] | |
self._sin_cached = emb.sin()[None, None, :, :] | |
return self._cos_cached, self._sin_cached | |
def apply_rotary_pos_emb_one_angle(self, x: torch.Tensor, index): | |
dtype = x.dtype | |
cos, sin = self._update_cos_sin_tables_len(index, x.device) | |
if cos.dim() == 2: | |
cos = cos[index - 1 : index, :] | |
sin = sin[index - 1 : index, :] | |
elif cos.dim() == 3: | |
cos = cos[:, index - 1 : index, :] | |
sin = sin[:, index - 1 : index, :] | |
elif cos.dim() == 4: | |
cos = cos[:, :, index - 1 : index, :] | |
sin = sin[:, :, index - 1 : index, :] | |
return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype) | |
def forward( | |
self, q: torch.Tensor, k: torch.Tensor, seq_dim=-2 | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
self._cos_cached, self._sin_cached = self._update_cos_sin_tables( | |
k, seq_dim=seq_dim | |
) | |
return ( | |
self.apply_rotary_pos_emb( | |
q, q.size(seq_dim), k.size(seq_dim), self._cos_cached, self._sin_cached | |
), | |
self.apply_rotary_pos_emb( | |
k, k.size(seq_dim), k.size(seq_dim), self._cos_cached, self._sin_cached | |
), | |
) | |
ATTN_FORWRAD = { | |
"streaming": stream_llm_forward, | |
"minference": minference_forward, | |
"inf_llm": inf_llm_forward, | |
} | |
def huggingface_forward(forward): | |
def hf_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask=None, | |
position_ids=None, | |
past_key_value=None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
**kwargs, | |
): | |
assert not output_attentions | |
ret = forward( | |
self, | |
hidden_states, | |
hidden_states, | |
position_ids, | |
use_cache, | |
past_key_value, | |
self.q_proj, | |
self.k_proj, | |
self.v_proj, | |
self.o_proj, | |
self.head_dim, | |
self.num_heads, | |
self.num_key_value_heads, | |
) | |
if use_cache: | |
o, pkv = ret | |
else: | |
o = ret | |
pkv = None | |
return o, None, pkv | |
return hf_forward | |
def hf_437_prepare_inputs_for_generation( | |
self, | |
input_ids, | |
past_key_values=None, | |
attention_mask=None, | |
inputs_embeds=None, | |
**kwargs, | |
): | |
if past_key_values is not None: | |
if isinstance(past_key_values, transformers.cache_utils.Cache): | |
cache_length = past_key_values.get_seq_length() | |
past_length = past_key_values.seen_tokens | |
max_cache_length = past_key_values.get_max_length() | |
else: | |
cache_length = past_length = past_key_values[0][0].shape[2] | |
max_cache_length = None | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
if ( | |
max_cache_length is not None | |
and attention_mask is not None | |
and cache_length + input_ids.shape[1] > max_cache_length | |
): | |
attention_mask = attention_mask[:, -max_cache_length:] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
def prepare_inputs_for_generation( | |
self, | |
input_ids, | |
past_key_values=None, | |
attention_mask=None, | |
inputs_embeds=None, | |
cache_position=None, | |
**kwargs, | |
): | |
# With static cache, the `past_key_values` is None | |
# TODO joao: standardize interface for the different Cache classes and remove of this if | |
has_static_cache = False | |
if past_key_values is None: | |
past_key_values = getattr( | |
getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None | |
) | |
has_static_cache = past_key_values is not None | |
past_length = 0 | |
if past_key_values is not None: | |
if isinstance(past_key_values, transformers.cache_utils.Cache): | |
past_length = ( | |
cache_position[0] | |
if cache_position is not None | |
else past_key_values.get_seq_length() | |
) | |
max_cache_length = ( | |
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) | |
if past_key_values.get_max_length() is not None | |
else None | |
) | |
cache_length = ( | |
past_length | |
if max_cache_length is None | |
else torch.min(max_cache_length, past_length) | |
) | |
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects | |
else: | |
# cache_length = past_length = past_key_values[0][0].shape[2] | |
cache_length = past_length = cache_position[0] | |
max_cache_length = None | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
if ( | |
max_cache_length is not None | |
and attention_mask is not None | |
and cache_length + input_ids.shape[1] > max_cache_length | |
): | |
attention_mask = attention_mask[:, -max_cache_length:] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise | |
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 | |
# TODO: use `next_tokens` directly instead. | |
model_inputs = {"input_ids": input_ids.contiguous()} | |
input_length = ( | |
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] | |
) | |
if cache_position is None: | |
cache_position = torch.arange( | |
past_length, past_length + input_length, device=input_ids.device | |
) | |
else: | |
cache_position = cache_position[-input_length:] | |
if has_static_cache: | |
past_key_values = None | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"cache_position": cache_position, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
def prepare_inputs_for_generation_snapkv( | |
self, | |
input_ids, | |
past_key_values=None, | |
attention_mask=None, | |
inputs_embeds=None, | |
**kwargs, | |
): | |
if past_key_values is None: # [SnapKV] | |
for layer in self.model.layers: | |
layer.self_attn.kv_seq_len = 0 | |
if past_key_values is not None: | |
if isinstance(past_key_values, Cache): | |
cache_length = past_key_values.get_seq_length() | |
past_length = past_key_values.seen_tokens | |
max_cache_length = past_key_values.get_max_length() | |
else: | |
# cache_length = past_length = past_key_values[0][0].shape[2] | |
# max_cache_length = None | |
cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len | |
max_cache_length = None | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
if ( | |
max_cache_length is not None | |
and attention_mask is not None | |
and cache_length + input_ids.shape[1] > max_cache_length | |
): | |
attention_mask = attention_mask[:, -max_cache_length:] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
def _prepare_decoder_attention_mask_inference( | |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length | |
): | |
# [bsz, seq_len] | |
if past_key_values_length > 0 and attention_mask is not None: | |
attention_mask = torch.cat( | |
( | |
torch.full( | |
(input_shape[0], past_key_values_length), | |
True, | |
dtype=attention_mask.dtype, | |
device=attention_mask.device, | |
), | |
attention_mask, | |
), | |
dim=-1, | |
) | |
if attention_mask is not None and torch.all(attention_mask): | |
return None # This uses the faster call when training with full samples | |
return attention_mask | |
def forward_llama_decoder_layer( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: Optional[bool] = False, | |
use_cache: Optional[bool] = False, | |
padding_mask: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | |
""" | |
Args: | |
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | |
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | |
""" | |
residual = hidden_states.clone() | |
batch, seq_len, embed_dim = hidden_states.shape | |
for start_idx in range(0, seq_len, 32000): | |
end_idx = min(seq_len, start_idx + 32000) | |
hidden_states[:, start_idx:end_idx, :] = self.input_layernorm( | |
hidden_states[:, start_idx:end_idx, :] | |
) | |
# Self Attention | |
hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
padding_mask=padding_mask, | |
) | |
hidden_states = residual + hidden_states | |
# Fully Connected | |
for start_idx in range(0, seq_len, 32000): | |
end_idx = min(seq_len, start_idx + 32000) | |
part_hidden_states = hidden_states[:, start_idx:end_idx, :].clone() | |
part_hidden_states = self.post_attention_layernorm(part_hidden_states) | |
part_hidden_states = self.mlp(part_hidden_states) | |
hidden_states[:, start_idx:end_idx, :] += part_hidden_states | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (self_attn_weights,) | |
if use_cache: | |
outputs += (present_key_value,) | |
return outputs | |
def forward_llama_model( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPast]: | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
# retrieve input_ids and inputs_embeds | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError( | |
"You cannot specify both input_ids and inputs_embeds at the same time" | |
) | |
elif input_ids is not None: | |
batch_size, seq_length = input_ids.shape[:2] | |
elif inputs_embeds is not None: | |
batch_size, seq_length = inputs_embeds.shape[:2] | |
else: | |
raise ValueError("You have to specify either input_ids or inputs_embeds") | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
logger.warning_once( | |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
) | |
use_cache = False | |
seq_length_with_past = seq_length | |
past_key_values_length = 0 | |
if use_cache: | |
use_legacy_cache = not isinstance(past_key_values, Cache) | |
if use_legacy_cache: | |
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
past_key_values_length = past_key_values.get_usable_length(seq_length) | |
seq_length_with_past = seq_length_with_past + past_key_values_length | |
if position_ids is None: | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
position_ids = torch.arange( | |
past_key_values_length, | |
seq_length + past_key_values_length, | |
dtype=torch.long, | |
device=device, | |
) | |
position_ids = position_ids.unsqueeze(0) | |
if inputs_embeds is None: | |
inputs_embeds = self.embed_tokens(input_ids) | |
if attention_mask is None: | |
attention_mask = torch.ones( | |
(batch_size, seq_length_with_past), | |
dtype=torch.bool, | |
device=inputs_embeds.device, | |
) | |
padding_mask = None | |
else: | |
if 0 in attention_mask: | |
padding_mask = attention_mask | |
else: | |
padding_mask = None | |
attention_mask = self._prepare_decoder_attention_mask( | |
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length | |
) | |
# embed positions | |
hidden_states = inputs_embeds | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = None | |
for decoder_layer in self.layers: | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
if self.gradient_checkpointing and self.training: | |
layer_outputs = self._gradient_checkpointing_func( | |
decoder_layer.__call__, | |
hidden_states, | |
attention_mask, | |
position_ids, | |
past_key_values, | |
output_attentions, | |
use_cache, | |
) | |
else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_values, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache = layer_outputs[2 if output_attentions else 1] | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
batch, seq_len, embed_dim = hidden_states.shape | |
for start_idx in range(0, seq_len, 32000): | |
end_idx = min(seq_len, start_idx + 32000) | |
hidden_states[:, start_idx:end_idx, :] = self.norm( | |
hidden_states[:, start_idx:end_idx, :] | |
) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = None | |
if use_cache: | |
next_cache = ( | |
next_decoder_cache.to_legacy_cache() | |
if use_legacy_cache | |
else next_decoder_cache | |
) | |
if not return_dict: | |
return tuple( | |
v | |
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] | |
if v is not None | |
) | |
return BaseModelOutputWithPast( | |
last_hidden_state=hidden_states, | |
past_key_values=next_cache, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attns, | |
) | |
def forward_llama_for_causal_lm( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
# assert labels is not None | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
torch.cuda.empty_cache() | |
hidden_states = outputs[0] | |
if labels is not None: | |
loss_fct = CrossEntropyLoss(reduction="sum") | |
valid_seq_len = input_ids.shape[-1] - 1 | |
valid_seq_len_slide_win = torch.sum(labels[:, 1:] >= 0).item() | |
# print("valid_seq_len_slide_win", valid_seq_len) | |
loss = 0.0 | |
for start_idx in range(0, valid_seq_len, 32000): | |
end_idx = min(start_idx + 32000, valid_seq_len) | |
shift_logits = self.lm_head( | |
hidden_states[..., start_idx:end_idx, :] | |
).float() | |
shift_labels = labels[..., start_idx + 1 : end_idx + 1].contiguous() | |
# Flatten the tokens | |
shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
loss += loss_fct(shift_logits, shift_labels) | |
loss /= valid_seq_len_slide_win | |
logits = None | |
else: | |
if self.config.to_dict().get("is_ppl", False): | |
logits = self.lm_head(hidden_states) | |
else: | |
logits = self.lm_head(hidden_states[:, -1:]).float() | |
loss = None | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
) | |
def minference_patch(model, config): | |
from transformers import LlamaForCausalLM | |
if config.kv_cache_cpu: | |
return minference_patch_kv_cache_cpu(model) | |
if config.use_snapkv: | |
return minference_patch_with_snapkv(model) | |
Attention = model.model.layers[0].self_attn.__class__ | |
Model = model.model.__class__ | |
DecoderLayer = model.model.layers[0].__class__ | |
forward = minference_forward() | |
def update_module(m): | |
if isinstance(m, Attention): | |
m.init_minference_parameters = init_minference_parameters.__get__( | |
m, Attention | |
) | |
m.gather_last_q_vertical_slash_topk_v4 = ( | |
gather_last_q_vertical_slash_topk_v4.__get__(m, Attention) | |
) | |
m.forward = forward.__get__(m, Attention) | |
if isinstance(m, DecoderLayer): | |
m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer) | |
model.apply(update_module) | |
model.prepare_inputs_for_generation = hf_437_prepare_inputs_for_generation.__get__( | |
model, model.__class__ | |
) | |
model.model._use_sdpa = False | |
model.model._prepare_decoder_attention_mask = ( | |
_prepare_decoder_attention_mask_inference.__get__( | |
model.model, model.model.__class__ | |
) | |
) | |
model.model.forward = forward_llama_model.__get__( | |
model.model, model.model.__class__ | |
) | |
model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__) | |
model.has_patch = True | |
print("Patched model for minference..") | |
return model | |
def minference_patch_kv_cache_cpu(model): | |
from transformers import LlamaForCausalLM | |
transformers.cache_utils.DynamicCache.update = cpu_cache_update | |
transformers.cache_utils.DynamicCache.get = cpu_cache_get | |
Attention = model.model.layers[0].self_attn.__class__ | |
Model = model.model.__class__ | |
DecoderLayer = model.model.layers[0].__class__ | |
forward = minference_kv_cache_cpu_forward() | |
def update_module(m): | |
if isinstance(m, Attention): | |
m.init_minference_parameters = init_minference_parameters.__get__( | |
m, Attention | |
) | |
m.gather_last_q_vertical_slash_topk_v4 = ( | |
gather_last_q_vertical_slash_topk_v4.__get__(m, Attention) | |
) | |
m.forward = forward.__get__(m, Attention) | |
if isinstance(m, DecoderLayer): | |
m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer) | |
model.apply(update_module) | |
model.prepare_inputs_for_generation = hf_437_prepare_inputs_for_generation.__get__( | |
model, model.__class__ | |
) | |
model.model._use_sdpa = False | |
model.model._prepare_decoder_attention_mask = ( | |
_prepare_decoder_attention_mask_inference.__get__( | |
model.model, model.model.__class__ | |
) | |
) | |
model.model.forward = forward_llama_model.__get__( | |
model.model, model.model.__class__ | |
) | |
model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__) | |
print("Patched model for MInference load KV Cache to CPU.") | |
return model | |
def minference_patch_with_snapkv(model): | |
from transformers import LlamaForCausalLM | |
Attention = model.model.layers[0].self_attn.__class__ | |
Model = model.model.__class__ | |
DecoderLayer = model.model.layers[0].__class__ | |
forward = minference_with_snapkv_forward() | |
def update_module(m): | |
if isinstance(m, Attention): | |
m.init_minference_parameters = init_minference_parameters.__get__( | |
m, Attention | |
) | |
m.gather_last_q_vertical_slash_topk_v4 = ( | |
gather_last_q_vertical_slash_topk_v4.__get__(m, Attention) | |
) | |
m.forward = forward.__get__(m, Attention) | |
if isinstance(m, DecoderLayer): | |
m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer) | |
model.apply(update_module) | |
model.prepare_inputs_for_generation = prepare_inputs_for_generation_snapkv.__get__( | |
model, model.__class__ | |
) | |
model.model._use_sdpa = False | |
model.model._prepare_decoder_attention_mask = ( | |
_prepare_decoder_attention_mask_inference.__get__( | |
model.model, model.model.__class__ | |
) | |
) | |
model.model.forward = forward_llama_model.__get__( | |
model.model, model.model.__class__ | |
) | |
model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__) | |
print("Patched model for minference with SanpKV..") | |
return model | |
def llama_model_forward_vllm( | |
self, | |
input_ids: Optional[torch.Tensor], | |
positions: torch.Tensor, | |
kv_caches: List[torch.Tensor], | |
attn_metadata, | |
inputs_embeds: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
if inputs_embeds is not None: | |
hidden_states = inputs_embeds | |
else: | |
hidden_states = self.get_input_embeddings(input_ids) | |
residual = None | |
for i in range(len(self.layers)): | |
layer = self.layers[i] | |
hidden_states, residual = layer( | |
positions, | |
hidden_states, | |
kv_caches[i], | |
attn_metadata, | |
residual, | |
layer_idx=i, | |
) | |
hidden_states, _ = self.norm(hidden_states, residual) | |
return hidden_states | |
def llama_layer_forward_vllm( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
kv_cache: torch.Tensor, | |
attn_metadata, | |
residual: Optional[torch.Tensor], | |
layer_idx: int, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# Self Attention | |
if residual is None: | |
residual = hidden_states | |
hidden_states = self.input_layernorm(hidden_states) | |
else: | |
hidden_states, residual = self.input_layernorm(hidden_states, residual) | |
hidden_states = self.self_attn( | |
positions=positions, | |
hidden_states=hidden_states, | |
kv_cache=kv_cache, | |
attn_metadata=attn_metadata, | |
layer_idx=layer_idx, | |
) | |
# Fully Connected | |
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) | |
hidden_states = self.mlp(hidden_states) | |
return hidden_states, residual | |
def llama_attn_forward_vllm( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
kv_cache: torch.Tensor, | |
attn_metadata, | |
layer_idx: int, | |
) -> torch.Tensor: | |
qkv, _ = self.qkv_proj(hidden_states) | |
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | |
q, k = self.rotary_emb(positions, q, k) | |
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx) | |
output, _ = self.o_proj(attn_output) | |
return output | |
def vllm_attn_forward( | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
kv_cache: Optional[torch.Tensor], | |
attn_metadata, | |
kv_scale: float = 1.0, | |
layer_idx: int = 0, | |
) -> torch.Tensor: | |
return self.impl.forward( | |
query, key, value, kv_cache, attn_metadata, kv_scale, layer_idx | |
) | |
def minference_patch_vllm( | |
llm, | |
config_file, | |
): | |
from vllm.attention import Attention | |
from vllm.model_executor.models.llama import ( | |
LlamaAttention, | |
LlamaDecoderLayer, | |
LlamaForCausalLM, | |
LlamaModel, | |
) | |
config = json.load(open(config_file)) | |
attn_forward = minference_vllm_forward(config) | |
def update_module(m): | |
if isinstance(m, Attention): | |
m.forward = vllm_attn_forward.__get__(m, Attention) | |
m = m.impl | |
m_cls = m.__class__ | |
m.gather_last_q_vertical_slash_topk_vllm = ( | |
gather_last_q_vertical_slash_topk_vllm.__get__(m, m_cls) | |
) | |
m.forward = attn_forward.__get__(m, m_cls) | |
if isinstance(m, LlamaDecoderLayer): | |
m.forward = llama_layer_forward_vllm.__get__(m, LlamaDecoderLayer) | |
if isinstance(m, LlamaModel): | |
m.forward = llama_model_forward_vllm.__get__(m, LlamaModel) | |
if isinstance(m, LlamaAttention): | |
m.forward = llama_attn_forward_vllm.__get__(m, LlamaAttention) | |
llm.llm_engine.model_executor.driver_worker.model_runner.model.apply(update_module) | |
print("Patched model for minference with VLLM..") | |
return llm | |
def patch_hf( | |
model, | |
attn_type: str = "inf_llm", | |
attn_kwargs: dict = {}, | |
base=None, | |
distance_scale=None, | |
**kwargs, | |
): | |
attn_kwargs.update(kwargs) | |
# This approach lacks scalability and will be refactored. | |
from transformers import LlamaForCausalLM, MistralForCausalLM, Qwen2ForCausalLM | |
from transformers.models.llama.modeling_llama import ( | |
BaseModelOutputWithPast, | |
LlamaAttention, | |
LlamaModel, | |
) | |
from transformers.models.mistral.modeling_mistral import ( | |
MistralAttention, | |
MistralModel, | |
) | |
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2Model | |
def model_forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask=None, | |
position_ids=None, | |
past_key_values=None, | |
inputs_embeds=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
*args, | |
**kwargs, | |
): | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
# retrieve input_ids and inputs_embeds | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError( | |
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" | |
) | |
elif input_ids is not None: | |
batch_size, seq_length = input_ids.shape | |
elif inputs_embeds is not None: | |
batch_size, seq_length, _ = inputs_embeds.shape | |
else: | |
raise ValueError( | |
"You have to specify either decoder_input_ids or decoder_inputs_embeds" | |
) | |
if inputs_embeds is None: | |
inputs_embeds = self.embed_tokens(input_ids) | |
if hasattr(self, "config") and hasattr(self.config, "scale_emb"): | |
inputs_embeds = inputs_embeds * self.config.scale_emb | |
if use_cache: | |
pkv = tuple() | |
else: | |
pkv = None | |
hidden_states = inputs_embeds | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
for i, decoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=self.position_bias, | |
past_key_value=( | |
past_key_values[i] if past_key_values is not None else None | |
), | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
_cache = layer_outputs[2 if output_attentions else 1] | |
pkv = pkv + (_cache,) | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
# hidden_states = self.norm(hidden_states) | |
for start_idx in range(0, hidden_states.size(1), 32000): | |
end_idx = min(hidden_states.size(1), start_idx + 32000) | |
hidden_states[:, start_idx:end_idx, :] = self.norm( | |
hidden_states[:, start_idx:end_idx, :] | |
) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
if not return_dict: | |
return tuple( | |
v | |
for v in [hidden_states, pkv, all_hidden_states, all_self_attns] | |
if v is not None | |
) | |
return BaseModelOutputWithPast( | |
last_hidden_state=hidden_states, | |
past_key_values=pkv, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attns, | |
) | |
forward = huggingface_forward(ATTN_FORWRAD[attn_type](**attn_kwargs)) | |
if isinstance(model, LlamaForCausalLM): | |
Attention = model.model.layers[0].self_attn.__class__ | |
Model = model.model.__class__ | |
elif isinstance(model, MistralForCausalLM): | |
Attention = model.model.layers[0].self_attn.__class__ | |
Model = model.model.__class__ | |
elif isinstance(model, Qwen2ForCausalLM): | |
Attention = model.model.layers[0].self_attn.__class__ | |
Model = model.model.__class__ | |
elif model.__class__.__name__ == "MiniCPMForCausalLM": | |
Attention = model.model.layers[0].self_attn.__class__ | |
Model = model.model.__class__ | |
elif model.__class__.__name__ == "Phi3ForCausalLM": | |
Attention = model.model.layers[0].self_attn.__class__ | |
Model = model.model.__class__ | |
else: | |
raise ValueError("Only supports llama, mistral and qwen2 models.") | |
hf_rope = model.model.layers[0].self_attn.rotary_emb | |
base = base if base is not None else hf_rope.base | |
distance_scale = distance_scale if distance_scale is not None else 1.0 | |
rope = RotaryEmbeddingESM(hf_rope.dim, base, distance_scale) | |
model.model.position_bias = rope | |
model.model.hf_position_bias = hf_rope | |
def set_forward(m): | |
if isinstance(m, Attention): | |
m._old_forward = m.forward | |
m.forward = forward.__get__(m, Attention) | |
model.apply(set_forward) | |
model._old_prepare_inputs_for_generation = model.prepare_inputs_for_generation | |
model.prepare_inputs_for_generation = prepare_inputs_for_generation.__get__( | |
model, model.__class__ | |
) | |
model.model._old_forward = model.model.forward | |
model.model.forward = model_forward.__get__(model.model, Model) | |
if attn_type == "inf_llm": | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
model.config._name_or_path | |
) | |
model = InfLLMGenerator(model, tokenizer) | |
print("Patched model ...") | |
return model | |
def fp8_cache_update( | |
self, | |
key_states: torch.Tensor, | |
value_states: torch.Tensor, | |
layer_idx: int, | |
cache_kwargs: Optional[Dict[str, Any]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | |
Parameters: | |
key_states (`torch.Tensor`): | |
The new key states to cache. | |
value_states (`torch.Tensor`): | |
The new value states to cache. | |
layer_idx (`int`): | |
The index of the layer to cache the states for. | |
cache_kwargs (`Dict[str, Any]`, `optional`): | |
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. | |
Return: | |
A tuple containing the updated key and value states. | |
""" | |
# Update the number of seen tokens | |
if layer_idx == 0: | |
self.seen_tokens += key_states.shape[-2] | |
# Update the cache | |
if len(self.key_cache) <= layer_idx: | |
self.key_cache.append(key_states.to(torch.float8_e5m2)) | |
self.value_cache.append(value_states.to(torch.float8_e5m2)) | |
else: | |
self.key_cache[layer_idx] = torch.cat( | |
[self.key_cache[layer_idx], key_states.to(torch.float8_e5m2)], dim=-2 | |
) | |
self.value_cache[layer_idx] = torch.cat( | |
[self.value_cache[layer_idx], value_states.to(torch.float8_e5m2)], dim=-2 | |
) | |
return self.key_cache[layer_idx].to(key_states.dtype), self.value_cache[ | |
layer_idx | |
].to(key_states.dtype) | |
def cpu_cache_update( | |
self, | |
key_states: torch.Tensor, | |
value_states: torch.Tensor, | |
layer_idx: int, | |
cache_kwargs: Optional[Dict[str, Any]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
if layer_idx == 0: | |
if "_seen_tokens" in self.__dict__: | |
self._seen_tokens += key_states.shape[-2] | |
else: | |
self.seen_tokens += key_states.shape[-2] | |
# Update the cache | |
if len(self.key_cache) <= layer_idx: | |
self.key_cache.append(key_states.cpu()) | |
self.value_cache.append(value_states.cpu()) | |
else: | |
self.key_cache[layer_idx] = torch.cat( | |
[self.key_cache[layer_idx], key_states.cpu()], dim=-2 | |
) | |
self.value_cache[layer_idx] = torch.cat( | |
[self.value_cache[layer_idx], value_states.cpu()], dim=-2 | |
) | |
def cpu_cache_get( | |
self, | |
key_states: torch.Tensor, | |
value_states: torch.Tensor, | |
layer_idx: int, | |
head_idx: int, | |
cache_kwargs: Optional[Dict[str, Any]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
if layer_idx == 0: | |
if "_seen_tokens" in self.__dict__: | |
self._seen_tokens += key_states.shape[-2] | |
else: | |
self.seen_tokens += key_states.shape[-2] | |
# Update the cache | |
if len(self.key_cache) <= layer_idx: | |
return key_states, value_states | |
else: | |
key_states = torch.cat( | |
[self.key_cache[layer_idx][:, head_idx : head_idx + 1].cuda(), key_states], | |
dim=-2, | |
) | |
value_states = torch.cat( | |
[ | |
self.value_cache[layer_idx][:, head_idx : head_idx + 1].cuda(), | |
value_states, | |
], | |
dim=-2, | |
) | |
return key_states, value_states | |