sanjanatule commited on
Commit
a26f15b
1 Parent(s): 8aedac7

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +154 -0
utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import lightning.pytorch as pl
4
+ from torchvision import datasets
5
+ from torch.nn import functional as F
6
+ from torch.utils.data import DataLoader, Dataset, random_splitim
7
+
8
+ class Head(nn.Module):
9
+ """ one head of self-attention """
10
+
11
+ def __init__(self, head_size):
12
+ super().__init__()
13
+ self.key = nn.Linear(n_embd, head_size, bias=False)
14
+ self.query = nn.Linear(n_embd, head_size, bias=False)
15
+ self.value = nn.Linear(n_embd, head_size, bias=False)
16
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
17
+
18
+ self.dropout = nn.Dropout(dropout)
19
+
20
+ def forward(self, x):
21
+ B,T,C = x.shape
22
+ k = self.key(x) # (B,T,C)
23
+ q = self.query(x) # (B,T,C)
24
+ # compute attention scores ("affinities")
25
+ wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
26
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
27
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
28
+ wei = self.dropout(wei)
29
+ # perform the weighted aggregation of the values
30
+ v = self.value(x) # (B,T,C)
31
+ out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
32
+ return out
33
+
34
+ class MultiHeadAttention(nn.Module):
35
+ """ multiple heads of self-attention in parallel """
36
+
37
+ def __init__(self, num_heads, head_size):
38
+ super().__init__()
39
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
40
+ self.proj = nn.Linear(n_embd, n_embd)
41
+ self.dropout = nn.Dropout(dropout)
42
+
43
+ def forward(self, x):
44
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
45
+ out = self.dropout(self.proj(out))
46
+ return out
47
+
48
+ class FeedFoward(nn.Module):
49
+ """ a simple linear layer followed by a non-linearity """
50
+
51
+ def __init__(self, n_embd):
52
+ super().__init__()
53
+ self.net = nn.Sequential(
54
+ nn.Linear(n_embd, 4 * n_embd),
55
+ nn.ReLU(),
56
+ nn.Linear(4 * n_embd, n_embd),
57
+ nn.Dropout(dropout),
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.net(x)
62
+
63
+ class Block(nn.Module):
64
+ """ Transformer block: communication followed by computation """
65
+
66
+ def __init__(self, n_embd, n_head):
67
+ # n_embd: embedding dimension, n_head: the number of heads we'd like
68
+ super().__init__()
69
+ head_size = n_embd // n_head
70
+ self.sa = MultiHeadAttention(n_head, head_size)
71
+ self.ffwd = FeedFoward(n_embd)
72
+ self.ln1 = nn.LayerNorm(n_embd)
73
+ self.ln2 = nn.LayerNorm(n_embd)
74
+
75
+ def forward(self, x):
76
+ x = x + self.sa(self.ln1(x))
77
+ x = x + self.ffwd(self.ln2(x))
78
+ return x
79
+
80
+
81
+ class GPTLanguageModel(nn.Module):
82
+ def __init__(self):
83
+ super().__init__()
84
+ # each token directly reads off the logits for the next token from a lookup table
85
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
86
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
87
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
88
+ self.ln_f = nn.LayerNorm(n_embd) # final layer norm
89
+ self.lm_head = nn.Linear(n_embd, vocab_size)
90
+
91
+ def forward(self, idx, targets=None):
92
+ B, T = idx.shape
93
+ #print(idx.device)
94
+ # idx and targets are both (B,T) tensor of integers
95
+ tok_emb = self.token_embedding_table(idx) # (B,T,C)
96
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
97
+ x = tok_emb + pos_emb # (B,T,C)
98
+ x = self.blocks(x) # (B,T,C)
99
+ x = self.ln_f(x) # (B,T,C)
100
+ logits = self.lm_head(x) # (B,T,vocab_size)
101
+
102
+ if targets is None:
103
+ loss = None
104
+ else:
105
+ B, T, C = logits.shape
106
+ logits = logits.view(B*T, C)
107
+ targets = targets.view(B*T)
108
+ loss = F.cross_entropy(logits, targets)
109
+
110
+ return logits, loss
111
+
112
+ def generate(self, idx, max_new_tokens):
113
+ # idx is (B, T) array of indices in the current context
114
+ for _ in range(max_new_tokens):
115
+ # crop idx to the last block_size tokens
116
+ idx_cond = idx[:, -block_size:].to(device)
117
+ # get the predictions
118
+ logits, loss = self(idx_cond)
119
+ # focus only on the last time step
120
+ logits = logits[:, -1, :] # becomes (B, C)
121
+ # apply softmax to get probabilities
122
+ probs = F.softmax(logits, dim=-1) # (B, C)
123
+ # sample from the distribution
124
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
125
+ # append sampled index to the running sequence
126
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
127
+ return idx
128
+
129
+ class GPTLM(pl.LightningModule):
130
+ def __init__(self):
131
+ super().__init__()
132
+ self.model = GPTLanguageModel()
133
+
134
+ def forward(self, idx, targets=None):
135
+ return self.model(idx, targets)
136
+
137
+ def process_step(self, batch):
138
+ xb, yb = batch
139
+ logits, loss = self(xb, yb)
140
+ return(logits, loss)
141
+
142
+ def training_step(self, batch, batch_idx):
143
+ _, loss = self.process_step(batch)
144
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
145
+ return loss
146
+
147
+ def validation_step(self, batch, batch_idx):
148
+ _, loss = self.process_step(batch)
149
+ self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
150
+ return loss
151
+
152
+ def configure_optimizers(self):
153
+ optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)
154
+ return optimizer