from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from .mlp import LlamaMLP from .config import LlamaConfig from .rms_norm import LlamaRMSNorm from .attention import LlamaAttention from .diff_attn import DifferentialAttention from .tensor_prod_attn import CausalTensorProductSelfAttn class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, layer_num): super().__init__() self.self_attn = CausalTensorProductSelfAttn(config) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states