WIP getting the Pile dataset up and running
Browse files- main.py +127 -20
- model.py +16 -8
- requirements.txt +6 -5
- utils.py +27 -19
main.py
CHANGED
@@ -1,44 +1,151 @@
|
|
|
|
1 |
import torch as t
|
2 |
import torch.nn as nn
|
3 |
import torch.functional as F
|
4 |
import torch.optim as optim
|
5 |
-
import
|
|
|
|
|
|
|
|
|
|
|
6 |
from utils import OsSoluConfig
|
7 |
from model import OsSoluModel
|
8 |
-
from typing import Tuple
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
|
|
|
13 |
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
|
14 |
-
parser.add_argument("--vocab_size", type=int, default=65536, help="Vocabulary size of the input sequence.")
|
15 |
-
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
16 |
-
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
|
17 |
-
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
|
18 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
|
|
19 |
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
|
20 |
-
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
|
21 |
-
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional. ")
|
22 |
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
return args
|
25 |
|
26 |
-
def train(config: OsSoluConfig, model: OsSoluModel) -> OsSoluModel:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# TODO: training loop
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
return model
|
30 |
|
31 |
-
def eval():
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
args = parse_arguments()
|
|
|
37 |
config = OsSoluConfig(args)
|
38 |
model = OsSoluModel(config)
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
if __name__=="__main__":
|
42 |
-
config, model = setup()
|
43 |
-
trained_model = train(config, model)
|
44 |
-
eval()
|
|
|
1 |
+
import argparse
|
2 |
import torch as t
|
3 |
import torch.nn as nn
|
4 |
import torch.functional as F
|
5 |
import torch.optim as optim
|
6 |
+
from tqdm import tqdm
|
7 |
+
import wandb
|
8 |
+
|
9 |
+
from typing import Tuple
|
10 |
+
from torch.utils.data.dataloader import DataLoader
|
11 |
+
from datasets import load_dataset
|
12 |
from utils import OsSoluConfig
|
13 |
from model import OsSoluModel
|
|
|
14 |
|
15 |
+
WANDB_PROJECT_NAME = "os_solu"
|
16 |
+
DEVICE = "cuda" if t.cuda.is_available() else "cpu"
|
17 |
+
|
18 |
+
def parse_arguments() -> dict:
|
19 |
+
"""Parses command-line arguments for this model run. Arguments of type string have allowed values,
|
20 |
+
which are enforced. Default parameter values are provided such that fields in the config are never None.
|
21 |
+
|
22 |
+
Raises:
|
23 |
+
ValueError: optimiser type must be adam or sgd.
|
24 |
+
ValueError: attention type must be rotary or unidirectional.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
dict: a dictionary containing the command-line arguments parsed by this function.
|
28 |
+
"""
|
29 |
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
|
30 |
+
parser.add_argument("--batch_size", type=int, default=256, help="Batch size used in training.")
|
31 |
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
|
|
|
|
|
|
|
|
|
32 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
33 |
+
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
34 |
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
|
|
|
|
|
35 |
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
|
36 |
+
parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
|
37 |
+
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
|
38 |
+
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
|
39 |
+
parser.add_argument("--num_epochs", type=int, default=5, help="Number of epochs to run for.")
|
40 |
+
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
|
41 |
+
parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
|
42 |
+
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
|
43 |
+
parser.add_argument("--vocab_size", type=int, default=65536, help="Vocabulary size of the input sequence.")
|
44 |
+
args = vars(parser.parse_args())
|
45 |
+
|
46 |
+
# Parse string arguments.
|
47 |
+
allowed_values = {
|
48 |
+
"optimiser_type": ["adam", "sgd"],
|
49 |
+
"self_attention_type": ["unidirectional", "rotary"],
|
50 |
+
"nonlinearity": ["relu", "solu"],
|
51 |
+
}
|
52 |
+
|
53 |
+
for key, values in allowed_values.items():
|
54 |
+
if args[key] not in values:
|
55 |
+
raise ValueError(f"{key} should be one of {values}.")
|
56 |
+
|
57 |
return args
|
58 |
|
59 |
+
def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader) -> OsSoluModel:
|
60 |
+
"""Trains a model using the config and training dataset provided.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
config (OsSoluConfig): The config object.
|
64 |
+
model (OsSoluModel): The model to train.
|
65 |
+
train_dataloader (t.utils.data.DataLoader): The training dataset provided as a torch DataLoader object.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
OsSoluModel: The trained model.
|
69 |
+
"""
|
70 |
# TODO: training loop
|
71 |
+
train_loss_fn = t.nn.CrossEntropyLoss()
|
72 |
+
wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
|
73 |
+
|
74 |
+
# Initialise optimiser.
|
75 |
+
opt = optim.Adam if config.optimiser_type.lower() == "adam" else optim.SGD
|
76 |
+
optimiser = opt(model.parameters(), lr=config.learning_rate)
|
77 |
+
|
78 |
+
# Train loop.
|
79 |
+
examples_seen = 0
|
80 |
+
for epoch in range(config.num_epochs):
|
81 |
+
for i, (data, target) in enumerate(tqdm(train_dataloader)):
|
82 |
+
print(data, target)
|
83 |
+
data = data.to(DEVICE)
|
84 |
+
target = target.to(DEVICE)
|
85 |
+
|
86 |
+
predictions = model(data)
|
87 |
+
accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
|
88 |
+
optimiser.zero_grad()
|
89 |
+
loss = train_loss_fn(target, predictions)
|
90 |
+
loss.backward()
|
91 |
+
optimiser.step()
|
92 |
+
|
93 |
+
wandb.log(dict(train_loss=loss, train_accuracy=accuracy, elapsed=time.time() - start_time), step=examples_seen)
|
94 |
+
examples_seen += len(data)
|
95 |
+
|
96 |
return model
|
97 |
|
98 |
+
def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None:
|
99 |
+
"""Evaluates a trained model on the test dataset provided.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
model (OsSoluModel): The trained model.
|
103 |
+
test_dataset (t.utils.data.Dataset): The dataset on which to evaluate the model.
|
104 |
+
"""
|
105 |
+
test_loss_fn = t.nn.CrossEntropyLoss()
|
106 |
+
|
107 |
+
# Eval loop.
|
108 |
+
examples_seen = 0
|
109 |
+
total_loss, num_correct = 0, 0
|
110 |
+
model.eval()
|
111 |
+
with t.inference_mode():
|
112 |
+
for i, (data, target) in enumerate(tqdm(test_dataloader)):
|
113 |
+
data = data.to(DEVICE)
|
114 |
+
target = target.to(DEVICE)
|
115 |
+
|
116 |
+
predictions = model(data)
|
117 |
+
num_correct += (predictions.argmax(dim=-1) == target).sum().item()
|
118 |
+
total_loss += test_loss_fn(target, predictions).item()
|
119 |
+
examples_seen += len(data)
|
120 |
+
wandb.log(dict(test_loss=total_loss, test_accuracy=num_correct / examples_seen, elapsed=time.time() - start_time), step=examples_seen)
|
121 |
+
|
122 |
+
# Save the model's state on disk, then upload to wandb.
|
123 |
+
filename = f"{wandb.run.dir}/model_state_dict.pt"
|
124 |
+
t.save(model.state_dict(), filename)
|
125 |
+
wandb.save(filename)
|
126 |
+
|
127 |
|
128 |
def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
|
129 |
+
"""This function delegates the setup to various helper functions.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
Tuple[OsSoluConfig, OsSoluModel, datasets.iterable_dataset.IterableDataset, datasets.iterable_dataset.IterableDataset]: A tuple containing a config, a model, a training dataset and a test dataset.
|
133 |
+
"""
|
134 |
args = parse_arguments()
|
135 |
+
wandb.init(project=WANDB_PROJECT_NAME, config=args)
|
136 |
config = OsSoluConfig(args)
|
137 |
model = OsSoluModel(config)
|
138 |
+
|
139 |
+
# Load and prep data.
|
140 |
+
ds = load_dataset("the_pile", streaming=True)
|
141 |
+
train_dataset = ds["train"].with_format("torch")
|
142 |
+
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
|
143 |
+
|
144 |
+
test_dataset = ds["test"].with_format("torch")
|
145 |
+
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
|
146 |
+
return config, model, (train_dataloader, test_dataloader)
|
147 |
|
148 |
if __name__=="__main__":
|
149 |
+
config, model, (train_dataloader, test_dataloader) = setup()
|
150 |
+
trained_model = train(config, model, train_dataloader)
|
151 |
+
eval(trained_model, test_dataloader)
|
model.py
CHANGED
@@ -8,28 +8,35 @@ from einops import rearrange, repeat, reduce
|
|
8 |
from utils import OsSoluConfig
|
9 |
|
10 |
|
|
|
11 |
class OsSoluModel(nn.Module):
|
|
|
|
|
12 |
def __init__(self, config: OsSoluConfig) -> None:
|
13 |
super().__init__()
|
14 |
-
normalised_shape = None # TODO: normalised_shape should be defined properly
|
15 |
self.config = config
|
16 |
self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
|
17 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
|
18 |
self.dropout = nn.Dropout(config.dropout)
|
19 |
self.transformer_blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.num_blocks)])
|
20 |
-
self.final_ln = nn.LayerNorm(
|
21 |
-
self.unembed = nn
|
22 |
|
23 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
24 |
positional_embeddings = self.embed_positions(t.arange(x.size(1)))
|
25 |
token_embeddings = self.embed_tokens(x)
|
26 |
embeddings = positional_embeddings + token_embeddings
|
27 |
out = self.dropout(embeddings)
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
class SoLU(nn.Module):
|
|
|
31 |
def __init__(self):
|
32 |
-
|
33 |
|
34 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
35 |
return x * x.softmax(dim=-1)
|
@@ -39,12 +46,13 @@ class GPT2Block(nn.Module):
|
|
39 |
super().__init__()
|
40 |
self.config = config
|
41 |
|
42 |
-
self.layer_norm1 = nn.LayerNorm(
|
43 |
self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
|
|
|
44 |
self.MLP = nn.Sequential(
|
45 |
-
nn.LayerNorm(
|
46 |
nn.Linear(config.d_model, 4*config.d_model),
|
47 |
-
|
48 |
nn.Linear(4*config.d_model, config.d_model),
|
49 |
nn.Dropout(config.dropout)
|
50 |
)
|
|
|
8 |
from utils import OsSoluConfig
|
9 |
|
10 |
|
11 |
+
|
12 |
class OsSoluModel(nn.Module):
|
13 |
+
"""An open-source implementation of a SoLU-based transformer. This is a GPT-style architecture model
|
14 |
+
where the nonlinearity in the MLP block is replaced with SoLU(x) = x * softmax(x)."""
|
15 |
def __init__(self, config: OsSoluConfig) -> None:
|
16 |
super().__init__()
|
|
|
17 |
self.config = config
|
18 |
self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
|
19 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
|
20 |
self.dropout = nn.Dropout(config.dropout)
|
21 |
self.transformer_blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.num_blocks)])
|
22 |
+
self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)
|
|
|
23 |
|
24 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
25 |
positional_embeddings = self.embed_positions(t.arange(x.size(1)))
|
26 |
token_embeddings = self.embed_tokens(x)
|
27 |
embeddings = positional_embeddings + token_embeddings
|
28 |
out = self.dropout(embeddings)
|
29 |
+
for block in self.transformer_blocks:
|
30 |
+
out = block(out)
|
31 |
+
|
32 |
+
# Unembedding is not separate, so we just einsum with token embedding weights.
|
33 |
+
out = einsum("vocab hidden, batch seq hidden -> batch seq vocab", self.embed_tokens.weight, out)
|
34 |
+
return out
|
35 |
|
36 |
class SoLU(nn.Module):
|
37 |
+
"""A simple wrapper around the SoLU function such that it can be used as a layer in a model."""
|
38 |
def __init__(self):
|
39 |
+
super().__init__()
|
40 |
|
41 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
42 |
return x * x.softmax(dim=-1)
|
|
|
46 |
super().__init__()
|
47 |
self.config = config
|
48 |
|
49 |
+
self.layer_norm1 = nn.LayerNorm(config.d_model, config.ln_eps)
|
50 |
self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
|
51 |
+
nonlinearity = SoLU() if config.nonlinearity == "solu" else nn.ReLU()
|
52 |
self.MLP = nn.Sequential(
|
53 |
+
nn.LayerNorm(config.d_model, config.ln_eps),
|
54 |
nn.Linear(config.d_model, 4*config.d_model),
|
55 |
+
nonlinearity,
|
56 |
nn.Linear(4*config.d_model, config.d_model),
|
57 |
nn.Dropout(config.dropout)
|
58 |
)
|
requirements.txt
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
-
|
2 |
-
wandb
|
3 |
einops
|
4 |
fancy_einsum
|
5 |
-
tqdm
|
6 |
ipykernel
|
7 |
-
notebook
|
8 |
ipywidgets
|
9 |
jupyter
|
10 |
matplotlib
|
|
|
11 |
numpy-stl
|
|
|
|
|
|
|
12 |
wandb
|
13 |
-
|
|
|
1 |
+
datasets
|
|
|
2 |
einops
|
3 |
fancy_einsum
|
|
|
4 |
ipykernel
|
|
|
5 |
ipywidgets
|
6 |
jupyter
|
7 |
matplotlib
|
8 |
+
notebook
|
9 |
numpy-stl
|
10 |
+
plotly
|
11 |
+
torch
|
12 |
+
tqdm
|
13 |
wandb
|
14 |
+
zstandard
|
utils.py
CHANGED
@@ -1,27 +1,35 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
class OsSoluConfig:
|
|
|
|
|
|
|
4 |
d_model: int # Hidden size of the model.
|
5 |
-
vocab_size: int # Vocabulary size of the input sequence. Unsure about this.
|
6 |
-
learning_rate: float # Learning rate for the optimiser.
|
7 |
-
num_embeddings: int # Number of embeddings. Unsure about this.
|
8 |
-
num_blocks: int # Number of transformer blocks.
|
9 |
dropout: float # Probability of dropout.
|
|
|
10 |
ln_eps: float # Layer norm epsilon.
|
|
|
|
|
|
|
|
|
|
|
11 |
num_heads: int # Number of attention heads in each attention layer.
|
12 |
self_attention_type: str # What type of attention to use: rotary or unidirectional.
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
16 |
"""Initialise this config class with values provided by a command-line argument parser.
|
17 |
Values are never None here, as we provide suitable defaults in the parser call."""
|
18 |
-
self.
|
19 |
-
self.
|
20 |
-
self.
|
21 |
-
self.
|
22 |
-
self.
|
23 |
-
self.
|
24 |
-
self.
|
25 |
-
self.
|
26 |
-
self.
|
27 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
class OsSoluConfig:
|
2 |
+
"""A class to hold hyperparameters for the model itself and for the training process."""
|
3 |
+
|
4 |
+
batch_size: int # Training data batch size.
|
5 |
d_model: int # Hidden size of the model.
|
|
|
|
|
|
|
|
|
6 |
dropout: float # Probability of dropout.
|
7 |
+
learning_rate: float # Learning rate for the optimiser.
|
8 |
ln_eps: float # Layer norm epsilon.
|
9 |
+
max_positional_embeddings: int # Maximum number of positional embeddings.
|
10 |
+
nonlinearity: str # Nonlinearity to use inside MLP block: must be ReLU or SoLU.
|
11 |
+
num_blocks: int # Number of transformer blocks.
|
12 |
+
num_embeddings: int # Number of embeddings. Unsure about this.
|
13 |
+
num_epochs: int # Number of epochs for this run.
|
14 |
num_heads: int # Number of attention heads in each attention layer.
|
15 |
self_attention_type: str # What type of attention to use: rotary or unidirectional.
|
16 |
+
optimiser_type: str # Optimiser type: SGD, Adam.
|
17 |
+
vocab_size: int # Vocabulary size of the input sequence. Unsure about this.
|
18 |
+
|
19 |
+
def __init__(self, args: dict) -> None:
|
20 |
"""Initialise this config class with values provided by a command-line argument parser.
|
21 |
Values are never None here, as we provide suitable defaults in the parser call."""
|
22 |
+
self.batch_size = args["batch_size"]
|
23 |
+
self.d_model = args["d_model"]
|
24 |
+
self.dropout = args["dropout"]
|
25 |
+
self.learning_rate = args["learning_rate"]
|
26 |
+
self.ln_eps = args["ln_eps"]
|
27 |
+
self.max_positional_embeddings = args["max_positional_embeddings"]
|
28 |
+
self.nonlinearity = args["nonlinearity"]
|
29 |
+
self.num_blocks = args["num_blocks"]
|
30 |
+
self.num_embeddings = args["num_embeddings"]
|
31 |
+
self.num_epochs = args["num_epochs"]
|
32 |
+
self.num_heads = args["num_heads"]
|
33 |
+
self.optimiser_type = args["optimiser_type"]
|
34 |
+
self.self_attention_type = args["self_attention_type"]
|
35 |
+
self.vocab_size = args["vocab_size"]
|