Navyabhat commited on
Commit
367dd01
1 Parent(s): 139373c

Update gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +77 -82
gpt.py CHANGED
@@ -1,125 +1,120 @@
1
  import torch
2
- import torch.nn as nn
3
- from torch.nn import functional as F
4
- import config as cfg
5
 
6
- class Head(nn.Module):
7
 
8
- def __init__(self, head_size):
 
9
  super().__init__()
10
- self.key = nn.Linear(cfg.n_embd, head_size, bias=False)
11
- self.query = nn.Linear(cfg.n_embd, head_size, bias=False)
12
- self.value = nn.Linear(cfg.n_embd, head_size, bias=False)
13
- self.register_buffer('tril', torch.tril(torch.ones(cfg.block_size, cfg.block_size)))
14
-
15
- self.dropout = nn.Dropout(cfg.dropout)
16
 
17
  def forward(self, x):
18
- B,T,C = x.shape
19
- k = self.key(x) # (B,T,hs)
20
- q = self.query(x) # (B,T,hs)
21
- wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
22
- wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
23
- wei = F.softmax(wei, dim=-1) # (B, T, T)
24
  wei = self.dropout(wei)
25
  v = self.value(x)
26
- out = wei @ v
27
  return out
28
 
29
- class MultiHeadAttention(nn.Module):
30
- """ multiple heads of self-attention in parallel """
31
 
32
- def __init__(self, num_heads, head_size):
 
33
  super().__init__()
34
- self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
35
- self.proj = nn.Linear(head_size * num_heads, cfg.n_embd)
36
- self.dropout = nn.Dropout(cfg.dropout)
 
 
37
 
38
  def forward(self, x):
39
- out = torch.cat([h(x) for h in self.heads], dim=-1)
40
- out = self.dropout(self.proj(out))
41
- return out
 
42
 
43
- class FeedFoward(nn.Module):
44
- """ a simple linear layer followed by a non-linearity """
45
 
46
- def __init__(self, n_embd):
 
47
  super().__init__()
48
  self.net = nn.Sequential(
49
- nn.Linear(n_embd, 4 * n_embd),
50
  nn.ReLU(),
51
- nn.Linear(4 * n_embd, n_embd),
52
- nn.Dropout(cfg.dropout),
53
  )
54
 
55
  def forward(self, x):
56
  return self.net(x)
57
 
58
- class Block(nn.Module):
59
- """ Transformer block: communication followed by computation """
60
 
61
- def __init__(self, n_embd, n_head):
62
- # n_embd: embedding dimension, n_head: the number of heads we'd like
63
  super().__init__()
64
- head_size = n_embd // n_head
65
- self.sa = MultiHeadAttention(n_head, head_size)
66
- self.ffwd = FeedFoward(n_embd)
67
- self.ln1 = nn.LayerNorm(n_embd)
68
- self.ln2 = nn.LayerNorm(n_embd)
 
 
69
 
70
  def forward(self, x):
71
- x = x + self.sa(self.ln1(x))
72
  x = x + self.ffwd(self.ln2(x))
73
  return x
74
 
75
- class GPTLanguageModel(nn.Module):
76
 
77
- def __init__(self, vocab_size):
 
 
 
78
  super().__init__()
79
- # each token directly reads off the logits for the next token from a lookup table
80
- self.token_embedding_table = nn.Embedding(vocab_size, cfg.n_embd)
81
- self.position_embedding_table = nn.Embedding(cfg.block_size, cfg.n_embd)
82
- self.blocks = nn.Sequential(*[Block(cfg.n_embd, n_head=cfg.n_head) for _ in range(cfg.n_layer)])
83
- self.ln_f = nn.LayerNorm(cfg.n_embd)
84
- self.lm_head = nn.Linear(cfg.n_embd, vocab_size)
85
- self.apply(self._init_weights)
86
-
87
- def _init_weights(self, module):
88
- if isinstance(module, nn.Linear):
89
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
90
- if module.bias is not None:
91
- torch.nn.init.zeros_(module.bias)
92
- elif isinstance(module, nn.Embedding):
93
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
94
 
95
  def forward(self, idx, targets=None):
96
  B, T = idx.shape
97
 
98
- # idx and targets are both (B,T) tensor of integers
99
- tok_emb = self.token_embedding_table(idx) # (B,T,C)
100
- pos_emb = self.position_embedding_table(torch.arange(T, device=cfg.device)) # (T,C)
101
- x = tok_emb + pos_emb # (B,T,C)
102
- x = self.blocks(x) # (B,T,C)
103
- x = self.ln_f(x) # (B,T,C)
104
- logits = self.lm_head(x) # (B,T,vocab_size)
105
-
106
- if targets is None:
107
- loss = None
108
- else:
 
109
  B, T, C = logits.shape
110
- logits = logits.view(B*T, C)
111
- targets = targets.view(B*T)
112
  loss = F.cross_entropy(logits, targets)
113
-
114
  return logits, loss
115
 
116
  def generate(self, idx, max_new_tokens):
117
- # idx is (B, T) array of indices in the current context
118
  for _ in range(max_new_tokens):
119
- idx_cond = idx[:, -cfg.block_size:]
120
- logits, loss = self(idx_cond)
121
- logits = logits[:, -1, :]
122
- probs = F.softmax(logits, dim=-1)
123
- idx_next = torch.multinomial(probs, num_samples=1)
124
- idx = torch.cat((idx, idx_next), dim=1)
125
- return idx
 
 
1
  import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
 
4
 
 
5
 
6
+ class Head(nn.Module):
7
+ def __init__(self, n_embeds, head_size, block_size, dropout) -> None:
8
  super().__init__()
9
+ self.key = nn.Linear(n_embeds, head_size, bias=False)
10
+ self.query = nn.Linear(n_embeds, head_size, bias=False)
11
+ self.value = nn.Linear(n_embeds, head_size, bias=False)
12
+ self.dropout = nn.Dropout(dropout)
13
+ self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
 
14
 
15
  def forward(self, x):
16
+ B, T, C = x.shape
17
+ k = self.key(x)
18
+ q = self.query(x)
19
+ wei = q @ k.transpose(-2, -1) * (C**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)
20
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
21
+ wei = F.softmax(wei, dim=-1)
22
  wei = self.dropout(wei)
23
  v = self.value(x)
24
+ out = wei @ v
25
  return out
26
 
 
 
27
 
28
+ class MultiHeadAttention(nn.Module):
29
+ def __init__(self, n_heads, n_embeds, head_size, block_size, dropout):
30
  super().__init__()
31
+ self.heads = nn.ModuleList(
32
+ [Head(n_embeds, head_size, block_size, dropout) for _ in range(n_heads)]
33
+ )
34
+ self.proj = nn.Linear(n_embeds, n_embeds)
35
+ self.dropout = nn.Dropout(dropout)
36
 
37
  def forward(self, x):
38
+ x = torch.cat([h(x) for h in self.heads], dim=-1)
39
+ x = self.proj(x)
40
+ x = self.dropout(x)
41
+ return x
42
 
 
 
43
 
44
+ class FeedForward(nn.Module):
45
+ def __init__(self, n_embeds, dropout):
46
  super().__init__()
47
  self.net = nn.Sequential(
48
+ nn.Linear(n_embeds, 4 * n_embeds),
49
  nn.ReLU(),
50
+ nn.Linear(4 * n_embeds, n_embeds),
51
+ nn.Dropout(dropout),
52
  )
53
 
54
  def forward(self, x):
55
  return self.net(x)
56
 
 
 
57
 
58
+ class Decoder(nn.Module):
59
+ def __init__(self, n_embeds, n_heads, block_size, dropout):
60
  super().__init__()
61
+ head_size = n_embeds // n_heads
62
+ self.sa_heads = MultiHeadAttention(
63
+ n_heads, n_embeds, head_size, block_size, dropout
64
+ )
65
+ self.ffwd = FeedForward(n_embeds, dropout)
66
+ self.ln1 = nn.LayerNorm(n_embeds)
67
+ self.ln2 = nn.LayerNorm(n_embeds)
68
 
69
  def forward(self, x):
70
+ x = x + self.sa_heads(self.ln1(x))
71
  x = x + self.ffwd(self.ln2(x))
72
  return x
73
 
 
74
 
75
+ class GPTModel(nn.Module):
76
+ def __init__(
77
+ self, vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device
78
+ ):
79
  super().__init__()
80
+ self.device = device
81
+ self.block_size = block_size
82
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embeds)
83
+ self.position_embedding_table = nn.Embedding(block_size, n_embeds)
84
+ self.blocks = nn.Sequential(
85
+ *[Decoder(n_embeds, n_heads, block_size, dropout) for _ in range(n_layers)]
86
+ )
87
+ self.lnf = nn.LayerNorm(n_embeds)
88
+ self.lm_head = nn.Linear(n_embeds, vocab_size)
 
 
 
 
 
 
89
 
90
  def forward(self, idx, targets=None):
91
  B, T = idx.shape
92
 
93
+ tok_embeds = self.token_embedding_table(idx) # BxTxNemb
94
+ pos_embeds = self.position_embedding_table(
95
+ torch.arange(T, device=self.device)
96
+ ) # TXNemb
97
+
98
+ x = tok_embeds + pos_embeds # BxTxNemb
99
+ x = self.blocks(x)
100
+ x = self.lnf(x)
101
+ logits = self.lm_head(x) # BxTxVocabSize
102
+
103
+ loss = None
104
+ if targets is not None:
105
  B, T, C = logits.shape
106
+ logits = logits.view(B * T, C)
107
+ targets = targets.view(B * T)
108
  loss = F.cross_entropy(logits, targets)
 
109
  return logits, loss
110
 
111
  def generate(self, idx, max_new_tokens):
 
112
  for _ in range(max_new_tokens):
113
+ idx_cond = idx[:, -self.block_size :]
114
+ logits, loss = self(idx_cond) # BxTxC
115
+ logits = logits[:, -1, :] # BxC
116
+ probs = F.softmax(logits, dim=-1) # BxC
117
+ idx_next = torch.multinomial(probs, num_samples=1) # Bx1
118
+ idx = torch.cat((idx, idx_next), dim=1) # BxT+1
119
+
120
+ return idx