Matthev00 commited on
Commit
036204f
1 Parent(s): ae8d4e0

model bilding function

Browse files
Files changed (1) hide show
  1. model.py +169 -0
model.py CHANGED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ batch_size = 32
6
+ block_size = 128
7
+ max_iters = 1000
8
+ learning_rate = 3e-4
9
+ eval_steps = 200
10
+ n_embd = 384
11
+ n_head = 4
12
+ n_layer = 4
13
+ dropout = 0.2
14
+
15
+
16
+ class Block(nn.Module):
17
+ """Transformer block: communication followed by computation"""
18
+
19
+ def __init__(self, n_embd, n_head):
20
+ # n_embd: embedding dimension, n_head: the number of heads we'd like
21
+ super().__init__()
22
+ head_size = n_embd // n_head
23
+ self.sa = MultiHeadAttention(n_head, head_size)
24
+ self.ffwd = FeedFoward(n_embd)
25
+ self.ln1 = nn.LayerNorm(n_embd)
26
+ self.ln2 = nn.LayerNorm(n_embd)
27
+
28
+ def forward(self, x):
29
+ y = self.sa(x)
30
+ x = self.ln1(x + y)
31
+ y = self.ffwd(x)
32
+ x = self.ln2(x + y)
33
+ return x
34
+
35
+
36
+ class MultiHeadAttention(nn.Module):
37
+ """multiple heads of self-attention in parallel"""
38
+
39
+ def __init__(self, num_heads, head_size):
40
+ super().__init__()
41
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
42
+ self.proj = nn.Linear(head_size * num_heads, n_embd)
43
+ self.dropout = nn.Dropout(dropout)
44
+
45
+ def forward(self, x):
46
+ # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])
47
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
48
+ out = self.dropout(self.proj(out))
49
+ return out
50
+
51
+
52
+ class Head(nn.Module):
53
+ """one head of self-attention"""
54
+
55
+ def __init__(self, head_size):
56
+ super().__init__()
57
+ self.key = nn.Linear(n_embd, head_size, bias=False)
58
+ self.query = nn.Linear(n_embd, head_size, bias=False)
59
+ self.value = nn.Linear(n_embd, head_size, bias=False)
60
+ self.register_buffer(
61
+ "tril", torch.tril(torch.ones(block_size, block_size))
62
+ ) # noqa 5501
63
+
64
+ self.dropout = nn.Dropout(dropout)
65
+
66
+ def forward(self, x):
67
+ # input of size (batch, time-step, channels)
68
+ # output of size (batch, time-step, head size)
69
+ B, T, C = x.shape
70
+ k = self.key(x) # (B,T,hs)
71
+ q = self.query(x) # (B,T,hs)
72
+ # compute attention scores ("affinities")
73
+ wei = (
74
+ q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
75
+ ) # (B, T, hs) @ (B, hs, T) -> (B, T, T) # noqa 5501
76
+ wei = wei.masked_fill(
77
+ self.tril[:T, :T] == 0, float("-inf")
78
+ ) # (B, T, T) # noqa 5501
79
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
80
+ wei = self.dropout(wei)
81
+ # perform the weighted aggregation of the values
82
+ v = self.value(x) # (B,T,hs)
83
+ out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
84
+ return out
85
+
86
+
87
+ class FeedFoward(nn.Module):
88
+ """Simple linear layer followed by non_linear layer"""
89
+
90
+ def __init__(self, n_embd):
91
+ super().__init__()
92
+ self.net = nn.Sequential(
93
+ nn.Linear(n_embd, 4 * n_embd),
94
+ nn.ReLU(),
95
+ nn.Linear(4 * n_embd, n_embd),
96
+ nn.Dropout(dropout),
97
+ )
98
+
99
+ def forward(self, x):
100
+ return self.net(x)
101
+
102
+
103
+ class GPTLanguageModel(nn.Module):
104
+ def __init__(self, vocab_size, device):
105
+ super().__init__()
106
+ self.device = device
107
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
108
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
109
+ self.blocks = nn.Sequential(
110
+ *[Block(n_embd, n_head=n_head) for _ in range(n_layer)]
111
+ ) # noqa 5501
112
+ self.ln_f = nn.LayerNorm(n_embd)
113
+ self.lm_head = nn.Linear(n_embd, vocab_size)
114
+
115
+ self.apply(self._init_weights)
116
+
117
+ def _init_weights(self, module):
118
+ if isinstance(module, nn.Linear):
119
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
120
+ if module.bias is not None:
121
+ torch.nn.init.zeros_(module.bias)
122
+ elif isinstance(module, nn.Embedding):
123
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
124
+
125
+ def forward(self, index, targets=None):
126
+ B, T = index.shape
127
+
128
+ # idx and targets are both (B,T) tensor of integers
129
+ tok_emb = self.token_embedding_table(index) # (B,T,C)
130
+ pos_emb = self.position_embedding_table(
131
+ torch.arange(T, device=self.device)
132
+ ) # (T,C) # noqa 5501
133
+ x = tok_emb + pos_emb # (B,T,C)
134
+ x = self.blocks(x) # (B,T,C)
135
+ x = self.ln_f(x) # (B,T,C)
136
+ logits = self.lm_head(x) # (B,T,vocab_size)
137
+
138
+ if targets is None:
139
+ loss = None
140
+ else:
141
+ B, T, C = logits.shape
142
+ logits = logits.view(B * T, C)
143
+ targets = targets.view(B * T)
144
+ loss = F.cross_entropy(logits, targets)
145
+
146
+ return logits, loss
147
+
148
+ def generate(self, index, max_new_tokens):
149
+ # index is (B, T) array of indices in the current context
150
+ for _ in range(max_new_tokens):
151
+ # crop idx to the last block_size tokens
152
+ index_cond = index[:, -block_size:]
153
+ # get the predictions
154
+ logits, loss = self.forward(index_cond)
155
+ # focus only on the last time step
156
+ logits = logits[:, -1, :] # becomes (B, C)
157
+ # apply softmax to get probabilities
158
+ probs = F.softmax(logits, dim=-1) # (B, C)
159
+ # sample from the distribution
160
+ index_next = torch.multinomial(probs, num_samples=1) # (B, 1)
161
+ # append sampled index to the running sequence
162
+ index = torch.cat((index, index_next), dim=1) # (B, T+1)
163
+ return index
164
+
165
+
166
+ def create_GPT_model(vocab_size, device):
167
+ model = GPTLanguageModel(vocab_size=vocab_size, device=device)
168
+ model = model.to(device)
169
+ return model