Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import torch | |
| import torch.nn as nn | |
| from typing import List, Optional, Tuple, Union | |
| import transformers.models | |
| from transformers.models.qwen2.modeling_qwen2 import ( | |
| Qwen2RotaryEmbedding, | |
| Qwen2ForCausalLM, | |
| Qwen2MLP, | |
| Qwen2RMSNorm, | |
| apply_rotary_pos_emb, | |
| repeat_kv, | |
| _prepare_4d_causal_attention_mask_with_cache_position, | |
| ) | |
| from transformers.utils import logging | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from transformers.cache_utils import Cache, StaticCache, SlidingWindowCache | |
| from transformers.models.qwen2.configuration_qwen2 import Qwen2Config | |
| from .utils import InferTaskCode | |
| logger = logging.get_logger(__name__) | |
| _GPU_QWEN_TORCH_COMPILE = True | |
| # =================================================================== | |
| # =============================Attention============================= | |
| # =================================================================== | |
| class GPUQwen2Attention(nn.Module): | |
| """ | |
| Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer | |
| and "Generating Long Sequences with Sparse Transformers". | |
| """ | |
| def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): | |
| super().__init__() | |
| self.config = config | |
| self.layer_idx = layer_idx | |
| if layer_idx is None: | |
| logger.warning_once( | |
| f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " | |
| "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " | |
| "when creating this class." | |
| ) | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = self.hidden_size // self.num_heads | |
| self.num_key_value_heads = config.num_key_value_heads | |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
| self.max_position_embeddings = config.max_position_embeddings | |
| self.rope_theta = config.rope_theta | |
| self.is_causal = True | |
| self.attention_dropout = config.attention_dropout | |
| if (self.head_dim * self.num_heads) != self.hidden_size: | |
| raise ValueError( | |
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" | |
| f" and `num_heads`: {self.num_heads})." | |
| ) | |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) | |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) | |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | |
| self.rotary_emb = Qwen2RotaryEmbedding( | |
| self.head_dim, | |
| max_position_embeddings=self.max_position_embeddings, | |
| base=self.rope_theta, | |
| ) | |
| # Adapted from Qwen2Attention.forward | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Cache] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| bsz, q_len, _ = hidden_states.size() | |
| query_states = self.q_proj(hidden_states) | |
| key_states = self.k_proj(hidden_states) | |
| value_states = self.v_proj(hidden_states) | |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| # NOTE: RoPE return all embedding (to satisfy torch compile) | |
| cos, sin = self.rotary_emb(value_states, seq_len=past_key_value.get_max_length()) | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | |
| if past_key_value is not None: | |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models | |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
| key_states = repeat_kv(key_states, self.num_key_value_groups) | |
| value_states = repeat_kv(value_states, self.num_key_value_groups) | |
| causal_mask = attention_mask | |
| if attention_mask is not None: # no matter the length, we just slice it | |
| causal_mask = attention_mask[:, :, :, : past_key_value.get_max_length()] | |
| query_states = query_states.contiguous() | |
| key_states = key_states.contiguous() | |
| value_states = value_states.contiguous() | |
| # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment | |
| # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. | |
| # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. | |
| is_causal = True if causal_mask is None and q_len > 1 else False | |
| attn_output = torch.nn.functional.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=causal_mask, | |
| dropout_p=0.0, | |
| is_causal=is_causal, | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.view(bsz, q_len, self.hidden_size) | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output, None, past_key_value | |
| # =================================================================== | |
| # =============================Layer================================= | |
| # =================================================================== | |
| class GPUQwen2DecoderLayer(nn.Module): | |
| def __init__(self, config: Qwen2Config, layer_idx: int): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| if config.sliding_window and config._attn_implementation != "flash_attention_2": | |
| logger.warning_once( | |
| f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " | |
| "unexpected results may be encountered." | |
| ) | |
| self.self_attn = GPUQwen2Attention(config, layer_idx) | |
| self.mlp = Qwen2MLP(config) | |
| self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
| output_attentions: Optional[bool] = False, | |
| use_cache: Optional[bool] = False, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| **kwargs, | |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | |
| """ | |
| Args: | |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | |
| `(batch, sequence_length)` where padding elements are indicated by 0. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
| (see `past_key_values`). | |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): | |
| Indices depicting the position of the input sequence tokens in the sequence. | |
| kwargs (`dict`, *optional*): | |
| Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code | |
| into the model | |
| """ | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| # Self Attention | |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = residual + hidden_states | |
| # Fully Connected | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| outputs = (hidden_states,) | |
| if output_attentions: | |
| outputs += (self_attn_weights,) | |
| if use_cache: | |
| outputs += (present_key_value,) | |
| return outputs | |
| # =================================================================== | |
| # ========================Qwen2ForCausalLM=========================== | |
| # =================================================================== | |
| class InferQwen2ForCausalLM(Qwen2ForCausalLM): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.compile_forward = torch.compile(self.simplify_forward, dynamic=False, fullgraph=True) \ | |
| if _GPU_QWEN_TORCH_COMPILE else self.simplify_forward | |
| self.text_phase = True | |
| ''' | |
| NOTE: 重写原Qwen2ForCausalLM forward函数,torchair直接编译原函数在返回CausalLMOutputWithPast时会出现编译错误 | |
| ''' | |
| def simplify_forward(self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| ): | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| ) | |
| return outputs | |
| def forward(self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| do_compile = True | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| if past_key_values is not None: | |
| past_key_values.training = False | |
| # print(self.text_phase) | |
| if input_ids is not None: | |
| if self.text_phase: | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| else: | |
| inputs_embeds = self.speech_token_emded(input_ids) | |
| if torch.isin(input_ids, 151645).any(): | |
| self.text_phase = False | |
| input_ids = None | |
| if (inputs_embeds is not None and cache_position[0] == 0) or do_compile==False : | |
| # prefill branch | |
| outputs = self.simplify_forward(input_ids, | |
| attention_mask, | |
| position_ids, | |
| past_key_values, | |
| inputs_embeds, | |
| labels, | |
| use_cache, | |
| output_attentions, | |
| output_hidden_states, | |
| return_dict, | |
| cache_position) | |
| else: | |
| # decoding | |
| outputs = self.compile_forward(input_ids, | |
| attention_mask, | |
| position_ids, | |
| past_key_values, | |
| inputs_embeds, | |
| labels, | |
| use_cache, | |
| output_attentions, | |
| output_hidden_states, | |
| return_dict, | |
| cache_position) | |
| last_hidden_states = outputs.last_hidden_state | |
| if self.text_phase: | |
| logits = self.lm_head(last_hidden_states) | |
| else: | |
| logits = self.speech_head(last_hidden_states) | |
| logits = logits.float() | |
| return CausalLMOutputWithPast( | |
| loss=None, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| inputs_embeds=None, | |
| cache_position=None, | |
| position_ids=None, | |
| use_cache=True, | |
| **kwargs, | |
| ): | |
| """ | |
| Mainly add static cache support | |
| """ | |
| # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens | |
| # Exception 1: when passing input_embeds, input_ids may be missing entries | |
| # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here | |
| if past_key_values is not None: | |
| if inputs_embeds is not None: # Exception 1 | |
| input_ids = input_ids[:, -cache_position.shape[0] :] | |
| elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) | |
| input_ids = input_ids[:, cache_position] | |
| if attention_mask is not None and position_ids is None: | |
| # create position_ids on the fly for batch generation | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past_key_values: | |
| position_ids = position_ids[:, -input_ids.shape[1] :] | |
| # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, | |
| # as otherwise the input `position_ids` would have various stride during the decoding. | |
| # Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, | |
| # `position_ids` is already contiguous but with varying stride which retriggers a capture. | |
| position_ids = position_ids.clone(memory_format=torch.contiguous_format) | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is not None and cache_position[0] == 0: | |
| model_inputs = {"inputs_embeds": inputs_embeds} | |
| else: | |
| # NOTE: 与上述的position_ids相同,same as position_ids, for torch.compile and cuda graph | |
| input_ids = input_ids.clone(memory_format=torch.contiguous_format) | |
| model_inputs = {"input_ids": input_ids} | |
| if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: | |
| if inputs_embeds is not None and cache_position[0] == 0: | |
| # prefill phase, inputs_embeds has shape (B,S,H) | |
| batch_size, sequence_length = inputs_embeds.shape[0], inputs_embeds.shape[1] | |
| device = inputs_embeds.device | |
| else: | |
| # decdoing phase, input_ids has shape (B,S) | |
| batch_size, sequence_length = input_ids.shape | |
| device = input_ids.device | |
| dtype = self.lm_head.weight.dtype | |
| min_dtype = torch.finfo(dtype).min | |
| if inputs_embeds is not None and inputs_embeds.ndim == 2 or input_ids is not None and input_ids.size(-1) == 1: | |
| # we only expand attention mask in docoding mode | |
| attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( | |
| attention_mask, | |
| sequence_length=sequence_length, | |
| target_length=past_key_values.get_max_length(), | |
| dtype=dtype, | |
| device=device, | |
| min_dtype=min_dtype, | |
| cache_position=cache_position, | |
| batch_size=batch_size, | |
| ) | |
| model_inputs.update( | |
| { | |
| "position_ids": position_ids, | |
| "cache_position": cache_position, | |
| "past_key_values": past_key_values, | |
| "use_cache": use_cache, | |
| "attention_mask": attention_mask, | |
| "do_compile": kwargs['do_compile'], | |
| } | |
| ) | |
| return model_inputs | |
| # =================================================================== | |
| print("========================= DO Qwen2 PATCH ===========================") | |
| # =================================================================== | |
| transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel._supports_static_cache = True # enable static cache | |
| transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer = GPUQwen2DecoderLayer | |
| transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM = InferQwen2ForCausalLM |