import torch import torch.nn as nn import torch.nn.functional as F from data_utils import * from attention import SelfAttentionHead, MultiHeadAttention, FeedForwardNet, DecoderBlock class BigramLanguageModel(nn.Module): def __init__(self, vocab_size, n_embed, block_size, num_heads, n_layers) -> None: super().__init__() self.token_embedding_table = nn.Embedding(vocab_size, n_embed) self.position_embedding_table = nn.Embedding(block_size, n_embed) self.decoder_blocks = nn.Sequential(*[DecoderBlock(n_embed, num_heads, block_size=block_size) for _ in range(n_layers)] ) self.ln_final = nn.LayerNorm(n_embed) ## self.sa_head = SelfAttentionHead(vocab_size, n_embed, block_size) # self.sa_heads = MultiHeadAttention(num_heads=4, head_size=n_embed//4, n_embed=n_embed, block_size=block_size) # self.ffn = FeedForwardNet(n_embed, dropout=0.2) self.lm_head = nn.Linear(n_embed, vocab_size) def forward(self, idx, targets=None): # idx and targets both are tensors of shape (B, T) -> B = batch_sz, T = seq_len ("time steps", here 8) B, T = idx.shape tok_embed = self.token_embedding_table(idx) # (B, T, C) C = "channels", here vocab_size or embedding dim for each token pos_embed = self.position_embedding_table(torch.arange(T, device=idx.device)) # (T, C) C = "channels", here vocab_size or embedding dim for each token x_in = tok_embed + pos_embed # x_in = self.sa_heads(x_in) # x_in = self.ffn(x_in) x_in = self.ln_final(self.decoder_blocks(x_in)) logits = self.lm_head(x_in) # (B, T, C) C = "channels", here vocab_size or embedding dim for each token if targets is None: loss = None else: B, T, C = logits.shape # Cross entropy requires the 2nd param to be C "channels" loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T), ignore_index=0) return logits, loss def generate(self, idx, max_new_tokens): # idx is (B, T) shaped array of indices in current context for _ in range(max_new_tokens): #limit input idx to last "block size" tokens idx_cond = idx[:, -BLOCK_SIZE:] logits, loss = self(idx_cond) #focus only on the last time step logits = logits[:, -1, :] # becomes (B, C) # apply softmax for probs probs = F.softmax(logits, dim=-1) # (B, C) #sample from distribudion idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) #append sampled index to running sequence idx idx = torch.cat([idx, idx_next], dim=1) # (B, T+1) return idx def get_num_params(self, non_embedding=True): """ Return the number of parameters in the model. For non-embedding count (default), the position embeddings get subtracted. The token embeddings would too, except due to the parameter sharing these params are actually used as weights in the final layer, so we include them. """ n_params = sum(p.numel() for p in self.parameters()) if non_embedding: n_params -= self.transformer.wpe.weight.numel() return n_params def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if __name__ == "__main__": from data_utils import * xb, yb = get_random_batch('train') xb = xb.to(device) yb = yb.to(device) m = BigramLanguageModel(vocab_size=65, n_embed=n_embed, block_size=BLOCK_SIZE, num_heads=n_head, n_layers=n_layer).to(device) logits, loss = m(xb, yb) print(logits.shape) print(loss)