GunaKoppula commited on
Commit
6921f91
1 Parent(s): a856b9f

Update gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +77 -95
gpt.py CHANGED
@@ -1,138 +1,120 @@
1
  import torch
2
- import torch.nn as nn
3
- from torch.nn import functional as F
4
- import config as cfg
5
 
6
- class Head(nn.Module):
7
- """ one head of self-attention """
8
 
9
- def __init__(self, head_size):
 
10
  super().__init__()
11
- self.key = nn.Linear(cfg.n_embd, head_size, bias=False)
12
- self.query = nn.Linear(cfg.n_embd, head_size, bias=False)
13
- self.value = nn.Linear(cfg.n_embd, head_size, bias=False)
14
- self.register_buffer('tril', torch.tril(torch.ones(cfg.block_size, cfg.block_size)))
15
-
16
- self.dropout = nn.Dropout(cfg.dropout)
17
 
18
  def forward(self, x):
19
- # input of size (batch, time-step, channels)
20
- # output of size (batch, time-step, head size)
21
- B,T,C = x.shape
22
- k = self.key(x) # (B,T,hs)
23
- q = self.query(x) # (B,T,hs)
24
- # compute attention scores ("affinities")
25
- wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
26
- wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
27
- wei = F.softmax(wei, dim=-1) # (B, T, T)
28
  wei = self.dropout(wei)
29
- # perform the weighted aggregation of the values
30
- v = self.value(x) # (B,T,hs)
31
- out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
32
  return out
33
 
34
- class MultiHeadAttention(nn.Module):
35
- """ multiple heads of self-attention in parallel """
36
 
37
- def __init__(self, num_heads, head_size):
 
38
  super().__init__()
39
- self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
40
- self.proj = nn.Linear(head_size * num_heads, cfg.n_embd)
41
- self.dropout = nn.Dropout(cfg.dropout)
 
 
42
 
43
  def forward(self, x):
44
- out = torch.cat([h(x) for h in self.heads], dim=-1)
45
- out = self.dropout(self.proj(out))
46
- return out
 
47
 
48
- class FeedFoward(nn.Module):
49
- """ a simple linear layer followed by a non-linearity """
50
 
51
- def __init__(self, n_embd):
 
52
  super().__init__()
53
  self.net = nn.Sequential(
54
- nn.Linear(n_embd, 4 * n_embd),
55
  nn.ReLU(),
56
- nn.Linear(4 * n_embd, n_embd),
57
- nn.Dropout(cfg.dropout),
58
  )
59
 
60
  def forward(self, x):
61
  return self.net(x)
62
 
63
- class Block(nn.Module):
64
- """ Transformer block: communication followed by computation """
65
 
66
- def __init__(self, n_embd, n_head):
67
- # n_embd: embedding dimension, n_head: the number of heads we'd like
68
  super().__init__()
69
- head_size = n_embd // n_head
70
- self.sa = MultiHeadAttention(n_head, head_size)
71
- self.ffwd = FeedFoward(n_embd)
72
- self.ln1 = nn.LayerNorm(n_embd)
73
- self.ln2 = nn.LayerNorm(n_embd)
 
 
74
 
75
  def forward(self, x):
76
- x = x + self.sa(self.ln1(x))
77
  x = x + self.ffwd(self.ln2(x))
78
  return x
79
 
80
- class GPTLanguageModel(nn.Module):
81
 
82
- def __init__(self, vocab_size):
 
 
 
83
  super().__init__()
84
- # each token directly reads off the logits for the next token from a lookup table
85
- self.token_embedding_table = nn.Embedding(vocab_size, cfg.n_embd)
86
- self.position_embedding_table = nn.Embedding(cfg.block_size, cfg.n_embd)
87
- self.blocks = nn.Sequential(*[Block(cfg.n_embd, n_head=cfg.n_head) for _ in range(cfg.n_layer)])
88
- self.ln_f = nn.LayerNorm(cfg.n_embd) # final layer norm
89
- self.lm_head = nn.Linear(cfg.n_embd, vocab_size)
90
-
91
- # better init, not covered in the original GPT video, but important, will cover in followup video
92
- self.apply(self._init_weights)
93
-
94
- def _init_weights(self, module):
95
- if isinstance(module, nn.Linear):
96
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
97
- if module.bias is not None:
98
- torch.nn.init.zeros_(module.bias)
99
- elif isinstance(module, nn.Embedding):
100
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
101
 
102
  def forward(self, idx, targets=None):
103
  B, T = idx.shape
104
 
105
- # idx and targets are both (B,T) tensor of integers
106
- tok_emb = self.token_embedding_table(idx) # (B,T,C)
107
- pos_emb = self.position_embedding_table(torch.arange(T, device=cfg.device)) # (T,C)
108
- x = tok_emb + pos_emb # (B,T,C)
109
- x = self.blocks(x) # (B,T,C)
110
- x = self.ln_f(x) # (B,T,C)
111
- logits = self.lm_head(x) # (B,T,vocab_size)
112
-
113
- if targets is None:
114
- loss = None
115
- else:
 
116
  B, T, C = logits.shape
117
- logits = logits.view(B*T, C)
118
- targets = targets.view(B*T)
119
  loss = F.cross_entropy(logits, targets)
120
-
121
  return logits, loss
122
 
123
  def generate(self, idx, max_new_tokens):
124
- # idx is (B, T) array of indices in the current context
125
  for _ in range(max_new_tokens):
126
- # crop idx to the last block_size tokens
127
- idx_cond = idx[:, -cfg.block_size:]
128
- # get the predictions
129
- logits, loss = self(idx_cond)
130
- # focus only on the last time step
131
- logits = logits[:, -1, :] # becomes (B, C)
132
- # apply softmax to get probabilities
133
- probs = F.softmax(logits, dim=-1) # (B, C)
134
- # sample from the distribution
135
- idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
136
- # append sampled index to the running sequence
137
- idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
138
  return idx
 
1
  import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
 
4
 
 
 
5
 
6
+ class Head(nn.Module):
7
+ def __init__(self, n_embeds, head_size, block_size, dropout) -> None:
8
  super().__init__()
9
+ self.key = nn.Linear(n_embeds, head_size, bias=False)
10
+ self.query = nn.Linear(n_embeds, head_size, bias=False)
11
+ self.value = nn.Linear(n_embeds, head_size, bias=False)
12
+ self.dropout = nn.Dropout(dropout)
13
+ self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
 
14
 
15
  def forward(self, x):
16
+ B, T, C = x.shape
17
+ k = self.key(x)
18
+ q = self.query(x)
19
+ wei = q @ k.transpose(-2, -1) * (C**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)
20
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
21
+ wei = F.softmax(wei, dim=-1)
 
 
 
22
  wei = self.dropout(wei)
23
+ v = self.value(x)
24
+ out = wei @ v
 
25
  return out
26
 
 
 
27
 
28
+ class MultiHeadAttention(nn.Module):
29
+ def __init__(self, n_heads, n_embeds, head_size, block_size, dropout):
30
  super().__init__()
31
+ self.heads = nn.ModuleList(
32
+ [Head(n_embeds, head_size, block_size, dropout) for _ in range(n_heads)]
33
+ )
34
+ self.proj = nn.Linear(n_embeds, n_embeds)
35
+ self.dropout = nn.Dropout(dropout)
36
 
37
  def forward(self, x):
38
+ x = torch.cat([h(x) for h in self.heads], dim=-1)
39
+ x = self.proj(x)
40
+ x = self.dropout(x)
41
+ return x
42
 
 
 
43
 
44
+ class FeedForward(nn.Module):
45
+ def __init__(self, n_embeds, dropout):
46
  super().__init__()
47
  self.net = nn.Sequential(
48
+ nn.Linear(n_embeds, 4 * n_embeds),
49
  nn.ReLU(),
50
+ nn.Linear(4 * n_embeds, n_embeds),
51
+ nn.Dropout(dropout),
52
  )
53
 
54
  def forward(self, x):
55
  return self.net(x)
56
 
 
 
57
 
58
+ class Decoder(nn.Module):
59
+ def __init__(self, n_embeds, n_heads, block_size, dropout):
60
  super().__init__()
61
+ head_size = n_embeds // n_heads
62
+ self.sa_heads = MultiHeadAttention(
63
+ n_heads, n_embeds, head_size, block_size, dropout
64
+ )
65
+ self.ffwd = FeedForward(n_embeds, dropout)
66
+ self.ln1 = nn.LayerNorm(n_embeds)
67
+ self.ln2 = nn.LayerNorm(n_embeds)
68
 
69
  def forward(self, x):
70
+ x = x + self.sa_heads(self.ln1(x))
71
  x = x + self.ffwd(self.ln2(x))
72
  return x
73
 
 
74
 
75
+ class GPTModel(nn.Module):
76
+ def __init__(
77
+ self, vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device
78
+ ):
79
  super().__init__()
80
+ self.device = device
81
+ self.block_size = block_size
82
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embeds)
83
+ self.position_embedding_table = nn.Embedding(block_size, n_embeds)
84
+ self.blocks = nn.Sequential(
85
+ *[Decoder(n_embeds, n_heads, block_size, dropout) for _ in range(n_layers)]
86
+ )
87
+ self.lnf = nn.LayerNorm(n_embeds)
88
+ self.lm_head = nn.Linear(n_embeds, vocab_size)
 
 
 
 
 
 
 
 
89
 
90
  def forward(self, idx, targets=None):
91
  B, T = idx.shape
92
 
93
+ tok_embeds = self.token_embedding_table(idx) # BxTxNemb
94
+ pos_embeds = self.position_embedding_table(
95
+ torch.arange(T, device=self.device)
96
+ ) # TXNemb
97
+
98
+ x = tok_embeds + pos_embeds # BxTxNemb
99
+ x = self.blocks(x)
100
+ x = self.lnf(x)
101
+ logits = self.lm_head(x) # BxTxVocabSize
102
+
103
+ loss = None
104
+ if targets is not None:
105
  B, T, C = logits.shape
106
+ logits = logits.view(B * T, C)
107
+ targets = targets.view(B * T)
108
  loss = F.cross_entropy(logits, targets)
 
109
  return logits, loss
110
 
111
  def generate(self, idx, max_new_tokens):
 
112
  for _ in range(max_new_tokens):
113
+ idx_cond = idx[:, -self.block_size :]
114
+ logits, loss = self(idx_cond) # BxTxC
115
+ logits = logits[:, -1, :] # BxC
116
+ probs = F.softmax(logits, dim=-1) # BxC
117
+ idx_next = torch.multinomial(probs, num_samples=1) # Bx1
118
+ idx = torch.cat((idx, idx_next), dim=1) # BxT+1
119
+
 
 
 
 
 
120
  return idx