import torch from torch import nn from typing import Optional, Tuple import math from src.model.modules.kv_cache import KVCache class GemmaConfig: def __init__( self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim=256, max_position_embeddings=8192, rms_norm_eps=1e-6, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, pad_token_id=None, **kwargs, ): super().__init__() self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.head_dim = head_dim self.num_key_value_heads = num_key_value_heads self.rms_norm_eps = rms_norm_eps self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.pad_token_id = pad_token_id class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.zeros(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()) # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = output * (1.0 + self.weight.float()) return output.type_as(x) class GemmaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim # it is set to the head_dim self.max_position_embeddings = max_position_embeddings self.base = base # Calculate the theta according to the formula theta_i = base^(2i/dim) where i = 0, 1, 2, ..., dim // 2 inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) ) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] self.inv_freq.to(x.device) # Copy the inv_freq tensor for batch in the sequence # inv_freq_expanded: [Batch_Size, Head_Dim // 2, 1] inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) # position_ids_expanded: [Batch_Size, 1, Seq_Len] position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type device_type = ( device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ) with torch.autocast(device_type=device_type, enabled=False): # Multiply each theta by the position (which is the argument of the sin and cos functions) # freqs: [Batch_Size, Head_Dim // 2, 1] @ [Batch_Size, 1, Seq_Len] --> [Batch_Size, Seq_Len, Head_Dim // 2] freqs = ( inv_freq_expanded.float() @ position_ids_expanded.float() ).transpose(1, 2) # emb: [Batch_Size, Seq_Len, Head_Dim] emb = torch.cat((freqs, freqs), dim=-1) # cos, sin: [Batch_Size, Seq_Len, Head_Dim] cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): # Build the [-x2, x1, -x4, x3, ...] tensor for the sin part of the positional encoding. x1 = x[..., : x.shape[-1] // 2] # Takes the first half of the last dimension x2 = x[..., x.shape[-1] // 2 :] # Takes the second half of the last dimension return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) # Add the head dimension sin = sin.unsqueeze(unsqueeze_dim) # Add the head dimension # Apply the formula (34) of the Rotary Positional Encoding paper. q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class GemmaMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) def forward(self, x): # Equivalent to: # y = self.gate_proj(x) # [Batch_Size, Seq_Len, Hidden_Size] -> [Batch_Size, Seq_Len, Intermediate_Size] # y = torch.gelu(y, approximate="tanh") # [Batch_Size, Seq_Len, Intermediate_Size] # j = self.up_proj(x) # [Batch_Size, Seq_Len, Hidden_Size] -> [Batch_Size, Seq_Len, Intermediate_Size] # z = y * j # [Batch_Size, Seq_Len, Intermediate_Size] # z = self.down_proj(z) # [Batch_Size, Seq_Len, Intermediate_Size] -> [Batch_Size, Seq_Len, Hidden_Size] return self.down_proj( nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x) ) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class GemmaAttention(nn.Module): def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True assert self.hidden_size % self.num_heads == 0 self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias ) self.rotary_emb = GemmaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, kv_cache: Optional[KVCache] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() # [Batch_Size, Seq_Len, Hidden_Size] # [Batch_Size, Seq_Len, Num_Heads_Q * Head_Dim] query_states = self.q_proj(hidden_states) # [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim] key_states = self.k_proj(hidden_states) # [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim] value_states = self.v_proj(hidden_states) # [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim] query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) # [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim] key_states = key_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) # [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim] value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) # [Batch_Size, Seq_Len, Head_Dim], [Batch_Size, Seq_Len, Head_Dim] cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) # [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim], [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim] query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if kv_cache is not None: key_states, value_states = kv_cache.update( key_states, value_states, self.layer_idx ) # Repeat the key and values to match the number of heads of the query key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # Perform the calculation as usual, Q * K^T / sqrt(head_dim). Shape: [Batch_Size, Num_Heads_Q, Seq_Len_Q, Seq_Len_KV] attn_weights = torch.matmul( query_states, key_states.transpose(2, 3) ) / math.sqrt(self.head_dim) assert attention_mask is not None attn_weights = attn_weights + attention_mask # Apply the softmax # [Batch_Size, Num_Heads_Q, Seq_Len_Q, Seq_Len_KV] attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) # Apply the dropout attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) # Multiply by the values. [Batch_Size, Num_Heads_Q, Seq_Len_Q, Seq_Len_KV] x [Batch_Size, Num_Heads_KV, Seq_Len_KV, Head_Dim] -> [Batch_Size, Num_Heads_Q, Seq_Len_Q, Head_Dim] attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) # Make sure the sequence length is the second dimension. # [Batch_Size, Num_Heads_Q, Seq_Len_Q, Head_Dim] -> [Batch_Size, Seq_Len_Q, Num_Heads_Q, Head_Dim] attn_output = attn_output.transpose(1, 2).contiguous() # Concatenate all the heads together. [Batch_Size, Seq_Len_Q, Num_Heads_Q, Head_Dim] -> [Batch_Size, Seq_Len_Q, Num_Heads_Q * Head_Dim] attn_output = attn_output.view(bsz, q_len, -1) # Multiply by W_o. [Batch_Size, Seq_Len_Q, Hidden_Size] attn_output = self.o_proj(attn_output) return attn_output, attn_weights class GemmaDecoderLayer(nn.Module): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) self.mlp = GemmaMLP(config) self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GemmaRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, kv_cache: Optional[KVCache] = None, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states # [Batch_Size, Seq_Len, Hidden_Size] hidden_states = self.input_layernorm(hidden_states) # [Batch_Size, Seq_Len, Hidden_Size] ( hidden_states, _, ) = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, kv_cache=kv_cache, ) # [Batch_Size, Seq_Len, Hidden_Size] hidden_states = residual + hidden_states # [Batch_Size, Seq_Len, Hidden_Size] residual = hidden_states # [Batch_Size, Seq_Len, Hidden_Size] hidden_states = self.post_attention_layernorm(hidden_states) # [Batch_Size, Seq_Len, Hidden_Size] hidden_states = self.mlp(hidden_states) # [Batch_Size, Seq_Len, Hidden_Size] hidden_states = residual + hidden_states return hidden_states class GemmaModel(nn.Module): def __init__(self, config: GemmaConfig): super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self): return self.embed_tokens # Ignore copy def forward( self, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, kv_cache: Optional[KVCache] = None, ) -> torch.FloatTensor: # [Batch_Size, Seq_Len, Hidden_Size] hidden_states = inputs_embeds # [Batch_Size, Seq_Len, Hidden_Size] normalizer = torch.tensor( self.config.hidden_size**0.5, dtype=hidden_states.dtype ) hidden_states = hidden_states * normalizer for decoder_layer in self.layers: # [Batch_Size, Seq_Len, Hidden_Size] hidden_states = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, kv_cache=kv_cache, ) # [Batch_Size, Seq_Len, Hidden_Size] hidden_states = self.norm(hidden_states) # [Batch_Size, Seq_Len, Hidden_Size] return hidden_states class GemmaForCausalLM(nn.Module): def __init__(self, config): super().__init__() self.config = config self.model = GemmaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def get_input_embeddings(self): return self.model.embed_tokens def tie_weights(self): self.lm_head.weight = self.model.embed_tokens.weight def forward( self, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, kv_cache: Optional[KVCache] = None, ) -> Tuple: # input_embeds: [Batch_Size, Seq_Len, Hidden_Size] # outputs: [Batch_Size, Seq_Len, Hidden_Size] outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, kv_cache=kv_cache, ) hidden_states = outputs logits = self.lm_head(hidden_states) logits = logits.float() return_data = { "logits": logits, } if kv_cache is not None: # Return the updated cache return_data["kv_cache"] = kv_cache return return_data