# Modified based on https://github.com/lm-sys/FastChat import warnings from typing import Optional, Tuple import torch from torch import nn import transformers from einops import rearrange from flash_attn import __version__ as flash_attn_version from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import ( flash_attn_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func ) from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, rotate_half from flash_attn.bert_padding import unpad_input, pad_input import math group_size_ratio = 1/4 sft_group_size = 8192 def forward_flashattn( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel attention_mask: [bsz, q_len] """ if not self.training: warnings.warn("This function should be used just for training as it may exhibit reduced inference performance. For inference, please use forward_flashattn_inference.") if output_attentions: warnings.warn( "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." ) bsz, q_len, _ = hidden_states.size() query_states = ( self.q_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) key_states = ( self.k_proj(hidden_states) .view(bsz, q_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) value_states = ( self.v_proj(hidden_states) .view(bsz, q_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) # Past Key value support if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # Flash attention codes from # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py # transform the data into the format required by flash attention qkv = torch.stack( [query_states, key_states, value_states], dim=2 ) # [bsz, nh, 3, q_len, hd] qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] # We have disabled _prepare_decoder_attention_mask in LlamaModel # the attention_mask should be the same as the key_padding_mask key_padding_mask = attention_mask.repeat(2, 1) nheads = qkv.shape[-2] # shift if q_len % 4096 == 0: group_size = int(q_len * group_size_ratio) else: group_size = sft_group_size qkv = qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim) x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype) cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2 cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).min cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1) cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1) cu_q_lens = cu_q_lens[cu_q_lens >= 0] x_unpad = rearrange( x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2 ) output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True ) output = rearrange( pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len ), "b s (h d) -> b s h d", h=nheads // 2, ) output = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads, self.head_dim) return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value def forward_flashattn_full( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel attention_mask: [bsz, q_len] """ if output_attentions: warnings.warn( "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." ) bsz, q_len, _ = hidden_states.size() query_states = ( self.q_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) key_states = ( self.k_proj(hidden_states) .view(bsz, q_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) value_states = ( self.v_proj(hidden_states) .view(bsz, q_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) # Past Key value support if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # Flash attention codes from # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py # transform the data into the format required by flash attention qkv = torch.stack( [query_states, key_states, value_states], dim=2 ) # [bsz, nh, 3, q_len, hd] qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] # We have disabled _prepare_decoder_attention_mask in LlamaModel # the attention_mask should be the same as the key_padding_mask key_padding_mask = attention_mask nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange( x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads ) output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = rearrange( pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len ), "b s (h d) -> b s h d", h=nheads, ) output = output.reshape(bsz, q_len, self.num_heads, self.head_dim) return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value def forward_noflashattn( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() group_size = int(q_len * group_size_ratio) if q_len % group_size > 0: raise ValueError("q_len %d should be divisible by group size %d."%(q_len, group_size)) num_group = q_len // group_size if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # shift def shift(qkv, bsz, q_len, group_size, num_heads, head_dim): qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2) qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2) return qkv query_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim) key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim) value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz * num_group, self.num_heads, group_size, group_size): raise ValueError( f"Attention weights should be of size {(bsz * num_group, self.num_heads, group_size, group_size)}, but is" f" {attn_weights.size()}" ) attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) if attention_mask is not None: if attention_mask.size() != (bsz * num_group, 1, group_size, group_size): raise ValueError( f"Attention mask should be of size {(bsz * num_group, 1, group_size, group_size)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp16 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to(query_states.dtype) #torch.float32 attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz * num_group, self.num_heads, group_size, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz * num_group, self.num_heads, group_size, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) # shift back attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # [bsz, seq_len] return attention_mask def apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids): gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] gather_indices = gather_indices.repeat( 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] ) bsz = gather_indices.shape[0] cos, sin = ( torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) for x in cos_sin ) q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) return q, k def forward_flashattn_inference( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: warnings.warn( "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." ) bsz, q_len, _ = hidden_states.size() kv_heads = getattr(self, "num_key_value_heads", self.num_heads) q, k, v = ( op(hidden_states).view(bsz, q_len, nh, self.head_dim) for op, nh in ( (self.q_proj, self.num_heads), (self.k_proj, kv_heads), (self.v_proj, kv_heads), ) ) # shape: (b, s, num_heads, head_dim) kv_seq_len = k.shape[1] past_kv_len = 0 if past_key_value is not None: past_kv_len = past_key_value[0].shape[2] kv_seq_len += past_kv_len cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) q, k = apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids) if past_key_value is not None: assert ( flash_attn_version >= "2.1.0" ), "past_key_value support requires flash-attn >= 2.1.0" # reuse k, v k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None if attention_mask is None: output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( bsz, q_len, -1 ) else: q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) # We can skip concat and call unpad twice but seems better to call unpad only once. kv, _, cu_k_lens, max_k = unpad_input( torch.stack((k, v), dim=2), attention_mask ) output_unpad = flash_attn_varlen_kvpacked_func( q, kv, cu_q_lens, cu_k_lens, max_s, max_k, 0.0, softmax_scale=None, causal=True, ) output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) output = pad_input(output_unpad, indices, bsz, q_len) return self.o_proj(output), None, past_key_value def _prepare_decoder_attention_mask_inference( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # [bsz, seq_len] if past_key_values_length > 0 and attention_mask is not None: attention_mask = torch.cat( ( torch.full( (input_shape[0], past_key_values_length), True, dtype=attention_mask.dtype, device=attention_mask.device, ), attention_mask, ), dim=-1, ) if attention_mask is not None and torch.all(attention_mask): return None # This uses the faster call when training with full samples return attention_mask def replace_llama_attn(use_flash_attn=True, use_full=False, inference=False): if use_flash_attn: cuda_major, cuda_minor = torch.cuda.get_device_capability() if cuda_major < 8: warnings.warn( "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" ) if inference: transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_inference transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_inference else: transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( _prepare_decoder_attention_mask ) transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_full if use_full else forward_flashattn else: transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_noflashattn