inwaves commited on
Commit
405f5b1
·
1 Parent(s): 0b6a10a

Refactor config class, add argparser

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. main.py +30 -9
  3. model.py +22 -7
  4. utils.py +26 -11
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *__pycache__/
main.py CHANGED
@@ -2,22 +2,43 @@ import torch as t
2
  import torch.nn as nn
3
  import torch.functional as F
4
  import torch.optim as optim
 
 
 
 
5
 
6
-
7
- def parse_args():
8
  # TODO: command-line args for hparams
9
- pass
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def train():
12
  # TODO: training loop
13
- pass
 
14
 
15
  def eval():
16
  pass
17
 
18
- def setup():
19
- # TODO: wandb logging, load configs, all that stuff
20
- pass
 
 
 
21
 
22
  if __name__=="__main__":
23
- parse_args()
 
 
 
2
  import torch.nn as nn
3
  import torch.functional as F
4
  import torch.optim as optim
5
+ import argparse
6
+ from utils import OsSoluConfig
7
+ from model import OsSoluModel
8
+ from typing import Tuple
9
 
10
+ def parse_arguments() -> argparse.Namespace:
 
11
  # TODO: command-line args for hparams
12
+ parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
13
+ parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
14
+ parser.add_argument("--vocab_size", type=int, default=65536, help="Vocabulary size of the input sequence.")
15
+ parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
16
+ parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
17
+ parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
18
+ parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
19
+ parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
20
+ parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
21
+ parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional. ")
22
+ parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
23
+ args = parser.parse_args()
24
+ return args
25
 
26
+ def train(config: OsSoluConfig, model: OsSoluModel) -> OsSoluModel:
27
  # TODO: training loop
28
+
29
+ return model
30
 
31
  def eval():
32
  pass
33
 
34
+ def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
35
+ # TODO: wandb logging
36
+ args = parse_arguments()
37
+ config = OsSoluConfig(args)
38
+ model = OsSoluModel(config)
39
+ return config, model
40
 
41
  if __name__=="__main__":
42
+ config, model = setup()
43
+ trained_model = train(config, model)
44
+ eval()
model.py CHANGED
@@ -15,7 +15,8 @@ class OsSoluModel(nn.Module):
15
  self.config = config
16
  self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
17
  self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
18
- self.transformer_block = TransformerBlock(config)
 
19
  self.final_ln = nn.LayerNorm(normalized_shape, config.ln_eps)
20
  self.unembed = nn
21
 
@@ -23,23 +24,36 @@ class OsSoluModel(nn.Module):
23
  positional_embeddings = self.embed_positions(t.arange(x.size(1)))
24
  token_embeddings = self.embed_tokens(x)
25
  embeddings = positional_embeddings + token_embeddings
 
 
26
 
 
 
 
 
 
 
27
 
28
- class TransformerBlock(nn.Module):
29
  def __init__(self, config: OsSoluConfig) -> None:
30
  super().__init__()
31
  self.config = config
32
 
 
33
  self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
34
- self.linear = nn.Sequential(
35
- nn.Linear(config.d_model, config.d_model),
 
36
  SoLU(),
 
 
37
  )
38
- self.layer_norm = nn.LayerNorm(normalized_shape, config.ln_eps)
39
- self.unembed = nn.Embedding(config.num_embeddings, config.d_model)
40
 
41
  def forward(self, x: t.Tensor) -> t.Tensor:
42
- pass
 
 
 
43
 
44
 
45
  class UnidirectionalAttention(nn.Module):
@@ -96,4 +110,5 @@ class RotaryAttention(nn.Module):
96
  self.config = config
97
 
98
  def forward(self, x: t.Tensor) -> t.Tensor:
 
99
  pass
 
15
  self.config = config
16
  self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
17
  self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
18
+ self.dropout = nn.Dropout(config.dropout)
19
+ self.transformer_blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.num_blocks)])
20
  self.final_ln = nn.LayerNorm(normalized_shape, config.ln_eps)
21
  self.unembed = nn
22
 
 
24
  positional_embeddings = self.embed_positions(t.arange(x.size(1)))
25
  token_embeddings = self.embed_tokens(x)
26
  embeddings = positional_embeddings + token_embeddings
27
+ out = self.dropout(embeddings)
28
+ out = self.transformer_blocks(out)
29
 
30
+ class SoLU(nn.Module):
31
+ def __init__(self):
32
+ pass
33
+
34
+ def forward(self, x: t.Tensor) -> t.Tensor:
35
+ return x * x.softmax(dim=-1)
36
 
37
+ class GPT2Block(nn.Module):
38
  def __init__(self, config: OsSoluConfig) -> None:
39
  super().__init__()
40
  self.config = config
41
 
42
+ self.layer_norm1 = nn.LayerNorm(normalized_shape, config.ln_eps)
43
  self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
44
+ self.MLP = nn.Sequential(
45
+ nn.LayerNorm(normalized_shape, config.ln_eps),
46
+ nn.Linear(config.d_model, 4*config.d_model),
47
  SoLU(),
48
+ nn.Linear(4*config.d_model, config.d_model),
49
+ nn.Dropout(config.dropout)
50
  )
 
 
51
 
52
  def forward(self, x: t.Tensor) -> t.Tensor:
53
+ x = x + self.attention(self.layer_norm1(x))
54
+ x = x + self.MLP(x)
55
+ return x
56
+
57
 
58
 
59
  class UnidirectionalAttention(nn.Module):
 
110
  self.config = config
111
 
112
  def forward(self, x: t.Tensor) -> t.Tensor:
113
+ # TODO: implement rotary self-attention
114
  pass
utils.py CHANGED
@@ -1,12 +1,27 @@
1
- @dataclass
 
2
  class OsSoluConfig:
3
- d_model: int = 512 # Hidden size of the model.
4
- vocab_size: int = 65536 # Vocabulary size of the input sequence. Unsure about this.
5
- learning_rate: float = 1e-3 # Learning rate for the optimiser.
6
- num_embeddings: int = 1024 # Number of embeddings. Unsure about this.
7
- num_blocks: int = 1 # Number of transformer blocks.
8
- dropout: float = 0.1 # Probability of dropout.
9
- ln_eps: float = 1e-3 # Layer norm epsilon.
10
- num_heads: int = 4 # Number of attention heads in each attention layer.
11
- self_attention_type: str = "unidirectional" # What type of attention to use: rotary or unidirectional.
12
- max_positional_embeddings: int = 1024 # Maximum number of positional embeddings.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
  class OsSoluConfig:
4
+ d_model: int # Hidden size of the model.
5
+ vocab_size: int # Vocabulary size of the input sequence. Unsure about this.
6
+ learning_rate: float # Learning rate for the optimiser.
7
+ num_embeddings: int # Number of embeddings. Unsure about this.
8
+ num_blocks: int # Number of transformer blocks.
9
+ dropout: float # Probability of dropout.
10
+ ln_eps: float # Layer norm epsilon.
11
+ num_heads: int # Number of attention heads in each attention layer.
12
+ self_attention_type: str # What type of attention to use: rotary or unidirectional.
13
+ max_positional_embeddings: int # Maximum number of positional embeddings.
14
+
15
+ def __init__(self, args: argparse.Namespace) -> None:
16
+ """Initialise this config class with values provided by a command-line argument parser.
17
+ Values are never None here, as we provide suitable defaults in the parser call."""
18
+ self.d_model = args.d_model
19
+ self.vocab_size = args.vocab_size
20
+ self.learning_rate = args.learning_rate
21
+ self.num_embeddings = args.num_embeddings
22
+ self.num_blocks = args.num_blocks
23
+ self.dropout = args.dropout
24
+ self.ln_eps = args.ln_eps
25
+ self.num_heads = args.num_heads
26
+ self.self_attention_type = args.self_attention_type
27
+ self.max_positional_embeddings = args.max_positional_embeddings