Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import math | |
from .optimized_diffattn import MultiheadDiffAttn | |
# --- Tokenizer Definition --- | |
# Vocabulary: 256 bytes + IM_START_TOKEN + IM_END_TOKEN + <pad> | |
IM_START_TOKEN = "<|im_start|>" | |
IM_END_TOKEN = "<|im_end|>" | |
PAD_TOKEN = "<pad>" | |
SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN] | |
VOCAB_SIZE = 256 + len(SPECIAL_TOKENS) | |
# Create token to id mapping | |
token_to_id = {} | |
id_to_token = {} | |
for i in range(256): | |
token_to_id[bytes([i])] = i | |
id_to_token[i] = bytes([i]) | |
for i, token_str in enumerate(SPECIAL_TOKENS): | |
token_id = 256 + i | |
token_to_id[token_str] = token_id | |
id_to_token[token_id] = token_str | |
PAD_ID = token_to_id[PAD_TOKEN] | |
IM_START_ID = token_to_id[IM_START_TOKEN] | |
IM_END_ID = token_to_id[IM_END_TOKEN] | |
class ByteTokenizer: | |
def __init__(self): | |
self.token_to_id = token_to_id | |
self.id_to_token = id_to_token | |
self.vocab_size = VOCAB_SIZE | |
self.pad_id = PAD_ID | |
self.im_start_id = IM_START_ID | |
self.im_end_id = IM_END_ID | |
def encode(self, text_bytes: bytes, add_special_tokens=True): | |
ids = [self.token_to_id[bytes([b])] for b in text_bytes] | |
if add_special_tokens: | |
return [self.im_start_id] + ids + [self.im_end_id] | |
return ids | |
def decode(self, ids: list[int]): | |
tokens = [] | |
for i in ids: | |
token = self.id_to_token.get(i) | |
if token is None: | |
# Handle unknown token ID if necessary, or raise error | |
tokens.append(b"?") # Placeholder for unknown | |
elif isinstance(token, bytes): | |
tokens.append(token) | |
# Ignore special tokens for decoding to raw text, or handle as needed | |
return b"".join(tokens) | |
# --- RoPE Embeddings --- (Reused from previous script) | |
def get_rotary_embeddings(seq_len, dim_model, theta=10000.0): | |
if dim_model % 2 != 0: | |
raise ValueError(f"dim_model must be even, got {dim_model}") | |
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp( | |
torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model) | |
) | |
angles = position * div_term | |
cos_emb = torch.cos(angles) | |
sin_emb = torch.sin(angles) | |
return cos_emb, sin_emb | |
# --- Model Definition --- | |
class FeedForward(nn.Module): | |
def __init__(self, embed_dim, hidden_dim, dropout=0.1): | |
super().__init__() | |
self.fc1 = nn.Linear(embed_dim, hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, embed_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.act = nn.GELU() | |
def forward(self, x): | |
return self.fc2(self.dropout(self.act(self.fc1(x)))) | |
class DiffTransformerBlock(nn.Module): | |
def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1): | |
super().__init__() | |
self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout) | |
self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout) | |
self.norm1 = nn.LayerNorm(embed_dim) | |
self.norm2 = nn.LayerNorm(embed_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, rel_pos, attn_mask=None): | |
# Pre-norm | |
attn_out = self.attn(self.norm1(x), rel_pos, attn_mask) | |
x = x + self.dropout(attn_out) | |
ffn_out = self.ffn(self.norm2(x)) | |
x = x + self.dropout(ffn_out) | |
return x | |
class DiffTransformerLLM(nn.Module): | |
def __init__( | |
self, | |
vocab_size, | |
embed_dim, | |
num_layers, | |
num_heads, | |
ffn_hidden_dim, | |
max_seq_len, | |
dropout=0.1, | |
): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.max_seq_len = max_seq_len | |
self.token_embeddings = nn.Embedding(vocab_size, embed_dim) | |
# Positional embeddings are handled by RoPE, so no separate nn.Embedding for positions | |
self.dropout = nn.Dropout(dropout) | |
self.layers = nn.ModuleList( | |
[ | |
DiffTransformerBlock( | |
embed_dim, num_heads, depth, ffn_hidden_dim, dropout | |
) | |
for depth in range(num_layers) | |
] | |
) | |
self.norm_out = nn.LayerNorm(embed_dim) | |
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False) | |
# Tie weights | |
self.token_embeddings.weight = self.lm_head.weight | |
# RoPE precomputation | |
# The head_dim for MultiheadDiffAttn is embed_dim // num_heads // 2 | |
self.rope_head_dim = embed_dim // num_heads // 2 | |
cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim) | |
self.register_buffer("cos_emb", cos_emb, persistent=False) | |
self.register_buffer("sin_emb", sin_emb, persistent=False) | |
def forward(self, input_ids, attn_mask=None): | |
batch_size, seq_len = input_ids.shape | |
x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim) | |
x = self.dropout(x) | |
# Ensure RoPE embeddings are on the same device *and* dtype as activations | |
rel_pos = ( | |
self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype), | |
self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype), | |
) | |
# Create causal attention mask if not provided | |
if attn_mask is None: | |
# Standard causal mask for autoregressive decoding | |
# MultiheadDiffAttn expects a mask where -inf indicates masked positions | |
causal_mask = torch.triu( | |
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"), | |
diagonal=1, | |
) | |
else: | |
# If a custom mask is provided (e.g., for padding), ensure it's correctly formatted | |
# For MultiheadDiffAttn, 0 means attend, -inf means mask. | |
# Assuming input attn_mask is 1 for attend, 0 for mask (like Hugging Face) | |
# We need to convert it: (1 - attn_mask) * -inf | |
# However, MultiheadDiffAttn's internal mask logic might be sufficient if it handles padding. | |
# For simplicity, let's assume the provided attn_mask is already in the correct format if not None. | |
# If it's a padding mask (1 for real tokens, 0 for pad), we need to adapt it. | |
# Let's stick to causal mask for now, padding handled by loss_fn ignore_index. | |
causal_mask = torch.triu( | |
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"), | |
diagonal=1, | |
) | |
for layer in self.layers: | |
x = layer(x, rel_pos, attn_mask=causal_mask) | |
x = self.norm_out(x) | |
logits = self.lm_head(x) | |
return logits | |
def count_parameters(self): | |
return sum(p.numel() for p in self.parameters() if p.requires_grad) | |