Shriti09 commited on
Commit
7970870
·
verified ·
1 Parent(s): 5e3a92a

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +61 -90
  2. config.json +15 -0
  3. model.py +384 -0
  4. requirements.txt +3 -0
app.py CHANGED
@@ -1,92 +1,63 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer
4
- from model_smol2 import LlamaForCausalLM, config_model
5
-
6
- # Instantiate the model
7
- model = LlamaForCausalLM(config_model)
8
-
9
- # Load the checkpoint
10
- checkpoint_path = "final_checkpoint.pt"
11
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
12
- model.load_state_dict(checkpoint['model_state_dict'])
13
- model.eval()
14
-
15
- # Load tokenizer (replace with the appropriate tokenizer if you're using a custom one)
16
- # Load the tokenizer
17
- TOKENIZER_PATH = "HuggingFaceTB/cosmo2-tokenizer"
18
- tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
19
- if tokenizer.pad_token is None:
20
- tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]"
21
-
22
-
23
- # Text generation function
24
- def generate_text(
25
- prompt, max_length=50, temperature=0.7, top_k=50, repetition_penalty=1.2, n_gram_block=2
26
- ):
27
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
28
- generated_tokens = input_ids[0].tolist()
29
-
30
- with torch.no_grad():
31
- for _ in range(max_length):
32
- outputs = model(input_ids) # model outputs
33
-
34
- # Check if the output is a dictionary with logits
35
- if isinstance(outputs, dict) and 'logits' in outputs:
36
- logits = outputs['logits'][:, -1, :]
37
- else:
38
- # If not, treat the output as a plain tensor
39
- logits = outputs[:, -1, :]
40
-
41
- # Repetition penalty
42
- for token_id in set(generated_tokens):
43
- logits[:, token_id] /= repetition_penalty
44
-
45
- # n-gram blocking
46
- if len(generated_tokens) >= n_gram_block:
47
- n_gram = tuple(generated_tokens[-n_gram_block:])
48
- for token_id in set(generated_tokens):
49
- if generated_tokens[-n_gram_block:] == list(n_gram):
50
- logits[:, token_id] -= 1e9
51
-
52
- logits /= temperature
53
- top_k_logits, top_k_indices = torch.topk(logits, top_k, dim=-1)
54
- probs = torch.softmax(top_k_logits, dim=-1)
55
-
56
- next_token_idx = torch.multinomial(probs, num_samples=1)
57
- next_token = top_k_indices[0, next_token_idx[0]]
58
-
59
- generated_tokens.append(next_token.item())
60
- input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
61
-
62
- if next_token.item() == tokenizer.eos_token_id:
63
- break
64
-
65
- return tokenizer.decode(generated_tokens, skip_special_tokens=True)
66
-
67
-
68
- # Gradio UI
69
- def generate_response(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block):
70
- return generate_text(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block)
71
-
72
- with gr.Blocks() as demo:
73
- gr.Markdown("# Smol2 Text Generator")
74
- with gr.Row():
75
- with gr.Column():
76
- prompt_input = gr.Textbox(label="Input Prompt", placeholder="Enter your text prompt here...")
77
- max_length = gr.Slider(label="Max Length", minimum=10, maximum=200, value=50)
78
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.1)
79
- top_k = gr.Slider(label="Top K", minimum=10, maximum=100, value=50, step=1)
80
- repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.1)
81
- n_gram_block = gr.Slider(label="N-Gram Blocking", minimum=1, maximum=5, value=2, step=1)
82
- generate_button = gr.Button("Generate Text")
83
- with gr.Column():
84
- output_text = gr.Textbox(label="Generated Text", lines=10)
85
-
86
- generate_button.click(
87
- generate_response,
88
- inputs=[prompt_input, max_length, temperature, top_k, repetition_penalty, n_gram_block],
89
- outputs=[output_text],
90
- )
91
-
92
- demo.launch()
 
1
  import torch
2
  import gradio as gr
3
+ from model import CustomLLM
4
+ from transformers import GPT2Tokenizer
5
+
6
+ class ModelLoader:
7
+ def __init__(self):
8
+ # Load config
9
+ self.config = {
10
+ "vocab_size": 50257, # Update with your actual values
11
+ "hidden_size": 768,
12
+ "num_hidden_layers": 12,
13
+ "rms_norm_eps": 1e-6
14
+ }
15
+
16
+ # Instantiate model
17
+ self.model = CustomLLM(self.config)
18
+
19
+ # Load trained weights
20
+ state_dict = torch.load('pytorch_model.bin', map_location='cpu')
21
+ self.model.load_state_dict(state_dict)
22
+ self.model.eval()
23
+
24
+ # Load tokenizer
25
+ self.tokenizer = GPT2Tokenizer.from_pretrained('tokenizer/')
26
+ self.tokenizer.pad_token = self.tokenizer.eos_token
27
+
28
+ def generate(self, prompt, max_new_tokens=100, temperature=0.9, top_k=50, top_p=0.95):
29
+ inputs = self.tokenizer(prompt, return_tensors="pt")
30
+ input_ids = inputs.input_ids
31
+
32
+ with torch.no_grad():
33
+ generated = self.model.generate(
34
+ input_ids=input_ids,
35
+ max_new_tokens=max_new_tokens,
36
+ temperature=temperature,
37
+ top_k=top_k,
38
+ top_p=top_p,
39
+ eos_token_id=self.tokenizer.eos_token_id,
40
+ pad_token_id=self.tokenizer.pad_token_id
41
+ )
42
+
43
+ return self.tokenizer.decode(generated[0], skip_special_tokens=True)
44
+
45
+ # Initialize model
46
+ loader = ModelLoader()
47
+
48
+ # Create Gradio interface
49
+ interface = gr.Interface(
50
+ fn=loader.generate,
51
+ inputs=[
52
+ gr.Textbox(lines=4, label="Input Prompt"),
53
+ gr.Slider(1, 500, value=100, label="Max New Tokens"),
54
+ gr.Slider(0.1, 2.0, value=0.9, label="Temperature"),
55
+ gr.Slider(1, 100, value=50, label="Top K"),
56
+ gr.Slider(0.1, 1.0, value=0.95, label="Top P")
57
+ ],
58
+ outputs=gr.Textbox(label="Generated Output"),
59
+ title="Custom LLM Demo",
60
+ description="Generate text using your custom-trained LLM"
61
+ )
62
+
63
+ interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 49152,
3
+ "hidden_size": 576,
4
+ "intermediate_size": 1536,
5
+ "num_hidden_layers": 30,
6
+ "num_attention_heads": 9,
7
+ "num_key_value_heads": 3,
8
+ "max_position_embeddings": 2048,
9
+ "rms_norm_eps": 1e-5,
10
+ "rope_theta": 10000.0,
11
+ "pad_token_id": 0,
12
+ "bos_token_id": 0,
13
+ "eos_token_id": 0,
14
+ }
15
+
model.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+
6
+ # 1. Custom Configuration Class
7
+ class CustomConfig:
8
+ def __init__(self):
9
+ # Architecture Parameters
10
+ self.vocab_size = 49152
11
+ self.hidden_size = 576 # d_model
12
+ self.intermediate_size = 1536 # FFN dimension
13
+ self.num_hidden_layers = 30 # Number of decoder layers
14
+ self.num_attention_heads = 9 # Query heads
15
+ self.num_key_value_heads = 3 # Key/Value heads
16
+ self.max_position_embeddings = 2048
17
+ self.rms_norm_eps = 1e-5
18
+ self.rope_theta = 10000.0 # Rotary embedding base
19
+
20
+ # Tokenizer/Generation Params
21
+ self.pad_token_id = None
22
+ self.bos_token_id = 0
23
+ self.eos_token_id = 0
24
+
25
+ def to_dict(self):
26
+ # Serialize the config parameters
27
+ return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
28
+
29
+ # 2. Custom RMS Normalization
30
+ class CustomRMSNorm(nn.Module):
31
+ def __init__(self, dim, eps=1e-5):
32
+ super().__init__()
33
+ self.weight = nn.Parameter(torch.ones(dim))
34
+ self.eps = eps
35
+
36
+ def _norm(self, x):
37
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
38
+
39
+ def forward(self, x):
40
+ return self.weight * self._norm(x.float()).type_as(x)
41
+
42
+ # 3. Rotary Positional Embeddings
43
+ class RotaryEmbedding(nn.Module):
44
+ def __init__(self, dim, max_seq_len=2048, theta=10000.0):
45
+ super().__init__()
46
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
47
+ self.register_buffer("inv_freq", inv_freq)
48
+ self._set_cos_sin_cache(max_seq_len)
49
+
50
+ def _set_cos_sin_cache(self, seq_len):
51
+ t = torch.arange(seq_len, device=self.inv_freq.device)
52
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
53
+ emb = torch.cat((freqs, freqs), dim=-1)
54
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
55
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :])
56
+
57
+ def forward(self, x, seq_len):
58
+ if seq_len > self.cos_cached.shape[2]:
59
+ self._set_cos_sin_cache(seq_len)
60
+ return self.cos_cached[:, :, :seq_len], self.sin_cached[:, :, :seq_len]
61
+
62
+ # 4. Attention Layer with Grouped Query Attention
63
+ class CustomAttention(nn.Module):
64
+ def __init__(self, config):
65
+ super().__init__()
66
+ self.hidden_size = config.hidden_size
67
+ self.num_heads = config.num_attention_heads
68
+ self.head_dim = self.hidden_size // self.num_heads
69
+ self.num_kv_heads = config.num_key_value_heads
70
+
71
+ # Projections
72
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
73
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
74
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
75
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
76
+
77
+ # Rotary embeddings
78
+ self.rotary_emb = RotaryEmbedding(
79
+ self.head_dim,
80
+ max_seq_len=config.max_position_embeddings,
81
+ theta=config.rope_theta
82
+ )
83
+
84
+ def forward(self, x, attention_mask=None):
85
+ batch_size, seq_len, _ = x.shape
86
+
87
+ # Project queries/keys/values
88
+ q = self.q_proj(x)
89
+ k = self.k_proj(x)
90
+ v = self.v_proj(x)
91
+
92
+ # Reshape for attention computation
93
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
94
+ k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
95
+ v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
96
+
97
+ # Apply rotary embeddings
98
+ cos, sin = self.rotary_emb(x, seq_len=seq_len)
99
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
100
+
101
+ # Repeat keys and values to match the number of query heads
102
+ repeat_factor = self.num_heads // self.num_kv_heads
103
+ k = k.repeat_interleave(repeat_factor, dim=1)
104
+ v = v.repeat_interleave(repeat_factor, dim=1)
105
+
106
+ # Attention scores
107
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
108
+
109
+ # Apply attention mask
110
+ if attention_mask is not None:
111
+ attn_weights = attn_weights + attention_mask
112
+
113
+ attn_weights = torch.softmax(attn_weights, dim=-1)
114
+ attn_output = torch.matmul(attn_weights, v)
115
+
116
+ # Reshape and project back
117
+ attn_output = attn_output.transpose(1, 2).contiguous()
118
+ attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
119
+ return self.o_proj(attn_output)
120
+
121
+ # 5. MLP Layer
122
+ class CustomMLP(nn.Module):
123
+ def __init__(self, config):
124
+ super().__init__()
125
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
126
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
127
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
128
+ self.act_fn = nn.SiLU()
129
+
130
+ def forward(self, x):
131
+ gate = self.act_fn(self.gate_proj(x))
132
+ up = self.up_proj(x)
133
+ return self.down_proj(gate * up)
134
+
135
+ # 6. Transformer Decoder Layer
136
+ class DecoderLayer(nn.Module):
137
+ def __init__(self, config):
138
+ super().__init__()
139
+ self.self_attn = CustomAttention(config)
140
+ self.mlp = CustomMLP(config)
141
+ self.input_norm = CustomRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
142
+ self.post_attn_norm = CustomRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
143
+
144
+ def forward(self, x, attention_mask=None):
145
+ # Self-attention
146
+ residual = x
147
+ x = self.input_norm(x)
148
+ x = self.self_attn(x, attention_mask)
149
+ x = residual + x
150
+
151
+ # MLP
152
+ residual = x
153
+ x = self.post_attn_norm(x)
154
+ x = self.mlp(x)
155
+ x = residual + x
156
+ return x
157
+
158
+ # 7. Full Model
159
+ class CustomLLM(nn.Module):
160
+ def __init__(self, config):
161
+ super().__init__()
162
+ self.config = config
163
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
164
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
165
+ self.norm = CustomRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
166
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
167
+ self.lm_head.weight = self.embed_tokens.weight # Tie the weights To reduce param
168
+
169
+ # Initialize weights
170
+ self.apply(self._init_weights)
171
+
172
+ def _init_weights(self, module):
173
+ if isinstance(module, nn.Linear):
174
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
175
+ if module.bias is not None:
176
+ torch.nn.init.zeros_(module.bias)
177
+ elif isinstance(module, nn.Embedding):
178
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
179
+
180
+ def forward(self, input_ids, attention_mask=None, labels=None):
181
+ x = self.embed_tokens(input_ids)
182
+ batch_size, seq_len = input_ids.shape
183
+
184
+ # Create causal mask
185
+ causal_mask = torch.full((seq_len, seq_len), float("-inf"), device=x.device)
186
+ causal_mask = torch.triu(causal_mask, diagonal=1)
187
+ causal_mask = causal_mask[None, None, :, :] # Shape: [1, 1, seq_len, seq_len]
188
+
189
+ # Combine with padding mask
190
+ if attention_mask is not None:
191
+ padding_mask = (1.0 - attention_mask.float()) * torch.finfo(x.dtype).min
192
+ padding_mask = padding_mask.view(batch_size, 1, 1, seq_len)
193
+ combined_mask = causal_mask + padding_mask
194
+ else:
195
+ combined_mask = causal_mask
196
+
197
+ # Process through decoder layers
198
+ for layer in self.layers:
199
+ x = layer(x, attention_mask=combined_mask)
200
+
201
+ x = self.norm(x)
202
+ logits = self.lm_head(x)
203
+
204
+ loss = None
205
+ if labels is not None:
206
+ # Shift logits and labels for causal LM
207
+ shift_logits = logits[..., :-1, :].contiguous()
208
+ shift_labels = labels[..., 1:].contiguous()
209
+ loss_fct = nn.CrossEntropyLoss()
210
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
211
+
212
+ return CausalLMOutputWithPast(
213
+ loss=loss,
214
+ logits=logits,
215
+ past_key_values=None,
216
+ hidden_states=None,
217
+ attentions=None,
218
+ )
219
+
220
+ def generate(
221
+ self,
222
+ input_ids: torch.Tensor,
223
+ max_new_tokens: int = 100,
224
+ temperature: float = 1.0,
225
+ top_k: int = None,
226
+ top_p: float = None,
227
+ repetition_penalty: float = 1.0,
228
+ eos_token_id: int = None,
229
+ pad_token_id: int = None,
230
+ ):
231
+ """
232
+ Generates text using various decoding strategies.
233
+
234
+ Args:
235
+ input_ids: Input token IDs of shape (batch_size, seq_len)
236
+ max_new_tokens: Maximum number of tokens to generate
237
+ temperature: Sampling temperature (higher = more random)
238
+ top_k: Top-k sampling cutoff
239
+ top_p: Nucleus sampling cutoff
240
+ repetition_penalty: Penalty for repeated tokens (1.0 = no penalty)
241
+ eos_token_id: Stop generation when this token is produced
242
+ pad_token_id: Padding token ID for sequence termination
243
+
244
+ Returns:
245
+ Generated sequence of token IDs
246
+ """
247
+ # Ensure model is in eval mode
248
+ self.eval()
249
+
250
+ # Move inputs to model device
251
+ input_ids = input_ids.to(self.embed_tokens.weight.device)
252
+ batch_size = input_ids.size(0)
253
+
254
+ # Storage for generated sequences
255
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
256
+ past_key_values = None # Could implement KV caching here for efficiency
257
+
258
+ for _ in range(max_new_tokens):
259
+ # Forward pass (only compute last logits for efficiency)
260
+ with torch.no_grad():
261
+ outputs = self(input_ids)
262
+ next_token_logits = outputs.logits[:, -1, :]
263
+
264
+ # Repetition penalty
265
+ if repetition_penalty != 1.0:
266
+ next_token_logits = self._apply_repetition_penalty(
267
+ next_token_logits, input_ids, repetition_penalty
268
+ )
269
+
270
+ # Temperature scaling
271
+ if temperature != 1.0:
272
+ next_token_logits = next_token_logits / temperature
273
+
274
+ # Top-k filtering
275
+ if top_k is not None and top_k > 0:
276
+ top_k_values, _ = torch.topk(next_token_logits, top_k)
277
+ min_top_k = top_k_values[:, -1].unsqueeze(-1)
278
+ next_token_logits = torch.where(
279
+ next_token_logits < min_top_k,
280
+ torch.tensor(-float('inf')).to(next_token_logits.device),
281
+ next_token_logits
282
+ )
283
+
284
+ # Top-p (nucleus) sampling
285
+ if top_p is not None and top_p < 1.0:
286
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
287
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
288
+
289
+ # Remove tokens with cumulative probability above threshold
290
+ sorted_indices_to_remove = cumulative_probs > top_p
291
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
292
+ sorted_indices_to_remove[..., 0] = 0
293
+
294
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
295
+ next_token_logits[indices_to_remove] = -float('inf')
296
+
297
+ # Convert logits to probabilities
298
+ probs = torch.softmax(next_token_logits, dim=-1)
299
+
300
+ # Sample next tokens
301
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
302
+
303
+ # Update sequences
304
+ input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
305
+
306
+ # Check for EOS tokens
307
+ if eos_token_id is not None:
308
+ unfinished = (next_tokens != eos_token_id).long() * unfinished_sequences
309
+ unfinished_sequences = unfinished
310
+
311
+ if unfinished_sequences.max() == 0:
312
+ break
313
+
314
+ # Pad sequences if requested
315
+ if pad_token_id is not None and eos_token_id is not None:
316
+ input_ids = self._pad_sequences(input_ids, eos_token_id, pad_token_id)
317
+
318
+ return input_ids
319
+
320
+ def _apply_repetition_penalty(self, logits, sequences, penalty):
321
+ """Applies repetition penalty to logits"""
322
+ score = torch.gather(logits, 1, sequences)
323
+ score = torch.where(score < 0, score * penalty, score / penalty)
324
+ logits.scatter_(1, sequences, score)
325
+ return logits
326
+
327
+ def _pad_sequences(self, sequences, eos_token_id, pad_token_id):
328
+ """Replace tokens after EOS with pad token"""
329
+ # Create mask of positions after EOS
330
+ eos_positions = (sequences == eos_token_id).int().argmax(dim=-1)
331
+ padding_mask = torch.arange(sequences.size(1), device=sequences.device) > eos_positions.unsqueeze(-1)
332
+
333
+ # Apply padding
334
+ sequences[padding_mask] = pad_token_id
335
+ return sequences
336
+
337
+ # Helper function for rotary embeddings
338
+ def apply_rotary_pos_emb(q, k, cos, sin):
339
+ q_embed = (q * cos) + (rotate_half(q) * sin)
340
+ k_embed = (k * cos) + (rotate_half(k) * sin)
341
+ return q_embed, k_embed
342
+
343
+ def rotate_half(x):
344
+ x1 = x[..., : x.shape[-1] // 2]
345
+ x2 = x[..., x.shape[-1] // 2 :]
346
+ return torch.cat((-x2, x1), dim=-1)
347
+
348
+ '''
349
+ # Usage
350
+ config = CustomConfig()
351
+ model = CustomLLM(config)
352
+
353
+ # Verify parameters
354
+ total_params = sum(p.numel() for p in model.parameters())
355
+ print(f"Total parameters: {total_params/1e6:.2f}M") # Should output ~135.00M
356
+ print(model)
357
+ # Test forward pass after fix
358
+ input_ids = torch.randint(0, config.vocab_size, (1, 256))
359
+ output = model(input_ids)
360
+ print(output.shape) # Expected output: (1, 256, 49152)
361
+
362
+ # Initialize model
363
+ config = CustomConfig()
364
+ model = CustomLLM(config)
365
+
366
+ # Generate text
367
+ prompt = torch.tensor([[config.bos_token_id]]) # Start token
368
+ generated = model.generate(
369
+ prompt,
370
+ max_new_tokens=50,
371
+ temperature=0.7,
372
+ top_p=0.9,
373
+ eos_token_id=config.eos_token_id,
374
+ pad_token_id=config.pad_token_id
375
+ )
376
+ from transformers import AutoTokenizer
377
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
378
+ tokenizer.pad_token = tokenizer.eos_token # For padding
379
+ # Decode tokens
380
+ generated_text = tokenizer.decode(generated[0].tolist())
381
+ print(prompt)
382
+ print(generated_text)
383
+ '''
384
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0.0
2
+ gradio>=3.0.0
3
+ transformers>=4.30.0