import torch
import torch.nn as nn
import torch.nn.functional as F
from data_utils import *
from attention import SelfAttentionHead, MultiHeadAttention, FeedForwardNet, DecoderBlock
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size, n_embed, block_size, num_heads, n_layers) -> None:
self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
self.position_embedding_table = nn.Embedding(block_size, n_embed)
self.decoder_blocks = nn.Sequential(*[DecoderBlock(n_embed, num_heads, block_size=block_size) for _ in range(n_layers)] )
self.ln_final = nn.LayerNorm(n_embed)
## self.sa_head = SelfAttentionHead(vocab_size, n_embed, block_size)
# self.sa_heads = MultiHeadAttention(num_heads=4, head_size=n_embed//4, n_embed=n_embed, block_size=block_size)
# self.ffn = FeedForwardNet(n_embed, dropout=0.2)
self.lm_head = nn.Linear(n_embed, vocab_size)
def forward(self, idx, targets=None):
# idx and targets both are tensors of shape (B, T) -> B = batch_sz, T = seq_len ("time steps", here 8)
B, T = idx.shape
tok_embed = self.token_embedding_table(idx) # (B, T, C) C = "channels", here vocab_size or embedding dim for each token
pos_embed = self.position_embedding_table(torch.arange(T, device=idx.device)) # (T, C) C = "channels", here vocab_size or embedding dim for each token
x_in = tok_embed + pos_embed
# x_in = self.sa_heads(x_in)
# x_in = self.ffn(x_in)
x_in = self.ln_final(self.decoder_blocks(x_in))
logits = self.lm_head(x_in) # (B, T, C) C = "channels", here vocab_size or embedding dim for each token
if targets is None:
loss = None
B, T, C = logits.shape
# Cross entropy requires the 2nd param to be C "channels"
loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T), ignore_index=0)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) shaped array of indices in current context
for _ in range(max_new_tokens):
#limit input idx to last "block size" tokens
idx_cond = idx[:, -BLOCK_SIZE:]
logits, loss = self(idx_cond)
#focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax for probs
probs = F.softmax(logits, dim=-1) # (B, C)
#sample from distribudion
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
#append sampled index to running sequence idx
idx = torch.cat([idx, idx_next], dim=1) # (B, T+1)
return idx
def get_num_params(self, non_embedding=True):
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
The token embeddings would too, except due to the parameter sharing these
params are actually used as weights in the final layer, so we include them.
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.transformer.wpe.weight.numel()
return n_params
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if __name__ == "__main__":
from data_utils import *
xb, yb = get_random_batch('train')
xb = xb.to(device)
yb = yb.to(device)
m = BigramLanguageModel(vocab_size=65, n_embed=n_embed, block_size=BLOCK_SIZE, num_heads=n_head, n_layers=n_layer).to(device)
logits, loss = m(xb, yb)