OpenPeerLLM / src /modeling_openpeer.py
Mentors4EDU's picture
Upload 27 files
d79115c verified
raw
history blame
8.03 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.query = nn.Linear(config.hidden_size, config.hidden_size)
self.key = nn.Linear(config.hidden_size, config.hidden_size)
self.value = nn.Linear(config.hidden_size, config.hidden_size)
self.out = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.attention_dropout)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_length = hidden_states.shape[:2]
# Project queries, keys, and values
query_states = self.query(hidden_states)
key_states = self.key(hidden_states)
value_states = self.value(hidden_states)
# Reshape for multi-head attention
query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_size).transpose(1, 2)
key_states = key_states.view(batch_size, seq_length, self.num_heads, self.head_size).transpose(1, 2)
value_states = value_states.view(batch_size, seq_length, self.num_heads, self.head_size).transpose(1, 2)
# Calculate attention scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.head_size)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# Apply attention to values
context_layer = torch.matmul(attention_probs, value_states)
context_layer = context_layer.transpose(1, 2).contiguous()
# Reshape back
context_layer = context_layer.view(batch_size, seq_length, self.hidden_size)
context_layer = self.out(context_layer)
return context_layer, attention_probs
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size)
self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size)
self.act = nn.GELU()
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.mlp = MLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self-attention
attention_layernorm_out = self.input_layernorm(hidden_states)
attention_output, attention_probs = self.attention(
attention_layernorm_out,
attention_mask=attention_mask,
head_mask=head_mask,
)
attention_output = self.dropout(attention_output)
# Add & norm
attention_output = attention_output + hidden_states
# MLP
mlp_layernorm_out = self.post_attention_layernorm(attention_output)
mlp_output = self.mlp(mlp_layernorm_out)
# Add & norm
layer_output = mlp_output + attention_output
return layer_output, attention_probs
class OpenPeerLLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Token embeddings
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# Transformer layers
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
# Final layer norm
self.final_layernorm = nn.LayerNorm(config.hidden_size)
# Output head
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights
self.init_weights()
def init_weights(self):
"""Initialize weights with small random values"""
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights for different layer types"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, ...]:
batch_size, seq_length = input_ids.shape
# Create position IDs
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Get embeddings
inputs_embeds = self.word_embeddings(input_ids)
position_embeds = self.position_embeddings(position_ids)
# Combine embeddings
hidden_states = inputs_embeds + position_embeds
# Create attention mask if needed
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=hidden_states.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min
# Process through transformer layers
all_attentions = []
for layer in self.layers:
hidden_states, attention_probs = layer(hidden_states, attention_mask)
all_attentions.append(attention_probs)
# Final layer norm
hidden_states = self.final_layernorm(hidden_states)
# Get logits
logits = self.lm_head(hidden_states)
# Calculate loss if labels provided
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return {
"loss": loss,
"logits": logits,
"hidden_states": hidden_states,
"attentions": all_attentions,
}