ProCreations commited on
Commit
a2ce935
·
1 Parent(s): 04582f0

Initial RLHF chat UI for intellite 100M

Browse files
Files changed (6) hide show
  1. .gitignore +4 -0
  2. README.md +57 -7
  3. app.py +338 -0
  4. config.py +60 -0
  5. model.py +162 -0
  6. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ data.json
4
+ data.json.tmp
README.md CHANGED
@@ -1,13 +1,63 @@
1
  ---
2
- title: Intellite
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.12.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: 'Intellite '
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: intellite-100m
3
+ emoji: 💬
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.0.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ # intellite-100M RLHF data collector
13
+
14
+ Serves the SFT-tuned intellite 100M model in a chat UI. Every assistant reply
15
+ gets 👍 / 👎 buttons; each rating appends one record to `data.json` with the
16
+ prompt, the response, and the binary reward — ready for RLHF / DPO training
17
+ on your Mac.
18
+
19
+ ## Setup
20
+
21
+ 1. Copy your SFT checkpoint to the Space root as **`best.pt`**
22
+ (or set `INTELLITE_CKPT=/path/to/file.pt` in the Space's settings → Variables).
23
+ Use `git lfs track "best.pt"` before committing the weights file.
24
+ 2. Push the Space. `app.py` loads the checkpoint once at startup.
25
+
26
+ ## Data format
27
+
28
+ `data.json` is a list of records, one per rating:
29
+
30
+ ```json
31
+ {
32
+ "ts": "2026-04-20T15:23:45",
33
+ "system": "You are a helpful, honest, and concise assistant.",
34
+ "prompt_messages": [
35
+ { "role": "user", "content": "..." },
36
+ { "role": "assistant", "content": "..." },
37
+ { "role": "user", "content": "..." }
38
+ ],
39
+ "response": "...",
40
+ "liked": true
41
+ }
42
+ ```
43
+
44
+ Each record is exactly `(prompt, response, reward∈{0,1})` — the shape any
45
+ preference/RL trainer expects. For DPO, group records by identical `prompt_messages`
46
+ and pair a `liked=true` response (chosen) with a `liked=false` one (rejected).
47
+ For REINFORCE/PPO, feed `liked` as a {−1, +1} or {0, 1} reward.
48
+
49
+ ## Downloading the data
50
+
51
+ The right-hand panel has an **⬇ Download data.json** button — one click on your
52
+ Mac and you've got every rating so far.
53
+
54
+ ## Clearing the data
55
+
56
+ The **Clear data.json** button empties the file on the Space. Do this after
57
+ pulling the file locally so you don't double-count records on the next export.
58
+
59
+ ## Notes on the free CPU tier
60
+
61
+ Generation on CPU is slow (~5–10 tok/s for 100M in fp32). If you move to the
62
+ paid GPU tier, the app auto-detects `cuda` and uses bf16 autocast — roughly
63
+ 10× faster.
app.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """intellite 100M — RLHF data collector served as a Gradio HuggingFace Space.
2
+
3
+ Every assistant reply gets 👍 / 👎 buttons. When the user rates a reply,
4
+ the (system, prior messages, response, liked) tuple is appended to
5
+ data.json in the Space's working directory. A Download button exposes
6
+ that file so you can grab it on your Mac and use it for RL / DPO.
7
+
8
+ The SFT checkpoint is loaded from:
9
+ $INTELLITE_CKPT (if set), else ./best.pt at the Space root.
10
+ """
11
+
12
+ import json
13
+ import os
14
+ import sys
15
+ import threading
16
+ import time
17
+ import traceback
18
+ from pathlib import Path
19
+
20
+ import gradio as gr
21
+ import tiktoken
22
+ import torch
23
+
24
+ SPACE_DIR = Path(__file__).resolve().parent
25
+ sys.path.insert(0, str(SPACE_DIR))
26
+
27
+ from config import ModelConfig
28
+ from model import IntelliteGPT
29
+
30
+ # ------------------------------------------------------------------------
31
+ # Paths & constants
32
+
33
+ CKPT_PATH = Path(os.environ.get("INTELLITE_CKPT", SPACE_DIR / "best.pt"))
34
+ DATA_PATH = SPACE_DIR / "data.json"
35
+
36
+ DEFAULT_SYSTEM = "You are a helpful, honest, and concise assistant."
37
+ SYSTEM_TAG = "<|system|>\n"
38
+ USER_TAG = "<|user|>\n"
39
+ ASST_TAG = "<|assistant|>\n"
40
+ STOP_MARKERS = ("<|user|>", "<|system|>")
41
+
42
+
43
+ # ------------------------------------------------------------------------
44
+ # Model load (once, at startup)
45
+
46
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
+ print(f"[sys] device={DEVICE} ckpt={CKPT_PATH}")
48
+
49
+ if not CKPT_PATH.exists():
50
+ raise FileNotFoundError(
51
+ f"No checkpoint at {CKPT_PATH}. Upload your SFT best.pt to the Space "
52
+ f"root, or set the INTELLITE_CKPT environment variable to its path."
53
+ )
54
+
55
+ sd = torch.load(str(CKPT_PATH), map_location=DEVICE)
56
+ _fields = ModelConfig.__dataclass_fields__.keys()
57
+ MCFG = ModelConfig(**{k: v for k, v in sd["model_cfg"].items() if k in _fields})
58
+ MODEL = IntelliteGPT(MCFG).to(DEVICE)
59
+ MODEL.load_state_dict(sd["model"])
60
+ MODEL.eval()
61
+ TOKENS_SEEN = int(sd.get("tokens_seen", 0))
62
+ BEST_VAL = float(sd.get("best_val", float("nan")))
63
+
64
+ ENC = tiktoken.get_encoding("gpt2")
65
+ EOT = ENC.eot_token
66
+ N_PARAMS = MODEL.num_params()
67
+ print(f"[model] {N_PARAMS/1e6:.1f}M params tokens_seen={TOKENS_SEEN:,} best_val={BEST_VAL:.4f}")
68
+
69
+
70
+ # ------------------------------------------------------------------------
71
+ # Prompt templating + generation (mirrors chat.py)
72
+
73
+ def render_prompt_ids(system: str, prior_messages: list[dict], user_msg: str) -> list[int]:
74
+ """Encode the SFT chat template exactly as sft_prepare.py did."""
75
+ ids: list[int] = []
76
+ if system:
77
+ ids.extend(ENC.encode_ordinary(SYSTEM_TAG + system.strip() + "\n"))
78
+ # Pair prior messages into (user, assistant) turns.
79
+ pending_user = None
80
+ for m in prior_messages:
81
+ role = m.get("role")
82
+ content = (m.get("content") or "").strip()
83
+ if role == "user":
84
+ pending_user = content
85
+ elif role == "assistant" and pending_user is not None:
86
+ ids.extend(ENC.encode_ordinary(USER_TAG + pending_user + "\n"))
87
+ ids.extend(ENC.encode_ordinary(ASST_TAG))
88
+ ids.extend(ENC.encode_ordinary(content))
89
+ ids.append(EOT)
90
+ pending_user = None
91
+ # Current user turn + assistant opener.
92
+ ids.extend(ENC.encode_ordinary(USER_TAG + user_msg.strip() + "\n"))
93
+ ids.extend(ENC.encode_ordinary(ASST_TAG))
94
+ return ids
95
+
96
+
97
+ @torch.no_grad()
98
+ def stream_reply(prompt_ids, max_new, temperature, top_k, top_p, rep_penalty):
99
+ """Yield the partial assistant reply after each new token."""
100
+ x = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE)
101
+ ctx = MCFG.seq_len
102
+ start = len(prompt_ids)
103
+ reply = ""
104
+
105
+ for _ in range(max_new):
106
+ xc = x[:, -ctx:]
107
+ if DEVICE == "cuda":
108
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
109
+ logits, _ = MODEL(xc)
110
+ else:
111
+ logits, _ = MODEL(xc)
112
+ logits = logits[0, -1, :].float()
113
+
114
+ if rep_penalty and rep_penalty != 1.0:
115
+ seen = torch.unique(x[0])
116
+ prev = logits[seen]
117
+ logits[seen] = torch.where(prev > 0, prev / rep_penalty, prev * rep_penalty)
118
+
119
+ logits = logits / max(temperature, 1e-5)
120
+
121
+ if top_k and top_k > 0:
122
+ k = min(int(top_k), logits.numel())
123
+ v, _ = torch.topk(logits, k)
124
+ logits[logits < v[-1]] = -float("inf")
125
+
126
+ if top_p and 0.0 < top_p < 1.0:
127
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
128
+ cum = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
129
+ mask = cum > top_p
130
+ mask[1:] = mask[:-1].clone()
131
+ mask[0] = False
132
+ logits[sorted_idx[mask]] = -float("inf")
133
+
134
+ probs = torch.softmax(logits, dim=-1)
135
+ next_tok = torch.multinomial(probs, num_samples=1)
136
+ tok_id = int(next_tok.item())
137
+ x = torch.cat([x, next_tok.unsqueeze(0)], dim=1)
138
+
139
+ if tok_id == EOT:
140
+ break
141
+
142
+ reply = ENC.decode(x[0, start:].tolist())
143
+
144
+ # Strip trailing replacement char (partial UTF-8) for nicer streaming.
145
+ while reply.endswith("\ufffd"):
146
+ reply = reply[:-1]
147
+
148
+ hit_stop = False
149
+ for marker in STOP_MARKERS:
150
+ idx = reply.find(marker)
151
+ if idx != -1:
152
+ reply = reply[:idx]
153
+ hit_stop = True
154
+ break
155
+ if hit_stop:
156
+ break
157
+
158
+ yield reply.strip()
159
+
160
+ yield reply.strip()
161
+
162
+
163
+ # ------------------------------------------------------------------------
164
+ # Feedback store (data.json)
165
+
166
+ _feedback_lock = threading.Lock()
167
+
168
+
169
+ def _read_data() -> list:
170
+ if not DATA_PATH.exists():
171
+ return []
172
+ try:
173
+ with open(DATA_PATH) as f:
174
+ return json.load(f)
175
+ except Exception:
176
+ return []
177
+
178
+
179
+ def _write_data(items: list) -> None:
180
+ tmp = DATA_PATH.with_suffix(".json.tmp")
181
+ with open(tmp, "w") as f:
182
+ json.dump(items, f, indent=2, ensure_ascii=False)
183
+ tmp.replace(DATA_PATH)
184
+
185
+
186
+ if not DATA_PATH.exists():
187
+ _write_data([])
188
+
189
+
190
+ def _stats_str() -> str:
191
+ with _feedback_lock:
192
+ items = _read_data()
193
+ total = len(items)
194
+ liked = sum(1 for i in items if i.get("liked"))
195
+ return f"**{total}** records · 👍 {liked} · 👎 {total - liked}"
196
+
197
+
198
+ def save_feedback(evt: gr.LikeData, history: list, system: str) -> str:
199
+ """Handle a thumbs-up / thumbs-down click on a chat message."""
200
+ if evt.liked is None:
201
+ return "rating cleared (nothing saved)"
202
+
203
+ # evt.index is an int in messages mode; be defensive either way.
204
+ idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index
205
+ if not isinstance(idx, int) or idx < 0 or idx >= len(history):
206
+ return f"bad index {evt.index!r}"
207
+
208
+ msg = history[idx]
209
+ if msg.get("role") != "assistant":
210
+ return "skipped non-assistant message"
211
+
212
+ record = {
213
+ "ts": time.strftime("%Y-%m-%dT%H:%M:%S"),
214
+ "system": (system or DEFAULT_SYSTEM).strip(),
215
+ "prompt_messages": history[:idx],
216
+ "response": msg.get("content", ""),
217
+ "liked": bool(evt.liked),
218
+ }
219
+ with _feedback_lock:
220
+ items = _read_data()
221
+ items.append(record)
222
+ _write_data(items)
223
+
224
+ verdict = "👍 good" if evt.liked else "👎 bad"
225
+ return f"saved {verdict} · {len(items)} records in data.json"
226
+
227
+
228
+ def clear_data() -> str:
229
+ with _feedback_lock:
230
+ _write_data([])
231
+ return "data.json cleared"
232
+
233
+
234
+ # ------------------------------------------------------------------------
235
+ # Chat callback
236
+
237
+ def chat(user_msg, history, system, max_new, temperature, top_k, top_p, rep_penalty):
238
+ """Stream a reply; yield updated chatbot history after each token."""
239
+ user_msg = (user_msg or "").strip()
240
+ if not user_msg:
241
+ yield history, ""
242
+ return
243
+
244
+ history = list(history) + [
245
+ {"role": "user", "content": user_msg},
246
+ {"role": "assistant", "content": ""},
247
+ ]
248
+ prior = history[:-2]
249
+
250
+ ids = render_prompt_ids(system or DEFAULT_SYSTEM, prior, user_msg)
251
+ room = MCFG.seq_len - int(max_new)
252
+ if len(ids) > room > 0:
253
+ ids = ids[-room:]
254
+
255
+ try:
256
+ for partial in stream_reply(ids, int(max_new), float(temperature),
257
+ int(top_k), float(top_p), float(rep_penalty)):
258
+ history[-1]["content"] = partial
259
+ yield history, ""
260
+ except Exception:
261
+ history[-1]["content"] = f"[error] {traceback.format_exc()}"
262
+ yield history, ""
263
+
264
+
265
+ # ------------------------------------------------------------------------
266
+ # UI
267
+
268
+ with gr.Blocks(title="intellite 100M — RLHF collector") as demo:
269
+ gr.Markdown(
270
+ f"# intellite 100M — RLHF data collector\n"
271
+ f"{MCFG.d_model}d × {MCFG.n_layers}L × {MCFG.n_heads}h "
272
+ f"({N_PARAMS/1e6:.1f}M params) · {TOKENS_SEEN/1e6:.0f}M SFT tokens · "
273
+ f"val_loss {BEST_VAL:.3f} · device `{DEVICE}` \n"
274
+ f"**Please rate every response with 👍 or 👎.** Every rating appends a record "
275
+ f"to `data.json`; grab it from the sidebar for RLHF on your Mac."
276
+ )
277
+
278
+ with gr.Row():
279
+ with gr.Column(scale=3):
280
+ chatbot = gr.Chatbot(
281
+ type="messages",
282
+ height=520,
283
+ show_copy_button=True,
284
+ avatar_images=(None, None),
285
+ )
286
+ msg = gr.Textbox(
287
+ placeholder="Your message — Enter to send",
288
+ lines=2,
289
+ show_label=False,
290
+ autofocus=True,
291
+ )
292
+ with gr.Row():
293
+ send_btn = gr.Button("Send", variant="primary")
294
+ clear_btn = gr.Button("Clear chat")
295
+ feedback_status = gr.Markdown("_rate replies with 👍 / 👎_")
296
+
297
+ with gr.Column(scale=1):
298
+ system = gr.Textbox(value=DEFAULT_SYSTEM, label="System prompt", lines=3)
299
+ max_new = gr.Slider(16, 800, value=400, step=16, label="max new tokens")
300
+ temp = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature")
301
+ top_k = gr.Slider(0, 200, value=50, step=1, label="top-k (0 = off)")
302
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top-p")
303
+ rep = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label="repetition penalty")
304
+
305
+ gr.Markdown("### RLHF data")
306
+ stats_md = gr.Markdown(_stats_str())
307
+ download = gr.DownloadButton(
308
+ label="⬇ Download data.json", value=str(DATA_PATH)
309
+ )
310
+ clear_data_btn = gr.Button("Clear data.json", variant="stop")
311
+
312
+ # Wire the chat events.
313
+ send_btn.click(
314
+ chat,
315
+ inputs=[msg, chatbot, system, max_new, temp, top_k, top_p, rep],
316
+ outputs=[chatbot, msg],
317
+ )
318
+ msg.submit(
319
+ chat,
320
+ inputs=[msg, chatbot, system, max_new, temp, top_k, top_p, rep],
321
+ outputs=[chatbot, msg],
322
+ )
323
+ clear_btn.click(lambda: [], None, chatbot, queue=False)
324
+
325
+ # Thumbs-up / thumbs-down → append to data.json, refresh counters.
326
+ chatbot.like(
327
+ save_feedback,
328
+ inputs=[chatbot, system],
329
+ outputs=[feedback_status],
330
+ ).then(lambda: _stats_str(), None, stats_md, queue=False)
331
+
332
+ clear_data_btn.click(clear_data, None, feedback_status, queue=False).then(
333
+ lambda: _stats_str(), None, stats_md, queue=False
334
+ )
335
+
336
+
337
+ if __name__ == "__main__":
338
+ demo.queue().launch()
config.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class ModelConfig:
6
+ vocab_size: int = 50304 # rounded-up GPT-2 vocab for better matmul shapes
7
+ d_model: int = 768
8
+ n_layers: int = 10
9
+ n_heads: int = 12 # head_dim = 64
10
+ d_ff: int = 2048 # canonical SwiGLU 8/3 * d_model
11
+ seq_len: int = 2048
12
+ dropout: float = 0.0
13
+ rope_theta: float = 10000.0
14
+ tie_embeddings: bool = True
15
+ norm_eps: float = 1e-5
16
+
17
+
18
+ @dataclass
19
+ class TrainConfig:
20
+ # Paths
21
+ data_dir: str = "data"
22
+ out_dir: str = "checkpoints"
23
+
24
+ # Model (mirrors ModelConfig so a single dataclass configures runs)
25
+ vocab_size: int = 50304
26
+ d_model: int = 768
27
+ n_layers: int = 10
28
+ n_heads: int = 12
29
+ d_ff: int = 2048
30
+ seq_len: int = 2048
31
+ dropout: float = 0.0
32
+
33
+ # Training budget
34
+ target_tokens: int = 1_000_000_000
35
+ # Memory at seq=2048 for ~100M params: keep microbatches small and use
36
+ # grad accumulation to keep effective batch = 32 × 2048 = 65_536 tok/step.
37
+ batch_size: int = 4
38
+ grad_accum_steps: int = 8
39
+
40
+ # Optimizer / schedule
41
+ learning_rate: float = 6e-4
42
+ min_lr_ratio: float = 0.1
43
+ warmup_tokens: int = 3_000_000
44
+ weight_decay: float = 0.1
45
+ beta1: float = 0.9
46
+ beta2: float = 0.95
47
+ grad_clip: float = 1.0
48
+
49
+ # Checkpoint / eval cadence (in tokens)
50
+ ckpt_every_tokens: int = 100_000_000
51
+ eval_every_tokens: int = 6_000_000
52
+ eval_batches: int = 50
53
+
54
+ # Logging
55
+ log_every_steps: int = 10
56
+
57
+ # System
58
+ device: str = "mps"
59
+ dtype: str = "bfloat16"
60
+ seed: int = 1337
model.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Small but modern decoder-only transformer (~50M params).
2
+
3
+ Uses RoPE, RMSNorm, SwiGLU FFN, tied embeddings, and PyTorch SDPA
4
+ for causal attention (which lights up MPS fast-paths where available).
5
+ """
6
+
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from config import ModelConfig
13
+
14
+
15
+ def precompute_rope(head_dim: int, seq_len: int, theta: float = 10000.0, device=None):
16
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
17
+ t = torch.arange(seq_len, device=device).float()
18
+ freqs = torch.outer(t, inv_freq) # (T, head_dim/2)
19
+ return freqs.cos(), freqs.sin()
20
+
21
+
22
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
23
+ # x: (B, H, T, D); cos/sin: (T, D/2)
24
+ x1, x2 = x.chunk(2, dim=-1)
25
+ cos = cos[None, None, :, :]
26
+ sin = sin[None, None, :, :]
27
+ return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
28
+
29
+
30
+ class RMSNorm(nn.Module):
31
+ def __init__(self, d: int, eps: float = 1e-5):
32
+ super().__init__()
33
+ self.weight = nn.Parameter(torch.ones(d))
34
+ self.eps = eps
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ # Always compute the norm in fp32 for stability, then cast back.
38
+ dtype = x.dtype
39
+ x32 = x.float()
40
+ norm = torch.rsqrt(x32.pow(2).mean(-1, keepdim=True) + self.eps)
41
+ return (x32 * norm).to(dtype) * self.weight
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, cfg: ModelConfig):
46
+ super().__init__()
47
+ assert cfg.d_model % cfg.n_heads == 0
48
+ self.n_heads = cfg.n_heads
49
+ self.head_dim = cfg.d_model // cfg.n_heads
50
+ self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
51
+ self.o = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
52
+ self.dropout = cfg.dropout
53
+
54
+ def forward(self, x, cos, sin):
55
+ B, T, C = x.shape
56
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
57
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
58
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
59
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
60
+ q = apply_rope(q, cos[:T], sin[:T])
61
+ k = apply_rope(k, cos[:T], sin[:T])
62
+ y = F.scaled_dot_product_attention(
63
+ q, k, v,
64
+ is_causal=True,
65
+ dropout_p=self.dropout if self.training else 0.0,
66
+ )
67
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
68
+ return self.o(y)
69
+
70
+
71
+ class SwiGLU(nn.Module):
72
+ def __init__(self, cfg: ModelConfig):
73
+ super().__init__()
74
+ self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) # gate
75
+ self.w2 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) # down
76
+ self.w3 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) # up
77
+
78
+ def forward(self, x):
79
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
80
+
81
+
82
+ class Block(nn.Module):
83
+ def __init__(self, cfg: ModelConfig):
84
+ super().__init__()
85
+ self.attn_norm = RMSNorm(cfg.d_model, cfg.norm_eps)
86
+ self.attn = Attention(cfg)
87
+ self.ffn_norm = RMSNorm(cfg.d_model, cfg.norm_eps)
88
+ self.ffn = SwiGLU(cfg)
89
+
90
+ def forward(self, x, cos, sin):
91
+ x = x + self.attn(self.attn_norm(x), cos, sin)
92
+ x = x + self.ffn(self.ffn_norm(x))
93
+ return x
94
+
95
+
96
+ class IntelliteGPT(nn.Module):
97
+ def __init__(self, cfg: ModelConfig):
98
+ super().__init__()
99
+ self.cfg = cfg
100
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
101
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
102
+ self.norm = RMSNorm(cfg.d_model, cfg.norm_eps)
103
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
104
+ if cfg.tie_embeddings:
105
+ self.lm_head.weight = self.tok_emb.weight
106
+
107
+ cos, sin = precompute_rope(cfg.d_model // cfg.n_heads, cfg.seq_len, cfg.rope_theta)
108
+ self.register_buffer("cos", cos, persistent=False)
109
+ self.register_buffer("sin", sin, persistent=False)
110
+
111
+ self.apply(self._init_weights)
112
+ # GPT-2 style: scale residual projections by 1/sqrt(2*n_layers)
113
+ scale = 0.02 / math.sqrt(2 * cfg.n_layers)
114
+ for n, p in self.named_parameters():
115
+ if n.endswith("attn.o.weight") or n.endswith("ffn.w2.weight"):
116
+ nn.init.normal_(p, mean=0.0, std=scale)
117
+
118
+ @staticmethod
119
+ def _init_weights(m):
120
+ if isinstance(m, nn.Linear):
121
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
122
+ if m.bias is not None:
123
+ nn.init.zeros_(m.bias)
124
+ elif isinstance(m, nn.Embedding):
125
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
126
+
127
+ def num_params(self, exclude_embedding: bool = False) -> int:
128
+ n = sum(p.numel() for p in self.parameters())
129
+ if exclude_embedding:
130
+ n -= self.tok_emb.weight.numel()
131
+ return n
132
+
133
+ def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
134
+ B, T = idx.shape
135
+ x = self.tok_emb(idx)
136
+ cos, sin = self.cos, self.sin
137
+ for block in self.blocks:
138
+ x = block(x, cos, sin)
139
+ x = self.norm(x)
140
+ logits = self.lm_head(x)
141
+ loss = None
142
+ if targets is not None:
143
+ loss = F.cross_entropy(
144
+ logits.view(-1, logits.size(-1)).float(),
145
+ targets.view(-1),
146
+ ignore_index=-1,
147
+ )
148
+ return logits, loss
149
+
150
+ @torch.no_grad()
151
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
152
+ for _ in range(max_new_tokens):
153
+ idx_cond = idx[:, -self.cfg.seq_len:]
154
+ logits, _ = self(idx_cond)
155
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
156
+ if top_k is not None:
157
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
158
+ logits[logits < v[:, [-1]]] = -float("inf")
159
+ probs = F.softmax(logits, dim=-1)
160
+ next_tok = torch.multinomial(probs, num_samples=1)
161
+ idx = torch.cat([idx, next_tok], dim=1)
162
+ return idx
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=5.0.0
2
+ torch>=2.1.0
3
+ tiktoken
4
+ numpy