hudsongouge's picture
Update space
adf0368
import torch
import torch.nn as nn
import math
from .optimized_diffattn import MultiheadDiffAttn
# --- Tokenizer Definition ---
# Vocabulary: 256 bytes + IM_START_TOKEN + IM_END_TOKEN + <pad>
IM_START_TOKEN = "<|im_start|>"
IM_END_TOKEN = "<|im_end|>"
PAD_TOKEN = "<pad>"
SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN]
VOCAB_SIZE = 256 + len(SPECIAL_TOKENS)
# Create token to id mapping
token_to_id = {}
id_to_token = {}
for i in range(256):
token_to_id[bytes([i])] = i
id_to_token[i] = bytes([i])
for i, token_str in enumerate(SPECIAL_TOKENS):
token_id = 256 + i
token_to_id[token_str] = token_id
id_to_token[token_id] = token_str
PAD_ID = token_to_id[PAD_TOKEN]
IM_START_ID = token_to_id[IM_START_TOKEN]
IM_END_ID = token_to_id[IM_END_TOKEN]
class ByteTokenizer:
def __init__(self):
self.token_to_id = token_to_id
self.id_to_token = id_to_token
self.vocab_size = VOCAB_SIZE
self.pad_id = PAD_ID
self.im_start_id = IM_START_ID
self.im_end_id = IM_END_ID
def encode(self, text_bytes: bytes, add_special_tokens=True):
ids = [self.token_to_id[bytes([b])] for b in text_bytes]
if add_special_tokens:
return [self.im_start_id] + ids + [self.im_end_id]
return ids
def decode(self, ids: list[int]):
tokens = []
for i in ids:
token = self.id_to_token.get(i)
if token is None:
# Handle unknown token ID if necessary, or raise error
tokens.append(b"?") # Placeholder for unknown
elif isinstance(token, bytes):
tokens.append(token)
# Ignore special tokens for decoding to raw text, or handle as needed
return b"".join(tokens)
# --- RoPE Embeddings --- (Reused from previous script)
def get_rotary_embeddings(seq_len, dim_model, theta=10000.0):
if dim_model % 2 != 0:
raise ValueError(f"dim_model must be even, got {dim_model}")
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model)
)
angles = position * div_term
cos_emb = torch.cos(angles)
sin_emb = torch.sin(angles)
return cos_emb, sin_emb
# --- Model Definition ---
class FeedForward(nn.Module):
def __init__(self, embed_dim, hidden_dim, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, x):
return self.fc2(self.dropout(self.act(self.fc1(x))))
class DiffTransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1):
super().__init__()
self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout)
self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, rel_pos, attn_mask=None):
# Pre-norm
attn_out = self.attn(self.norm1(x), rel_pos, attn_mask)
x = x + self.dropout(attn_out)
ffn_out = self.ffn(self.norm2(x))
x = x + self.dropout(ffn_out)
return x
class DiffTransformerLLM(nn.Module):
def __init__(
self,
vocab_size,
embed_dim,
num_layers,
num_heads,
ffn_hidden_dim,
max_seq_len,
dropout=0.1,
):
super().__init__()
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
# Positional embeddings are handled by RoPE, so no separate nn.Embedding for positions
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList(
[
DiffTransformerBlock(
embed_dim, num_heads, depth, ffn_hidden_dim, dropout
)
for depth in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(embed_dim)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
# Tie weights
self.token_embeddings.weight = self.lm_head.weight
# RoPE precomputation
# The head_dim for MultiheadDiffAttn is embed_dim // num_heads // 2
self.rope_head_dim = embed_dim // num_heads // 2
cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim)
self.register_buffer("cos_emb", cos_emb, persistent=False)
self.register_buffer("sin_emb", sin_emb, persistent=False)
def forward(self, input_ids, attn_mask=None):
batch_size, seq_len = input_ids.shape
x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim)
x = self.dropout(x)
# Ensure RoPE embeddings are on the same device *and* dtype as activations
rel_pos = (
self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype),
self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype),
)
# Create causal attention mask if not provided
if attn_mask is None:
# Standard causal mask for autoregressive decoding
# MultiheadDiffAttn expects a mask where -inf indicates masked positions
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
diagonal=1,
)
else:
# If a custom mask is provided (e.g., for padding), ensure it's correctly formatted
# For MultiheadDiffAttn, 0 means attend, -inf means mask.
# Assuming input attn_mask is 1 for attend, 0 for mask (like Hugging Face)
# We need to convert it: (1 - attn_mask) * -inf
# However, MultiheadDiffAttn's internal mask logic might be sufficient if it handles padding.
# For simplicity, let's assume the provided attn_mask is already in the correct format if not None.
# If it's a padding mask (1 for real tokens, 0 for pad), we need to adapt it.
# Let's stick to causal mask for now, padding handled by loss_fn ignore_index.
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
diagonal=1,
)
for layer in self.layers:
x = layer(x, rel_pos, attn_mask=causal_mask)
x = self.norm_out(x)
logits = self.lm_head(x)
return logits
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)