anto18671 commited on
Commit
c9d78a1
·
verified ·
1 Parent(s): dea8a41

Delete lumenspark.py

Browse files
Files changed (1) hide show
  1. lumenspark.py +0 -294
lumenspark.py DELETED
@@ -1,294 +0,0 @@
1
- from transformers import PretrainedConfig, PreTrainedModel, GenerationConfig
2
- from torch import nn
3
- import torch
4
-
5
- # ----------------------------
6
- # Define Lumenspark Configuration
7
- # ----------------------------
8
-
9
- class LumensparkConfig(PretrainedConfig):
10
- """
11
- Configuration class for the Lumenspark model.
12
- Stores model hyperparameters like sequence length, embedding dimension, number of layers, and others.
13
- """
14
- model_type = "lumenspark"
15
-
16
- def __init__(
17
- self,
18
- seq_length=512,
19
- vocab_size=50257,
20
- embed_dim=512,
21
- depth=6,
22
- heads=4,
23
- dropout=0.1,
24
- k=128,
25
- **kwargs
26
- ):
27
- super().__init__(**kwargs)
28
- self.vocab_size = vocab_size
29
- self.embed_dim = embed_dim
30
- self.depth = depth
31
- self.heads = heads
32
- self.seq_length = seq_length
33
- self.dropout = dropout
34
- self.k = k
35
-
36
- def to_dict(self):
37
- """
38
- Converts the configuration parameters to a dictionary format.
39
- Useful for saving the configuration or inspecting model settings.
40
- """
41
- output = super().to_dict()
42
- output.update({
43
- "vocab_size": self.vocab_size,
44
- "embed_dim": self.embed_dim,
45
- "depth": self.depth,
46
- "heads": self.heads,
47
- "seq_length": self.seq_length,
48
- "dropout": self.dropout,
49
- "k": self.k
50
- })
51
- return output
52
-
53
- # ----------------------------
54
- # Low-Rank Linear Layer Implementation
55
- # ----------------------------
56
-
57
- class LowRankLinear(nn.Module):
58
- """
59
- A low-rank linear layer that factorizes a standard linear layer into two smaller ones.
60
- This allows for reduced parameter count and faster computation.
61
- """
62
- def __init__(self, in_features, out_features, rank):
63
- super().__init__()
64
- self.U = nn.Linear(in_features, rank, bias=False)
65
- self.V = nn.Linear(rank, out_features, bias=False)
66
-
67
- def forward(self, x):
68
- """
69
- Forward pass through two low-rank linear layers (U and V).
70
- """
71
- return self.V(self.U(x))
72
-
73
- # ----------------------------
74
- # Lumenspark Self-Attention Implementation
75
- # ----------------------------
76
-
77
- class LumensparkSelfAttention(nn.Module):
78
- """
79
- Custom self-attention mechanism for the Lumenspark model.
80
- It includes low-rank approximations to reduce computational cost and memory usage.
81
- """
82
- def __init__(self, embed_dim, max_seq_len, proj_dim, num_heads, head_dim=None, single_kv_head=True, shared_kv=True, dropout=0.1):
83
- super().__init__()
84
- assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads'
85
-
86
- self.max_seq_len = max_seq_len
87
- self.proj_dim = proj_dim
88
- self.num_heads = num_heads
89
- self.embed_dim = embed_dim
90
-
91
- # Set the dimensionality of each attention head
92
- self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads
93
-
94
- # Query transformation: Low-rank projection followed by linear layer
95
- self.query_transform = nn.Sequential(
96
- LowRankLinear(embed_dim, embed_dim // 2, rank=32),
97
- nn.Linear(embed_dim // 2, self.head_dim * num_heads)
98
- )
99
- kv_size = self.head_dim if single_kv_head else (self.head_dim * num_heads)
100
- self.key_transform = nn.Linear(embed_dim, kv_size, bias=False)
101
- self.key_proj = nn.Parameter(self.initialize_proj_matrix(max_seq_len, proj_dim))
102
-
103
- # Shared key-value projection option
104
- self.shared_kv = shared_kv
105
- if not shared_kv:
106
- self.value_transform = nn.Linear(embed_dim, kv_size, bias=False)
107
- self.value_proj = nn.Parameter(self.initialize_proj_matrix(max_seq_len, proj_dim))
108
-
109
- self.dropout_layer = nn.Dropout(dropout) # Dropout for regularization
110
- self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim)
111
-
112
- def initialize_proj_matrix(self, rows, cols):
113
- """
114
- Initializes the projection matrix used to reduce the sequence length for key/value pairs.
115
- """
116
- return torch.nn.init.xavier_uniform_(torch.zeros(rows, cols))
117
-
118
- def forward(self, inputs, context_data=None, **kwargs):
119
- """
120
- Forward pass of the self-attention mechanism.
121
- """
122
- batch_size, seq_len, _ = inputs.shape
123
- kv_seq_len = inputs.shape[1] if context_data is None else context_data.shape[1]
124
- assert kv_seq_len <= self.max_seq_len, f'Key/value sequence length exceeds the max sequence length: {self.max_seq_len}'
125
-
126
- # Apply transformations to queries, keys, and values
127
- queries = self.query_transform(inputs)
128
- kv_inputs = inputs if context_data is None else context_data
129
- keys = self.key_transform(kv_inputs)
130
- values = self.value_transform(kv_inputs) if not self.shared_kv else keys
131
-
132
- # Apply projection matrix to keys and values
133
- keys = torch.einsum('bnd,nk->bkd', keys, self.key_proj[:kv_seq_len])
134
- values = torch.einsum('bnd,nk->bkd', values, self.value_proj[:kv_seq_len] if not self.shared_kv else self.key_proj[:kv_seq_len])
135
-
136
- # Reshape queries, keys, and values for multi-head attention
137
- queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
138
- keys = keys.view(batch_size, self.proj_dim, -1, self.head_dim).transpose(1, 2)
139
- values = values.view(batch_size, self.proj_dim, -1, self.head_dim).transpose(1, 2)
140
-
141
- # Compute scaled dot-product attention
142
- attention_scores = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (self.head_dim ** -0.5)
143
- attention_weights = attention_scores.softmax(dim=-1)
144
- attention_weights = self.dropout_layer(attention_weights)
145
-
146
- # Apply attention weights to values and compute output
147
- attention_output = torch.einsum('bhnk,bhkd->bhnd', attention_weights, values)
148
- attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
149
- return self.output_transform(attention_output)
150
-
151
- # ----------------------------
152
- # RMSNorm Layer Implementation
153
- # ----------------------------
154
-
155
- class RMSNorm(nn.Module):
156
- """
157
- Root Mean Square Layer Normalization (RMSNorm) without affine transformation.
158
- This normalization technique scales the input without centering it.
159
- """
160
- def __init__(self, embed_dim, eps=1e-6):
161
- super().__init__()
162
- self.eps = eps # Small constant to prevent division by zero
163
- self.scale = nn.Parameter(torch.ones(embed_dim)) # Scaling factor
164
-
165
- def forward(self, x):
166
- """
167
- Forward pass through the RMSNorm layer.
168
- """
169
- norm_x = x.norm(2, dim=-1, keepdim=True)
170
- rms_x = norm_x / (x.size(-1) ** 0.5) # Root mean square normalization
171
- return self.scale * x / (rms_x + self.eps)
172
-
173
- # ----------------------------
174
- # Define Lumenspark Model Wrapper
175
- # ----------------------------
176
-
177
- class LumensparkModel(PreTrainedModel):
178
- """
179
- Lumenspark model with factorized linear projections, multi-head attention, and RMSNorm for normalization.
180
- This model is specifically designed to handle long sequences efficiently.
181
- """
182
- config_class = LumensparkConfig
183
-
184
- def __init__(self, config):
185
- super().__init__(config)
186
- self.config = config
187
-
188
- # Token and position embeddings
189
- self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
190
- self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim)
191
-
192
- # Lumenspark transformer encoder layers
193
- self.layers = nn.ModuleList([nn.ModuleDict({
194
- "attn": LumensparkSelfAttention(
195
- embed_dim=config.embed_dim,
196
- max_seq_len=config.seq_length,
197
- num_heads=config.heads,
198
- proj_dim=config.k,
199
- head_dim=config.embed_dim // config.heads,
200
- single_kv_head=True,
201
- shared_kv=True,
202
- dropout=config.dropout
203
- ),
204
- "norm1": RMSNorm(config.embed_dim),
205
- "ffn": nn.Sequential(
206
- LowRankLinear(config.embed_dim, config.embed_dim // 2, rank=32),
207
- nn.GELU(),
208
- LowRankLinear(config.embed_dim // 2, config.embed_dim, rank=32),
209
- nn.Dropout(config.dropout)
210
- ),
211
- "norm2": RMSNorm(config.embed_dim)
212
- }) for _ in range(config.depth)])
213
-
214
- # Feed-forward output layer
215
- self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
216
- self.dropout = nn.Dropout(config.dropout)
217
-
218
- # Initialize model weights
219
- self.init_weights()
220
-
221
- # Create GenerationConfig instance for text generation
222
- self.generation_config = GenerationConfig(
223
- max_length=128,
224
- min_length=16,
225
- )
226
-
227
- def generate(self, input_ids, max_length=128, min_length=16, temperature=1.0, top_k=50, top_p=0.95, do_sample=True):
228
- """
229
- Text generation method that handles auto-regressive generation.
230
- """
231
- generated_tokens = input_ids
232
-
233
- for _ in range(max_length - input_ids.size(1)):
234
- outputs = self.forward(input_ids=generated_tokens)
235
- logits = outputs["logits"][:, -1, :]
236
- logits = logits / temperature
237
-
238
- # Apply top-k and top-p sampling to select the next token
239
- if do_sample:
240
- filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
241
- probs = torch.softmax(filtered_logits, dim=-1)
242
- next_token = torch.multinomial(probs, num_samples=1)
243
- else:
244
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
245
-
246
- # Append the generated token
247
- generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
248
-
249
- # Stop if the EOS token is generated
250
- if next_token.item() == self.config.eos_token_id:
251
- break
252
-
253
- return generated_tokens
254
-
255
- def forward(self, input_ids, attention_mask=None, labels=None):
256
- """
257
- Forward pass of the model. If labels are provided, the loss is also computed.
258
- """
259
- batch_size, seq_length = input_ids.size()
260
-
261
- # Generate position ids
262
- position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
263
- position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
264
-
265
- # Embed tokens and positions
266
- token_embeddings = self.token_embedding(input_ids)
267
- position_embeddings = self.position_embedding(position_ids)
268
-
269
- # Combine token and position embeddings
270
- embeddings = token_embeddings + position_embeddings
271
- embeddings = self.dropout(embeddings)
272
-
273
- # Pass through each transformer layer
274
- for layer in self.layers:
275
- embeddings = layer["attn"](embeddings) + embeddings
276
- embeddings = layer["norm1"](embeddings)
277
-
278
- ffn_out = layer["ffn"](embeddings)
279
- embeddings = ffn_out + embeddings
280
- embeddings = layer["norm2"](embeddings)
281
-
282
- # Compute logits (unnormalized scores)
283
- logits = self.fc_out(embeddings)
284
-
285
- # Compute loss if labels are provided
286
- loss = None
287
- if labels is not None:
288
- shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
289
- shift_labels = labels[:, 1:].contiguous().view(-1)
290
-
291
- loss_fct = nn.CrossEntropyLoss()
292
- loss = loss_fct(shift_logits, shift_labels)
293
-
294
- return {"loss": loss, "logits": logits}