Delete lumenspark.py
Browse files- lumenspark.py +0 -294
lumenspark.py
DELETED
@@ -1,294 +0,0 @@
|
|
1 |
-
from transformers import PretrainedConfig, PreTrainedModel, GenerationConfig
|
2 |
-
from torch import nn
|
3 |
-
import torch
|
4 |
-
|
5 |
-
# ----------------------------
|
6 |
-
# Define Lumenspark Configuration
|
7 |
-
# ----------------------------
|
8 |
-
|
9 |
-
class LumensparkConfig(PretrainedConfig):
|
10 |
-
"""
|
11 |
-
Configuration class for the Lumenspark model.
|
12 |
-
Stores model hyperparameters like sequence length, embedding dimension, number of layers, and others.
|
13 |
-
"""
|
14 |
-
model_type = "lumenspark"
|
15 |
-
|
16 |
-
def __init__(
|
17 |
-
self,
|
18 |
-
seq_length=512,
|
19 |
-
vocab_size=50257,
|
20 |
-
embed_dim=512,
|
21 |
-
depth=6,
|
22 |
-
heads=4,
|
23 |
-
dropout=0.1,
|
24 |
-
k=128,
|
25 |
-
**kwargs
|
26 |
-
):
|
27 |
-
super().__init__(**kwargs)
|
28 |
-
self.vocab_size = vocab_size
|
29 |
-
self.embed_dim = embed_dim
|
30 |
-
self.depth = depth
|
31 |
-
self.heads = heads
|
32 |
-
self.seq_length = seq_length
|
33 |
-
self.dropout = dropout
|
34 |
-
self.k = k
|
35 |
-
|
36 |
-
def to_dict(self):
|
37 |
-
"""
|
38 |
-
Converts the configuration parameters to a dictionary format.
|
39 |
-
Useful for saving the configuration or inspecting model settings.
|
40 |
-
"""
|
41 |
-
output = super().to_dict()
|
42 |
-
output.update({
|
43 |
-
"vocab_size": self.vocab_size,
|
44 |
-
"embed_dim": self.embed_dim,
|
45 |
-
"depth": self.depth,
|
46 |
-
"heads": self.heads,
|
47 |
-
"seq_length": self.seq_length,
|
48 |
-
"dropout": self.dropout,
|
49 |
-
"k": self.k
|
50 |
-
})
|
51 |
-
return output
|
52 |
-
|
53 |
-
# ----------------------------
|
54 |
-
# Low-Rank Linear Layer Implementation
|
55 |
-
# ----------------------------
|
56 |
-
|
57 |
-
class LowRankLinear(nn.Module):
|
58 |
-
"""
|
59 |
-
A low-rank linear layer that factorizes a standard linear layer into two smaller ones.
|
60 |
-
This allows for reduced parameter count and faster computation.
|
61 |
-
"""
|
62 |
-
def __init__(self, in_features, out_features, rank):
|
63 |
-
super().__init__()
|
64 |
-
self.U = nn.Linear(in_features, rank, bias=False)
|
65 |
-
self.V = nn.Linear(rank, out_features, bias=False)
|
66 |
-
|
67 |
-
def forward(self, x):
|
68 |
-
"""
|
69 |
-
Forward pass through two low-rank linear layers (U and V).
|
70 |
-
"""
|
71 |
-
return self.V(self.U(x))
|
72 |
-
|
73 |
-
# ----------------------------
|
74 |
-
# Lumenspark Self-Attention Implementation
|
75 |
-
# ----------------------------
|
76 |
-
|
77 |
-
class LumensparkSelfAttention(nn.Module):
|
78 |
-
"""
|
79 |
-
Custom self-attention mechanism for the Lumenspark model.
|
80 |
-
It includes low-rank approximations to reduce computational cost and memory usage.
|
81 |
-
"""
|
82 |
-
def __init__(self, embed_dim, max_seq_len, proj_dim, num_heads, head_dim=None, single_kv_head=True, shared_kv=True, dropout=0.1):
|
83 |
-
super().__init__()
|
84 |
-
assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads'
|
85 |
-
|
86 |
-
self.max_seq_len = max_seq_len
|
87 |
-
self.proj_dim = proj_dim
|
88 |
-
self.num_heads = num_heads
|
89 |
-
self.embed_dim = embed_dim
|
90 |
-
|
91 |
-
# Set the dimensionality of each attention head
|
92 |
-
self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads
|
93 |
-
|
94 |
-
# Query transformation: Low-rank projection followed by linear layer
|
95 |
-
self.query_transform = nn.Sequential(
|
96 |
-
LowRankLinear(embed_dim, embed_dim // 2, rank=32),
|
97 |
-
nn.Linear(embed_dim // 2, self.head_dim * num_heads)
|
98 |
-
)
|
99 |
-
kv_size = self.head_dim if single_kv_head else (self.head_dim * num_heads)
|
100 |
-
self.key_transform = nn.Linear(embed_dim, kv_size, bias=False)
|
101 |
-
self.key_proj = nn.Parameter(self.initialize_proj_matrix(max_seq_len, proj_dim))
|
102 |
-
|
103 |
-
# Shared key-value projection option
|
104 |
-
self.shared_kv = shared_kv
|
105 |
-
if not shared_kv:
|
106 |
-
self.value_transform = nn.Linear(embed_dim, kv_size, bias=False)
|
107 |
-
self.value_proj = nn.Parameter(self.initialize_proj_matrix(max_seq_len, proj_dim))
|
108 |
-
|
109 |
-
self.dropout_layer = nn.Dropout(dropout) # Dropout for regularization
|
110 |
-
self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim)
|
111 |
-
|
112 |
-
def initialize_proj_matrix(self, rows, cols):
|
113 |
-
"""
|
114 |
-
Initializes the projection matrix used to reduce the sequence length for key/value pairs.
|
115 |
-
"""
|
116 |
-
return torch.nn.init.xavier_uniform_(torch.zeros(rows, cols))
|
117 |
-
|
118 |
-
def forward(self, inputs, context_data=None, **kwargs):
|
119 |
-
"""
|
120 |
-
Forward pass of the self-attention mechanism.
|
121 |
-
"""
|
122 |
-
batch_size, seq_len, _ = inputs.shape
|
123 |
-
kv_seq_len = inputs.shape[1] if context_data is None else context_data.shape[1]
|
124 |
-
assert kv_seq_len <= self.max_seq_len, f'Key/value sequence length exceeds the max sequence length: {self.max_seq_len}'
|
125 |
-
|
126 |
-
# Apply transformations to queries, keys, and values
|
127 |
-
queries = self.query_transform(inputs)
|
128 |
-
kv_inputs = inputs if context_data is None else context_data
|
129 |
-
keys = self.key_transform(kv_inputs)
|
130 |
-
values = self.value_transform(kv_inputs) if not self.shared_kv else keys
|
131 |
-
|
132 |
-
# Apply projection matrix to keys and values
|
133 |
-
keys = torch.einsum('bnd,nk->bkd', keys, self.key_proj[:kv_seq_len])
|
134 |
-
values = torch.einsum('bnd,nk->bkd', values, self.value_proj[:kv_seq_len] if not self.shared_kv else self.key_proj[:kv_seq_len])
|
135 |
-
|
136 |
-
# Reshape queries, keys, and values for multi-head attention
|
137 |
-
queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
138 |
-
keys = keys.view(batch_size, self.proj_dim, -1, self.head_dim).transpose(1, 2)
|
139 |
-
values = values.view(batch_size, self.proj_dim, -1, self.head_dim).transpose(1, 2)
|
140 |
-
|
141 |
-
# Compute scaled dot-product attention
|
142 |
-
attention_scores = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (self.head_dim ** -0.5)
|
143 |
-
attention_weights = attention_scores.softmax(dim=-1)
|
144 |
-
attention_weights = self.dropout_layer(attention_weights)
|
145 |
-
|
146 |
-
# Apply attention weights to values and compute output
|
147 |
-
attention_output = torch.einsum('bhnk,bhkd->bhnd', attention_weights, values)
|
148 |
-
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
|
149 |
-
return self.output_transform(attention_output)
|
150 |
-
|
151 |
-
# ----------------------------
|
152 |
-
# RMSNorm Layer Implementation
|
153 |
-
# ----------------------------
|
154 |
-
|
155 |
-
class RMSNorm(nn.Module):
|
156 |
-
"""
|
157 |
-
Root Mean Square Layer Normalization (RMSNorm) without affine transformation.
|
158 |
-
This normalization technique scales the input without centering it.
|
159 |
-
"""
|
160 |
-
def __init__(self, embed_dim, eps=1e-6):
|
161 |
-
super().__init__()
|
162 |
-
self.eps = eps # Small constant to prevent division by zero
|
163 |
-
self.scale = nn.Parameter(torch.ones(embed_dim)) # Scaling factor
|
164 |
-
|
165 |
-
def forward(self, x):
|
166 |
-
"""
|
167 |
-
Forward pass through the RMSNorm layer.
|
168 |
-
"""
|
169 |
-
norm_x = x.norm(2, dim=-1, keepdim=True)
|
170 |
-
rms_x = norm_x / (x.size(-1) ** 0.5) # Root mean square normalization
|
171 |
-
return self.scale * x / (rms_x + self.eps)
|
172 |
-
|
173 |
-
# ----------------------------
|
174 |
-
# Define Lumenspark Model Wrapper
|
175 |
-
# ----------------------------
|
176 |
-
|
177 |
-
class LumensparkModel(PreTrainedModel):
|
178 |
-
"""
|
179 |
-
Lumenspark model with factorized linear projections, multi-head attention, and RMSNorm for normalization.
|
180 |
-
This model is specifically designed to handle long sequences efficiently.
|
181 |
-
"""
|
182 |
-
config_class = LumensparkConfig
|
183 |
-
|
184 |
-
def __init__(self, config):
|
185 |
-
super().__init__(config)
|
186 |
-
self.config = config
|
187 |
-
|
188 |
-
# Token and position embeddings
|
189 |
-
self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
|
190 |
-
self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim)
|
191 |
-
|
192 |
-
# Lumenspark transformer encoder layers
|
193 |
-
self.layers = nn.ModuleList([nn.ModuleDict({
|
194 |
-
"attn": LumensparkSelfAttention(
|
195 |
-
embed_dim=config.embed_dim,
|
196 |
-
max_seq_len=config.seq_length,
|
197 |
-
num_heads=config.heads,
|
198 |
-
proj_dim=config.k,
|
199 |
-
head_dim=config.embed_dim // config.heads,
|
200 |
-
single_kv_head=True,
|
201 |
-
shared_kv=True,
|
202 |
-
dropout=config.dropout
|
203 |
-
),
|
204 |
-
"norm1": RMSNorm(config.embed_dim),
|
205 |
-
"ffn": nn.Sequential(
|
206 |
-
LowRankLinear(config.embed_dim, config.embed_dim // 2, rank=32),
|
207 |
-
nn.GELU(),
|
208 |
-
LowRankLinear(config.embed_dim // 2, config.embed_dim, rank=32),
|
209 |
-
nn.Dropout(config.dropout)
|
210 |
-
),
|
211 |
-
"norm2": RMSNorm(config.embed_dim)
|
212 |
-
}) for _ in range(config.depth)])
|
213 |
-
|
214 |
-
# Feed-forward output layer
|
215 |
-
self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
|
216 |
-
self.dropout = nn.Dropout(config.dropout)
|
217 |
-
|
218 |
-
# Initialize model weights
|
219 |
-
self.init_weights()
|
220 |
-
|
221 |
-
# Create GenerationConfig instance for text generation
|
222 |
-
self.generation_config = GenerationConfig(
|
223 |
-
max_length=128,
|
224 |
-
min_length=16,
|
225 |
-
)
|
226 |
-
|
227 |
-
def generate(self, input_ids, max_length=128, min_length=16, temperature=1.0, top_k=50, top_p=0.95, do_sample=True):
|
228 |
-
"""
|
229 |
-
Text generation method that handles auto-regressive generation.
|
230 |
-
"""
|
231 |
-
generated_tokens = input_ids
|
232 |
-
|
233 |
-
for _ in range(max_length - input_ids.size(1)):
|
234 |
-
outputs = self.forward(input_ids=generated_tokens)
|
235 |
-
logits = outputs["logits"][:, -1, :]
|
236 |
-
logits = logits / temperature
|
237 |
-
|
238 |
-
# Apply top-k and top-p sampling to select the next token
|
239 |
-
if do_sample:
|
240 |
-
filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
241 |
-
probs = torch.softmax(filtered_logits, dim=-1)
|
242 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
243 |
-
else:
|
244 |
-
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
245 |
-
|
246 |
-
# Append the generated token
|
247 |
-
generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
|
248 |
-
|
249 |
-
# Stop if the EOS token is generated
|
250 |
-
if next_token.item() == self.config.eos_token_id:
|
251 |
-
break
|
252 |
-
|
253 |
-
return generated_tokens
|
254 |
-
|
255 |
-
def forward(self, input_ids, attention_mask=None, labels=None):
|
256 |
-
"""
|
257 |
-
Forward pass of the model. If labels are provided, the loss is also computed.
|
258 |
-
"""
|
259 |
-
batch_size, seq_length = input_ids.size()
|
260 |
-
|
261 |
-
# Generate position ids
|
262 |
-
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
|
263 |
-
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
|
264 |
-
|
265 |
-
# Embed tokens and positions
|
266 |
-
token_embeddings = self.token_embedding(input_ids)
|
267 |
-
position_embeddings = self.position_embedding(position_ids)
|
268 |
-
|
269 |
-
# Combine token and position embeddings
|
270 |
-
embeddings = token_embeddings + position_embeddings
|
271 |
-
embeddings = self.dropout(embeddings)
|
272 |
-
|
273 |
-
# Pass through each transformer layer
|
274 |
-
for layer in self.layers:
|
275 |
-
embeddings = layer["attn"](embeddings) + embeddings
|
276 |
-
embeddings = layer["norm1"](embeddings)
|
277 |
-
|
278 |
-
ffn_out = layer["ffn"](embeddings)
|
279 |
-
embeddings = ffn_out + embeddings
|
280 |
-
embeddings = layer["norm2"](embeddings)
|
281 |
-
|
282 |
-
# Compute logits (unnormalized scores)
|
283 |
-
logits = self.fc_out(embeddings)
|
284 |
-
|
285 |
-
# Compute loss if labels are provided
|
286 |
-
loss = None
|
287 |
-
if labels is not None:
|
288 |
-
shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
|
289 |
-
shift_labels = labels[:, 1:].contiguous().view(-1)
|
290 |
-
|
291 |
-
loss_fct = nn.CrossEntropyLoss()
|
292 |
-
loss = loss_fct(shift_logits, shift_labels)
|
293 |
-
|
294 |
-
return {"loss": loss, "logits": logits}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|