darthcrawl commited on
Commit
6e14144
·
verified ·
1 Parent(s): 2adf8d3

Add files using upload-large-folder tool

Browse files
Files changed (8) hide show
  1. README.md +120 -0
  2. config.json +12 -0
  3. config.py +50 -0
  4. meta.txt +6 -0
  5. model.py +180 -0
  6. model.safetensors +3 -0
  7. sample.py +55 -0
  8. tokenizer.json +0 -0
README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: pytorch
6
+ tags:
7
+ - causal-lm
8
+ - pretrained-from-scratch
9
+ - small-lm
10
+ - gpt
11
+ datasets:
12
+ - roneneldan/TinyStories
13
+ - roneneldan/TinyStoriesInstruct
14
+ - wikimedia/wikipedia
15
+ - nampdn-ai/tiny-textbooks
16
+ pipeline_tag: text-generation
17
+ ---
18
+
19
+ # tiny-38m
20
+
21
+ A 37.8M-parameter decoder-only transformer pretrained from zero on a mix of small, simple-vocabulary corpora. Pure PyTorch, single GPU, no HF Trainer, no PEFT, no distillation.
22
+
23
+ Educational artifact. Demonstrates that the modern transformer recipe (RMSNorm + RoPE + SwiGLU + SDPA) reaches coherent output at small scale on a single GPU.
24
+
25
+ ## Quick start
26
+
27
+ ```python
28
+ import json, sys, torch
29
+ from pathlib import Path
30
+ from huggingface_hub import snapshot_download
31
+ from tokenizers import Tokenizer
32
+ from safetensors.torch import load_file
33
+
34
+ local = snapshot_download("darthcrawl/tiny-38m")
35
+ sys.path.insert(0, local)
36
+ from config import ModelConfig
37
+ from model import GPT
38
+
39
+ cfg_dict = json.loads((Path(local) / "config.json").read_text())
40
+ valid = {f for f in ModelConfig.__dataclass_fields__}
41
+ cfg = ModelConfig(**{k: v for k, v in cfg_dict.items() if k in valid})
42
+
43
+ model = GPT(cfg).eval()
44
+ model.load_state_dict(load_file(f"{local}/model.safetensors"), strict=False)
45
+
46
+ tok = Tokenizer.from_file(f"{local}/tokenizer.json")
47
+ eot = tok.token_to_id("<|endoftext|>")
48
+
49
+ ids = torch.tensor([tok.encode("Once upon a time, there was a small dragon").ids], dtype=torch.long)
50
+ out = model.generate(ids, max_new_tokens=200, temperature=0.8, top_k=200, eos_id=eot)
51
+ print(tok.decode(out[0].tolist()))
52
+ ```
53
+
54
+ `strict=False` is required because tied embeddings (`lm_head.weight = tok_emb.weight`) get stored once.
55
+
56
+ ## Architecture
57
+
58
+ | | |
59
+ |---|---|
60
+ | Type | Decoder-only transformer |
61
+ | Parameters | 37.8M |
62
+ | Layers | 8 |
63
+ | Hidden dim | 512 |
64
+ | Attention heads | 8 |
65
+ | Context length | 1024 |
66
+ | Vocab size | 8192 |
67
+ | Position encoding | RoPE |
68
+ | Norm | RMSNorm (pre-norm) |
69
+ | MLP | SwiGLU |
70
+ | Attention | PyTorch SDPA, causal |
71
+ | Embedding tying | Yes |
72
+
73
+ ## Training
74
+
75
+ | | |
76
+ |---|---|
77
+ | Source mix | `tinystories:60,tinystories_instruct:15,simple_wiki:15,tiny_textbooks:10` |
78
+ | Total train tokens | 477521740 |
79
+ | Best ckpt step | 19500 |
80
+ | Best val loss | 1.8847 |
81
+ | Optimizer | AdamW (β=(0.9, 0.95), wd=0.1) |
82
+ | Peak LR | 0.0006 |
83
+ | LR schedule | Cosine, 200-step warmup |
84
+ | Batch size | 32 × grad_accum 4 |
85
+ | Precision | bfloat16 (AMP) |
86
+ | Hardware | Single GPU |
87
+
88
+ Mix format is `name:weight,...`. `meta.txt` in this repo is the canonical record.
89
+
90
+ ## Tokenizer
91
+
92
+ Byte-level BPE trained on the same source mix. Single `tokenizer.json` (HuggingFace `tokenizers` format), 8192 merges. Special tokens: `<|endoftext|>` (eot/eos), `<|pad|>`.
93
+
94
+ ## What it can do
95
+
96
+ - Continue toddler-level English narratives in TinyStories register.
97
+ - Produce short factual-sounding text in the simple-Wikipedia register.
98
+ - Follow basic prompt → story patterns from TinyStoriesInstruct.
99
+
100
+ ## What it can't do
101
+
102
+ - General-knowledge QA, code, math, multi-turn chat, reasoning, instructions beyond what was in the training mix.
103
+ - Out-of-distribution vocabulary. Vocab is small and the corpus is intentionally narrow.
104
+ - Reliable factuality. Even on simple-wiki-style prompts it will confabulate.
105
+
106
+ ## Intended use
107
+
108
+ Education, replication, ablations, baseline for from-scratch pretraining experiments. Not for downstream production.
109
+
110
+ ## Limitations and bias
111
+
112
+ Inherits whatever biases live in the synthetic TinyStories corpora and Simple English Wikipedia. Outputs are not safe for any user-facing application. No safety alignment, no instruction tuning, no RLHF.
113
+
114
+ ## Reproducibility
115
+
116
+ Inference code (`model.py`, `config.py`, `sample.py`) ships in this repo. Full training pipeline (tokenizer, data prep, training loop, source mixing) is in the upstream project.
117
+
118
+ ## License
119
+
120
+ Apache 2.0 for code and weights. Training data licenses follow their respective sources (see Datasets in metadata).
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 8192,
3
+ "n_layer": 8,
4
+ "n_head": 8,
5
+ "n_embd": 512,
6
+ "block_size": 1024,
7
+ "rope_base": 10000.0,
8
+ "mlp_mult": 4,
9
+ "dropout": 0.0,
10
+ "tie_embeddings": true,
11
+ "arch": "from_scratch_gpt"
12
+ }
config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field, asdict
2
+
3
+
4
+ @dataclass
5
+ class ModelConfig:
6
+ vocab_size: int = 8192
7
+ n_layer: int = 8
8
+ n_head: int = 8
9
+ n_embd: int = 512
10
+ block_size: int = 1024
11
+ rope_base: float = 10000.0
12
+ mlp_mult: int = 4
13
+ dropout: float = 0.0
14
+ tie_embeddings: bool = True
15
+
16
+ @property
17
+ def head_dim(self) -> int:
18
+ assert self.n_embd % self.n_head == 0
19
+ return self.n_embd // self.n_head
20
+
21
+
22
+ @dataclass
23
+ class TrainConfig:
24
+ out_dir: str = "checkpoints"
25
+ data_dir: str = "data"
26
+ tokenizer_path: str = "data/tokenizer.json"
27
+
28
+ batch_size: int = 32
29
+ grad_accum: int = 4
30
+ max_steps: int = 20000
31
+ eval_interval: int = 500
32
+ eval_iters: int = 100
33
+ log_interval: int = 20
34
+ save_interval: int = 2000
35
+
36
+ lr: float = 6e-4
37
+ min_lr: float = 6e-5
38
+ warmup_steps: int = 200
39
+ weight_decay: float = 0.1
40
+ beta1: float = 0.9
41
+ beta2: float = 0.95
42
+ grad_clip: float = 1.0
43
+
44
+ dtype: str = "bfloat16"
45
+ compile: bool = True
46
+ seed: int = 1337
47
+ device: str = "cuda"
48
+
49
+ def to_dict(self):
50
+ return asdict(self)
meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ dtype=uint16
2
+ vocab=8192
3
+ eot=0
4
+ train_tokens=477521740
5
+ val_tokens=9456433
6
+ mix=tinystories:60,tinystories_instruct:15,simple_wiki:15,tiny_textbooks:10
model.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Decoder-only transformer with RMSNorm, RoPE, SwiGLU. Educational, modern, single-GPU."""
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ from dataclasses import dataclass
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from config import ModelConfig
12
+
13
+
14
+ class RMSNorm(nn.Module):
15
+ def __init__(self, dim: int, eps: float = 1e-6):
16
+ super().__init__()
17
+ self.weight = nn.Parameter(torch.ones(dim))
18
+ self.eps = eps
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
22
+ return self.weight * norm.to(x.dtype)
23
+
24
+
25
+ def build_rope_cache(seq_len: int, head_dim: int, base: float, device, dtype):
26
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
27
+ t = torch.arange(seq_len, device=device).float()
28
+ freqs = torch.outer(t, inv_freq)
29
+ cos = freqs.cos().to(dtype)
30
+ sin = freqs.sin().to(dtype)
31
+ return cos, sin
32
+
33
+
34
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
35
+ # x: (B, H, T, D). Pair adjacent dims and rotate.
36
+ x1, x2 = x[..., 0::2], x[..., 1::2]
37
+ cos = cos[None, None, :x.size(-2), :]
38
+ sin = sin[None, None, :x.size(-2), :]
39
+ rot1 = x1 * cos - x2 * sin
40
+ rot2 = x1 * sin + x2 * cos
41
+ out = torch.stack((rot1, rot2), dim=-1).flatten(-2)
42
+ return out
43
+
44
+
45
+ class CausalSelfAttention(nn.Module):
46
+ def __init__(self, cfg: ModelConfig):
47
+ super().__init__()
48
+ self.n_head = cfg.n_head
49
+ self.head_dim = cfg.head_dim
50
+ self.qkv = nn.Linear(cfg.n_embd, 3 * cfg.n_embd, bias=False)
51
+ self.proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
52
+ self.dropout = cfg.dropout
53
+
54
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
55
+ B, T, C = x.shape
56
+ qkv = self.qkv(x)
57
+ q, k, v = qkv.chunk(3, dim=-1)
58
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
59
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
60
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
61
+
62
+ q = apply_rope(q, cos, sin)
63
+ k = apply_rope(k, cos, sin)
64
+
65
+ y = F.scaled_dot_product_attention(
66
+ q, k, v,
67
+ is_causal=True,
68
+ dropout_p=self.dropout if self.training else 0.0,
69
+ )
70
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
71
+ return self.proj(y)
72
+
73
+
74
+ class SwiGLU(nn.Module):
75
+ def __init__(self, cfg: ModelConfig):
76
+ super().__init__()
77
+ hidden = cfg.mlp_mult * cfg.n_embd
78
+ # Round to multiple of 64 for efficiency.
79
+ hidden = ((hidden + 63) // 64) * 64
80
+ self.w1 = nn.Linear(cfg.n_embd, hidden, bias=False)
81
+ self.w3 = nn.Linear(cfg.n_embd, hidden, bias=False)
82
+ self.w2 = nn.Linear(hidden, cfg.n_embd, bias=False)
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
86
+
87
+
88
+ class Block(nn.Module):
89
+ def __init__(self, cfg: ModelConfig):
90
+ super().__init__()
91
+ self.norm1 = RMSNorm(cfg.n_embd)
92
+ self.attn = CausalSelfAttention(cfg)
93
+ self.norm2 = RMSNorm(cfg.n_embd)
94
+ self.mlp = SwiGLU(cfg)
95
+
96
+ def forward(self, x, cos, sin):
97
+ x = x + self.attn(self.norm1(x), cos, sin)
98
+ x = x + self.mlp(self.norm2(x))
99
+ return x
100
+
101
+
102
+ class GPT(nn.Module):
103
+ def __init__(self, cfg: ModelConfig):
104
+ super().__init__()
105
+ self.cfg = cfg
106
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
107
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
108
+ self.norm = RMSNorm(cfg.n_embd)
109
+ self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
110
+ if cfg.tie_embeddings:
111
+ self.lm_head.weight = self.tok_emb.weight
112
+
113
+ self.apply(self._init_weights)
114
+ # Scale residual projections per GPT-2 init.
115
+ for name, p in self.named_parameters():
116
+ if name.endswith("proj.weight") or name.endswith("w2.weight"):
117
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * cfg.n_layer))
118
+
119
+ self._rope_cache = None
120
+
121
+ def _init_weights(self, m):
122
+ if isinstance(m, nn.Linear):
123
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
124
+ if m.bias is not None:
125
+ nn.init.zeros_(m.bias)
126
+ elif isinstance(m, nn.Embedding):
127
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
128
+
129
+ def num_params(self, non_embedding: bool = True) -> int:
130
+ n = sum(p.numel() for p in self.parameters())
131
+ if non_embedding and self.cfg.tie_embeddings:
132
+ n -= self.tok_emb.weight.numel()
133
+ return n
134
+
135
+ def _rope(self, T: int, device, dtype):
136
+ if (self._rope_cache is None
137
+ or self._rope_cache[0].size(0) < T
138
+ or self._rope_cache[0].device != device
139
+ or self._rope_cache[0].dtype != dtype):
140
+ self._rope_cache = build_rope_cache(
141
+ self.cfg.block_size, self.cfg.head_dim, self.cfg.rope_base, device, dtype,
142
+ )
143
+ cos, sin = self._rope_cache
144
+ return cos[:T], sin[:T]
145
+
146
+ def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
147
+ B, T = idx.shape
148
+ assert T <= self.cfg.block_size, f"sequence length {T} > block_size {self.cfg.block_size}"
149
+
150
+ x = self.tok_emb(idx)
151
+ cos, sin = self._rope(T, x.device, x.dtype)
152
+ for block in self.blocks:
153
+ x = block(x, cos, sin)
154
+ x = self.norm(x)
155
+
156
+ if targets is None:
157
+ logits = self.lm_head(x[:, [-1], :])
158
+ return logits, None
159
+
160
+ logits = self.lm_head(x)
161
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
162
+ return logits, loss
163
+
164
+ @torch.no_grad()
165
+ def generate(self, idx: torch.Tensor, max_new_tokens: int,
166
+ temperature: float = 1.0, top_k: int | None = None,
167
+ eos_id: int | None = None):
168
+ for _ in range(max_new_tokens):
169
+ idx_cond = idx if idx.size(1) <= self.cfg.block_size else idx[:, -self.cfg.block_size:]
170
+ logits, _ = self(idx_cond)
171
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
172
+ if top_k is not None:
173
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
174
+ logits[logits < v[:, [-1]]] = -float("inf")
175
+ probs = F.softmax(logits, dim=-1)
176
+ next_id = torch.multinomial(probs, num_samples=1)
177
+ idx = torch.cat((idx, next_id), dim=1)
178
+ if eos_id is not None and (next_id == eos_id).all():
179
+ break
180
+ return idx
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61c6b1b608732dfd322ac3b51cfadee1382a575f23a5b1dad2064baf75447f69
3
+ size 151035216
sample.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate from a trained checkpoint."""
2
+ import argparse
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from tokenizers import Tokenizer
7
+
8
+ from config import ModelConfig
9
+ from model import GPT
10
+
11
+
12
+ def main():
13
+ p = argparse.ArgumentParser()
14
+ p.add_argument("--ckpt", type=str, default="checkpoints/best.pt")
15
+ p.add_argument("--tokenizer", type=str, default="data/tokenizer.json")
16
+ p.add_argument("--prompt", type=str, default="Once upon a time")
17
+ p.add_argument("--max-new-tokens", type=int, default=256)
18
+ p.add_argument("--temperature", type=float, default=0.8)
19
+ p.add_argument("--top-k", type=int, default=200)
20
+ p.add_argument("--num-samples", type=int, default=1)
21
+ p.add_argument("--seed", type=int, default=42)
22
+ p.add_argument("--device", type=str, default=None)
23
+ args = p.parse_args()
24
+
25
+ device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
26
+ torch.manual_seed(args.seed)
27
+
28
+ ckpt = torch.load(args.ckpt, map_location=device, weights_only=False)
29
+ cfg_dict = ckpt["model_cfg"]
30
+ valid = {f for f in ModelConfig.__dataclass_fields__}
31
+ cfg = ModelConfig(**{k: v for k, v in cfg_dict.items() if k in valid})
32
+
33
+ model = GPT(cfg).to(device).eval()
34
+ model.load_state_dict(ckpt["model"])
35
+
36
+ tok = Tokenizer.from_file(args.tokenizer)
37
+ eot = tok.token_to_id("<|endoftext|>")
38
+
39
+ ids = tok.encode(args.prompt).ids
40
+ if not ids:
41
+ ids = [eot]
42
+ x = torch.tensor([ids], dtype=torch.long, device=device)
43
+
44
+ for s in range(args.num_samples):
45
+ out = model.generate(
46
+ x, max_new_tokens=args.max_new_tokens,
47
+ temperature=args.temperature, top_k=args.top_k, eos_id=eot,
48
+ )[0].tolist()
49
+ text = tok.decode(out)
50
+ print(f"\n--- sample {s + 1} ---")
51
+ print(text)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff