Spaces:
Sleeping
Sleeping
File size: 6,786 Bytes
adf0368 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import torch
import torch.nn as nn
import math
from .optimized_diffattn import MultiheadDiffAttn
# --- Tokenizer Definition ---
# Vocabulary: 256 bytes + IM_START_TOKEN + IM_END_TOKEN + <pad>
IM_START_TOKEN = "<|im_start|>"
IM_END_TOKEN = "<|im_end|>"
PAD_TOKEN = "<pad>"
SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN]
VOCAB_SIZE = 256 + len(SPECIAL_TOKENS)
# Create token to id mapping
token_to_id = {}
id_to_token = {}
for i in range(256):
token_to_id[bytes([i])] = i
id_to_token[i] = bytes([i])
for i, token_str in enumerate(SPECIAL_TOKENS):
token_id = 256 + i
token_to_id[token_str] = token_id
id_to_token[token_id] = token_str
PAD_ID = token_to_id[PAD_TOKEN]
IM_START_ID = token_to_id[IM_START_TOKEN]
IM_END_ID = token_to_id[IM_END_TOKEN]
class ByteTokenizer:
def __init__(self):
self.token_to_id = token_to_id
self.id_to_token = id_to_token
self.vocab_size = VOCAB_SIZE
self.pad_id = PAD_ID
self.im_start_id = IM_START_ID
self.im_end_id = IM_END_ID
def encode(self, text_bytes: bytes, add_special_tokens=True):
ids = [self.token_to_id[bytes([b])] for b in text_bytes]
if add_special_tokens:
return [self.im_start_id] + ids + [self.im_end_id]
return ids
def decode(self, ids: list[int]):
tokens = []
for i in ids:
token = self.id_to_token.get(i)
if token is None:
# Handle unknown token ID if necessary, or raise error
tokens.append(b"?") # Placeholder for unknown
elif isinstance(token, bytes):
tokens.append(token)
# Ignore special tokens for decoding to raw text, or handle as needed
return b"".join(tokens)
# --- RoPE Embeddings --- (Reused from previous script)
def get_rotary_embeddings(seq_len, dim_model, theta=10000.0):
if dim_model % 2 != 0:
raise ValueError(f"dim_model must be even, got {dim_model}")
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model)
)
angles = position * div_term
cos_emb = torch.cos(angles)
sin_emb = torch.sin(angles)
return cos_emb, sin_emb
# --- Model Definition ---
class FeedForward(nn.Module):
def __init__(self, embed_dim, hidden_dim, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, x):
return self.fc2(self.dropout(self.act(self.fc1(x))))
class DiffTransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1):
super().__init__()
self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout)
self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, rel_pos, attn_mask=None):
# Pre-norm
attn_out = self.attn(self.norm1(x), rel_pos, attn_mask)
x = x + self.dropout(attn_out)
ffn_out = self.ffn(self.norm2(x))
x = x + self.dropout(ffn_out)
return x
class DiffTransformerLLM(nn.Module):
def __init__(
self,
vocab_size,
embed_dim,
num_layers,
num_heads,
ffn_hidden_dim,
max_seq_len,
dropout=0.1,
):
super().__init__()
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
# Positional embeddings are handled by RoPE, so no separate nn.Embedding for positions
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList(
[
DiffTransformerBlock(
embed_dim, num_heads, depth, ffn_hidden_dim, dropout
)
for depth in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(embed_dim)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
# Tie weights
self.token_embeddings.weight = self.lm_head.weight
# RoPE precomputation
# The head_dim for MultiheadDiffAttn is embed_dim // num_heads // 2
self.rope_head_dim = embed_dim // num_heads // 2
cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim)
self.register_buffer("cos_emb", cos_emb, persistent=False)
self.register_buffer("sin_emb", sin_emb, persistent=False)
def forward(self, input_ids, attn_mask=None):
batch_size, seq_len = input_ids.shape
x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim)
x = self.dropout(x)
# Ensure RoPE embeddings are on the same device *and* dtype as activations
rel_pos = (
self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype),
self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype),
)
# Create causal attention mask if not provided
if attn_mask is None:
# Standard causal mask for autoregressive decoding
# MultiheadDiffAttn expects a mask where -inf indicates masked positions
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
diagonal=1,
)
else:
# If a custom mask is provided (e.g., for padding), ensure it's correctly formatted
# For MultiheadDiffAttn, 0 means attend, -inf means mask.
# Assuming input attn_mask is 1 for attend, 0 for mask (like Hugging Face)
# We need to convert it: (1 - attn_mask) * -inf
# However, MultiheadDiffAttn's internal mask logic might be sufficient if it handles padding.
# For simplicity, let's assume the provided attn_mask is already in the correct format if not None.
# If it's a padding mask (1 for real tokens, 0 for pad), we need to adapt it.
# Let's stick to causal mask for now, padding handled by loss_fn ignore_index.
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
diagonal=1,
)
for layer in self.layers:
x = layer(x, rel_pos, attn_mask=causal_mask)
x = self.norm_out(x)
logits = self.lm_head(x)
return logits
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|