abersbail commited on
Commit
79078fe
·
verified ·
1 Parent(s): 3b6bcbe

Add small GPT Python Space

Browse files
README.md CHANGED
@@ -1,12 +1,27 @@
1
  ---
2
- title: Small Gpt Python
3
- emoji: 👀
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.10.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Small GPT Python
3
+ colorFrom: indigo
4
+ colorTo: blue
 
5
  sdk: gradio
 
6
  app_file: app.py
7
  pinned: false
8
+ license: mit
9
  ---
10
 
11
+ # Small GPT Python
12
+
13
+ This is a tiny GPT-style language model project written in Python from scratch.
14
+
15
+ ## What it includes
16
+
17
+ - Word-level tokenizer
18
+ - Causal transformer decoder with self-attention
19
+ - Local CPU training loop
20
+ - Checkpoint save and load
21
+ - Gradio user interface
22
+
23
+ ## Important
24
+
25
+ - No external pretrained LLM is used
26
+ - This is a small educational GPT-like model
27
+ - The first generate or train call will initialize and train the model locally
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from small_gpt.config import SmallGPTConfig
4
+ from small_gpt.service import SmallGPTService
5
+
6
+
7
+ config = SmallGPTConfig()
8
+ service = SmallGPTService(config=config)
9
+
10
+
11
+ def generate_text(prompt, max_new_tokens, temperature, top_k):
12
+ return service.generate(
13
+ prompt=prompt,
14
+ max_new_tokens=int(max_new_tokens),
15
+ temperature=float(temperature),
16
+ top_k=int(top_k),
17
+ )
18
+
19
+
20
+ def train_model(extra_text, steps):
21
+ return service.train(extra_text=extra_text, steps=int(steps))
22
+
23
+
24
+ def reset_model():
25
+ return service.reset()
26
+
27
+
28
+ with gr.Blocks(
29
+ title="Small GPT Python",
30
+ theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"),
31
+ ) as demo:
32
+ gr.Markdown(
33
+ """
34
+ # Small GPT Python
35
+ A tiny GPT-style language model written in Python from scratch.
36
+
37
+ - Causal transformer decoder
38
+ - Word-level tokenizer
39
+ - No external pretrained LLM
40
+ - Local CPU training and generation
41
+ """
42
+ )
43
+
44
+ with gr.Tab("Generate"):
45
+ prompt_input = gr.Textbox(
46
+ label="Prompt",
47
+ value="User: hello\nAssistant:",
48
+ lines=6,
49
+ )
50
+ with gr.Row():
51
+ max_tokens_input = gr.Slider(10, 180, value=72, step=2, label="Max New Tokens")
52
+ temperature_input = gr.Slider(0.2, 1.3, value=0.75, step=0.05, label="Temperature")
53
+ top_k_input = gr.Slider(1, 20, value=8, step=1, label="Top-K")
54
+ generate_button = gr.Button("Generate", variant="primary")
55
+ output_text = gr.Textbox(label="Output", lines=10)
56
+ output_status = gr.Textbox(label="Status", lines=4)
57
+
58
+ with gr.Tab("Train"):
59
+ extra_text_input = gr.Textbox(
60
+ label="Extra Training Text",
61
+ placeholder="Add more local text to continue training the small GPT model.",
62
+ lines=10,
63
+ )
64
+ steps_input = gr.Slider(10, 400, value=120, step=10, label="Training Steps")
65
+ train_button = gr.Button("Train / Continue Training", variant="primary")
66
+ reset_button = gr.Button("Reset Model")
67
+ train_status = gr.Textbox(label="Training Status", lines=6)
68
+
69
+ generate_button.click(
70
+ fn=generate_text,
71
+ inputs=[prompt_input, max_tokens_input, temperature_input, top_k_input],
72
+ outputs=[output_text, output_status],
73
+ )
74
+
75
+ train_button.click(
76
+ fn=train_model,
77
+ inputs=[extra_text_input, steps_input],
78
+ outputs=[train_status],
79
+ )
80
+
81
+ reset_button.click(
82
+ fn=reset_model,
83
+ outputs=[train_status],
84
+ )
85
+
86
+
87
+ if __name__ == "__main__":
88
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio>=5.23.0
2
+ torch>=2.3.0
small_gpt/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .config import SmallGPTConfig
2
+ from .service import SmallGPTService
3
+
4
+ __all__ = ["SmallGPTConfig", "SmallGPTService"]
small_gpt/config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+
4
+
5
+ @dataclass
6
+ class SmallGPTConfig:
7
+ block_size: int = 48
8
+ batch_size: int = 20
9
+ d_model: int = 96
10
+ n_heads: int = 4
11
+ n_layers: int = 3
12
+ dropout: float = 0.1
13
+ learning_rate: float = 2.5e-3
14
+ bootstrap_steps: int = 80
15
+ cpu_threads: int = 4
16
+ seed: int = 42
17
+
18
+ @property
19
+ def root_dir(self) -> Path:
20
+ return Path(__file__).resolve().parents[1]
21
+
22
+ @property
23
+ def checkpoint_path(self) -> Path:
24
+ return self.root_dir / "artifacts" / "small_gpt_checkpoint.pt"
small_gpt/data.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BASE_CORPUS = """
2
+ System: You are a small local GPT style assistant.
3
+ User: hello
4
+ Assistant: Hello. I am a tiny GPT style language model built from scratch in Python.
5
+
6
+ User: what are you
7
+ Assistant: I am a local small GPT style model trained on a compact text corpus.
8
+
9
+ User: explain machine learning
10
+ Assistant: Machine learning builds patterns from data so models can make predictions.
11
+
12
+ User: explain deep learning
13
+ Assistant: Deep learning uses neural networks with multiple layers to learn useful representations.
14
+
15
+ User: explain natural language processing
16
+ Assistant: Natural language processing helps computers understand and generate human language.
17
+
18
+ User: what is regression
19
+ Assistant: Regression predicts a continuous value such as price, score, demand, or temperature.
20
+
21
+ User: what is classification
22
+ Assistant: Classification predicts a category or label for an example.
23
+
24
+ User: what is clustering
25
+ Assistant: Clustering groups similar data points without labeled targets.
26
+
27
+ User: what is preprocessing
28
+ Assistant: Preprocessing cleans and transforms data before modeling.
29
+
30
+ User: what is model evaluation
31
+ Assistant: Model evaluation measures how well a model performs on unseen data.
32
+
33
+ User: what is Python
34
+ Assistant: Python is a programming language known for readable code and strong libraries.
35
+
36
+ User: what is attention
37
+ Assistant: Attention lets a model focus on useful earlier tokens when predicting the next token.
38
+
39
+ User: what is a transformer
40
+ Assistant: A transformer is a neural architecture that uses attention and layered processing.
41
+
42
+ User: what is a tokenizer
43
+ Assistant: A tokenizer converts text into smaller units that a model can process.
44
+
45
+ User: what is overfitting
46
+ Assistant: Overfitting happens when a model memorizes training data and generalizes poorly.
47
+
48
+ User: what is underfitting
49
+ Assistant: Underfitting happens when a model is too simple to capture important patterns.
50
+
51
+ User: write a short greeting
52
+ Assistant: Hello and welcome. I am ready to help with a short answer.
53
+
54
+ User: write a short poem
55
+ Assistant: Quiet rain on window light, patient thoughts arrive at night.
56
+
57
+ User: give a coding tip
58
+ Assistant: Build small pieces, test often, and keep your functions clear.
59
+
60
+ User: how should I debug code
61
+ Assistant: Reproduce the problem, isolate the failing step, inspect values, and test one change at a time.
62
+
63
+ User: summarize good study habits
64
+ Assistant: Good study habits use planning, active recall, spaced review, and steady practice.
65
+
66
+ User: what is data science
67
+ Assistant: Data science combines programming, statistics, and domain knowledge to learn from data.
68
+
69
+ User: what is a dataset
70
+ Assistant: A dataset is a collection of examples or records used for analysis or training.
71
+
72
+ User: what is a feature
73
+ Assistant: A feature is an input variable used by a model.
74
+
75
+ User: what is a target
76
+ Assistant: A target is the value or label a model tries to predict.
77
+
78
+ User: what is local ai
79
+ Assistant: Local AI runs on your own machine so you control the code, files, and execution.
80
+ """.strip()
81
+
82
+
83
+ def build_training_text(extra_text: str = "") -> str:
84
+ extra = " ".join((extra_text or "").split())
85
+ if not extra:
86
+ return BASE_CORPUS
87
+ return BASE_CORPUS + "\n\n" + extra
small_gpt/model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class CausalSelfAttention(nn.Module):
8
+ def __init__(self, d_model, n_heads, block_size, dropout):
9
+ super().__init__()
10
+ self.n_heads = n_heads
11
+ self.head_dim = d_model // n_heads
12
+ self.qkv = nn.Linear(d_model, 3 * d_model)
13
+ self.out_proj = nn.Linear(d_model, d_model)
14
+ self.dropout = nn.Dropout(dropout)
15
+ mask = torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
16
+ self.register_buffer("mask", mask)
17
+
18
+ def forward(self, x):
19
+ batch, seq_len, channels = x.shape
20
+ qkv = self.qkv(x)
21
+ q, k, v = qkv.chunk(3, dim=-1)
22
+
23
+ q = q.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
24
+ k = k.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
25
+ v = v.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
26
+
27
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
28
+ att = att.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float("-inf"))
29
+ att = torch.softmax(att, dim=-1)
30
+ att = self.dropout(att)
31
+
32
+ out = att @ v
33
+ out = out.transpose(1, 2).contiguous().view(batch, seq_len, channels)
34
+ return self.out_proj(out)
35
+
36
+
37
+ class FeedForward(nn.Module):
38
+ def __init__(self, d_model, dropout):
39
+ super().__init__()
40
+ self.net = nn.Sequential(
41
+ nn.Linear(d_model, 4 * d_model),
42
+ nn.GELU(),
43
+ nn.Linear(4 * d_model, d_model),
44
+ nn.Dropout(dropout),
45
+ )
46
+
47
+ def forward(self, x):
48
+ return self.net(x)
49
+
50
+
51
+ class GPTBlock(nn.Module):
52
+ def __init__(self, d_model, n_heads, block_size, dropout):
53
+ super().__init__()
54
+ self.ln1 = nn.LayerNorm(d_model)
55
+ self.attn = CausalSelfAttention(d_model, n_heads, block_size, dropout)
56
+ self.ln2 = nn.LayerNorm(d_model)
57
+ self.ff = FeedForward(d_model, dropout)
58
+
59
+ def forward(self, x):
60
+ x = x + self.attn(self.ln1(x))
61
+ x = x + self.ff(self.ln2(x))
62
+ return x
63
+
64
+
65
+ class SmallGPTModel(nn.Module):
66
+ def __init__(self, vocab_size, block_size, d_model, n_heads, n_layers, dropout):
67
+ super().__init__()
68
+ self.block_size = block_size
69
+ self.token_emb = nn.Embedding(vocab_size, d_model)
70
+ self.pos_emb = nn.Embedding(block_size, d_model)
71
+ self.dropout = nn.Dropout(dropout)
72
+ self.blocks = nn.Sequential(
73
+ *[GPTBlock(d_model, n_heads, block_size, dropout) for _ in range(n_layers)]
74
+ )
75
+ self.ln_f = nn.LayerNorm(d_model)
76
+ self.head = nn.Linear(d_model, vocab_size, bias=False)
77
+ self.head.weight = self.token_emb.weight
78
+
79
+ def forward(self, idx, targets=None):
80
+ batch, seq_len = idx.shape
81
+ positions = torch.arange(seq_len, device=idx.device)
82
+ x = self.token_emb(idx) + self.pos_emb(positions)[None, :, :]
83
+ x = self.dropout(x)
84
+ x = self.blocks(x)
85
+ x = self.ln_f(x)
86
+ logits = self.head(x)
87
+
88
+ loss = None
89
+ if targets is not None:
90
+ loss = nn.functional.cross_entropy(
91
+ logits.reshape(-1, logits.size(-1)),
92
+ targets.reshape(-1),
93
+ )
94
+ return logits, loss
95
+
96
+ def generate(self, idx, max_new_tokens, eos_id, temperature=1.0, top_k=8):
97
+ for _ in range(max_new_tokens):
98
+ idx_cond = idx[:, -self.block_size :]
99
+ logits, _ = self(idx_cond)
100
+ logits = logits[:, -1, :] / max(temperature, 1e-4)
101
+
102
+ if top_k is not None and top_k > 0:
103
+ values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
104
+ logits[logits < values[:, [-1]]] = float("-inf")
105
+
106
+ probs = torch.softmax(logits, dim=-1)
107
+ next_id = torch.multinomial(probs, num_samples=1)
108
+ idx = torch.cat([idx, next_id], dim=1)
109
+ if int(next_id.item()) == eos_id:
110
+ break
111
+ return idx
small_gpt/service.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+
3
+ import torch
4
+
5
+ from .config import SmallGPTConfig
6
+ from .model import SmallGPTModel
7
+ from .tokenizer import WordTokenizer
8
+ from .trainer import create_model_and_tokenizer, set_seed, train_model
9
+
10
+
11
+ class SmallGPTService:
12
+ def __init__(self, config: SmallGPTConfig):
13
+ self.config = config
14
+ torch.set_num_threads(max(1, self.config.cpu_threads))
15
+ self.model = None
16
+ self.tokenizer = None
17
+
18
+ def generate(self, prompt: str, max_new_tokens: int, temperature: float, top_k: int):
19
+ clean_prompt = prompt or "User: hello\nAssistant:"
20
+ self._ensure_ready()
21
+ encoded = self.tokenizer.encode(clean_prompt, add_bos=True)
22
+ idx = torch.tensor(encoded, dtype=torch.long).unsqueeze(0)
23
+ self.model.eval()
24
+
25
+ with torch.inference_mode():
26
+ output = self.model.generate(
27
+ idx=idx,
28
+ max_new_tokens=max_new_tokens,
29
+ eos_id=self.tokenizer.eos_id,
30
+ temperature=temperature,
31
+ top_k=top_k,
32
+ )
33
+
34
+ text = self.tokenizer.decode(output[0].tolist())
35
+ status = (
36
+ f"Generated with small GPT Python. "
37
+ f"Architecture=causal transformer, Vocab={self.tokenizer.vocab_size}, Layers={self.config.n_layers}."
38
+ )
39
+ return text, status
40
+
41
+ def train(self, extra_text: str, steps: int):
42
+ steps = max(1, steps)
43
+ checkpoint_exists = self.config.checkpoint_path.exists()
44
+ training_text = extra_text or ""
45
+
46
+ if checkpoint_exists:
47
+ self._load_or_initialize(extra_text="")
48
+
49
+ model, tokenizer, encoded = create_model_and_tokenizer(self.config, training_text)
50
+ if checkpoint_exists and self.model is not None and self.tokenizer is not None:
51
+ if tokenizer.stoi == self.tokenizer.stoi:
52
+ model.load_state_dict(self.model.state_dict())
53
+
54
+ losses = train_model(model, encoded, self.config, steps)
55
+ self.model = model
56
+ self.tokenizer = tokenizer
57
+ self._save_checkpoint(extra_text=training_text)
58
+
59
+ return (
60
+ f"small GPT training finished.\n"
61
+ f"Steps: {steps}\n"
62
+ f"Start Loss: {losses[0]:.4f}\n"
63
+ f"End Loss: {losses[-1]:.4f}\n"
64
+ f"Checkpoint: {self.config.checkpoint_path}"
65
+ )
66
+
67
+ def reset(self):
68
+ checkpoint_dir = self.config.checkpoint_path.parent
69
+ if checkpoint_dir.exists():
70
+ shutil.rmtree(checkpoint_dir)
71
+ self.model = None
72
+ self.tokenizer = None
73
+ return "small GPT reset complete. Next train or generate call will rebuild from scratch."
74
+
75
+ def _ensure_ready(self):
76
+ if self.model is not None and self.tokenizer is not None:
77
+ return
78
+ self._load_or_initialize(extra_text="")
79
+
80
+ def _load_or_initialize(self, extra_text: str):
81
+ checkpoint = self.config.checkpoint_path
82
+ if checkpoint.exists():
83
+ state = torch.load(checkpoint, map_location="cpu")
84
+ self.tokenizer = WordTokenizer.from_state_dict(state["tokenizer"])
85
+ self.model = SmallGPTModel(
86
+ vocab_size=state["config"]["vocab_size"],
87
+ block_size=state["config"]["block_size"],
88
+ d_model=state["config"]["d_model"],
89
+ n_heads=state["config"]["n_heads"],
90
+ n_layers=state["config"]["n_layers"],
91
+ dropout=state["config"]["dropout"],
92
+ )
93
+ self.model.load_state_dict(state["model"])
94
+ self.model.eval()
95
+ return
96
+
97
+ set_seed(self.config.seed)
98
+ self.model, self.tokenizer, encoded = create_model_and_tokenizer(self.config, extra_text)
99
+ train_model(self.model, encoded, self.config, self.config.bootstrap_steps)
100
+ self._save_checkpoint(extra_text=extra_text)
101
+
102
+ def _save_checkpoint(self, extra_text: str):
103
+ checkpoint = self.config.checkpoint_path
104
+ checkpoint.parent.mkdir(parents=True, exist_ok=True)
105
+ torch.save(
106
+ {
107
+ "model": self.model.state_dict(),
108
+ "tokenizer": self.tokenizer.state_dict(),
109
+ "config": {
110
+ "vocab_size": self.tokenizer.vocab_size,
111
+ "block_size": self.config.block_size,
112
+ "d_model": self.config.d_model,
113
+ "n_heads": self.config.n_heads,
114
+ "n_layers": self.config.n_layers,
115
+ "dropout": self.config.dropout,
116
+ "extra_text": extra_text,
117
+ },
118
+ },
119
+ checkpoint,
120
+ )
small_gpt/tokenizer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ TOKEN_PATTERN = re.compile(r"\n|[A-Za-z0-9_']+|[^\w\s]")
5
+
6
+
7
+ class WordTokenizer:
8
+ def __init__(self):
9
+ self.special_tokens = ["<pad>", "<unk>", "<bos>", "<eos>"]
10
+ self.stoi = {}
11
+ self.itos = {}
12
+
13
+ @property
14
+ def vocab_size(self):
15
+ return len(self.stoi)
16
+
17
+ @property
18
+ def bos_id(self):
19
+ return self.stoi["<bos>"]
20
+
21
+ @property
22
+ def eos_id(self):
23
+ return self.stoi["<eos>"]
24
+
25
+ def tokenize(self, text: str):
26
+ return TOKEN_PATTERN.findall(text)
27
+
28
+ def fit(self, text: str):
29
+ vocab = self.special_tokens + sorted(set(self.tokenize(text)))
30
+ self.stoi = {token: idx for idx, token in enumerate(vocab)}
31
+ self.itos = {idx: token for token, idx in self.stoi.items()}
32
+ return self
33
+
34
+ def encode(self, text: str, add_bos: bool = False, add_eos: bool = False):
35
+ tokens = self.tokenize(text)
36
+ ids = [self.stoi.get(token, self.stoi["<unk>"]) for token in tokens]
37
+ if add_bos:
38
+ ids = [self.bos_id] + ids
39
+ if add_eos:
40
+ ids = ids + [self.eos_id]
41
+ return ids
42
+
43
+ def decode(self, ids):
44
+ tokens = []
45
+ for idx in ids:
46
+ token = self.itos.get(int(idx), "<unk>")
47
+ if token in self.special_tokens:
48
+ continue
49
+ tokens.append(token)
50
+
51
+ text = ""
52
+ for token in tokens:
53
+ if token == "\n":
54
+ text = text.rstrip() + "\n"
55
+ elif token in {".", ",", "!", "?", ":", ";"}:
56
+ text = text.rstrip() + token + " "
57
+ else:
58
+ text += token + " "
59
+ return text.strip()
60
+
61
+ def state_dict(self):
62
+ return {"stoi": self.stoi}
63
+
64
+ @classmethod
65
+ def from_state_dict(cls, state):
66
+ tok = cls()
67
+ tok.stoi = dict(state["stoi"])
68
+ tok.itos = {idx: token for token, idx in tok.stoi.items()}
69
+ return tok
small_gpt/trainer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+
5
+ from .data import build_training_text
6
+ from .model import SmallGPTModel
7
+ from .tokenizer import WordTokenizer
8
+
9
+
10
+ def set_seed(seed: int):
11
+ random.seed(seed)
12
+ torch.manual_seed(seed)
13
+
14
+
15
+ def create_model_and_tokenizer(config, extra_text=""):
16
+ text = build_training_text(extra_text)
17
+ tokenizer = WordTokenizer().fit(text)
18
+ encoded = tokenizer.encode(text, add_bos=True, add_eos=True)
19
+ encoded = torch.tensor(encoded, dtype=torch.long)
20
+ model = SmallGPTModel(
21
+ vocab_size=tokenizer.vocab_size,
22
+ block_size=config.block_size,
23
+ d_model=config.d_model,
24
+ n_heads=config.n_heads,
25
+ n_layers=config.n_layers,
26
+ dropout=config.dropout,
27
+ )
28
+ return model, tokenizer, encoded
29
+
30
+
31
+ def build_batch(encoded, block_size, batch_size):
32
+ max_start = max(1, len(encoded) - block_size - 1)
33
+ starts = torch.randint(0, max_start, (batch_size,))
34
+ x = torch.stack([encoded[start : start + block_size] for start in starts])
35
+ y = torch.stack([encoded[start + 1 : start + block_size + 1] for start in starts])
36
+ return x, y
37
+
38
+
39
+ def train_model(model, encoded, config, steps):
40
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
41
+ model.train()
42
+ losses = []
43
+
44
+ for _ in range(steps):
45
+ xb, yb = build_batch(encoded, config.block_size, config.batch_size)
46
+ _, loss = model(xb, targets=yb)
47
+ optimizer.zero_grad()
48
+ loss.backward()
49
+ optimizer.step()
50
+ losses.append(float(loss.item()))
51
+
52
+ return losses