import torch import torch.nn as nn from torch.nn import functional as F from utils import DEVICE class PromeLayerNorm(nn.Module): def __init__(self, epsilon=1e-5): super().__init__() self.epsilon = epsilon def forward(self, x): g = torch.nn.Parameter(torch.ones(x.shape[-1])).to(x.device) b = torch.nn.Parameter(torch.zeros(x.shape[-1])).to(x.device) u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) * torch.rsqrt(s + self.epsilon) x = x * g + b return x class PromeStand(nn.Module): def __init__(self, epsilon=1e-5): super().__init__() self.epsilon = epsilon def forward(self, x): """ x: Input tensor """ mean = x.mean() + self.epsilon std = x.std() + self.epsilon x = x - mean x = x / std return x class PromeEmbedding(nn.Module): """ This class implements a Prome embedding layer. Args: vocab_size (int): The size of the vocabulary. embedding_dim (int): The dimension of the embedding. padding_idx (int, optional): The padding index. If this is not None, then the padding index will be masked out when calculating the embedding. Returns: torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim). """ def __init__(self, vocab_size, embedding_dim, padding_idx = None): super().__init__() self.embedding_dim = embedding_dim self.weight = torch.nn.Parameter(torch.randn(vocab_size, embedding_dim)) self.padding_idx = padding_idx self.context_matrix = torch.nn.Parameter(torch.randn(vocab_size, embedding_dim)) def forward(self, input_ids): """ Calculates the embedding for the given input IDs. Args: input_ids (torch.Tensor): A tensor of shape (batch_size, sequence_length). Returns: torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim). """ input_ids = input_ids.long() if self.padding_idx is not None: input_ids = input_ids.masked_fill(input_ids == self.padding_idx, 0) # get symbol vector embeddings = self.weight[input_ids] # Dynamically update context vector based on input embeddings context_vectors = self.context_matrix[input_ids] # Modify embeddings using context vector output = embeddings + context_vectors return output class AttentionHead(nn.Module): """ One head of the self-attention layer """ def __init__(self, head_size, num_embed, block_size, dropout): super().__init__() self.key = nn.Linear(num_embed, head_size, bias=False) self.query = nn.Linear(num_embed, head_size, bias=False) self.value = nn.Linear(num_embed, head_size, bias=False) # tril is a lower triangular matrix. it is not a parameter # of the model, so we assign it to the module using register_buffer self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size))) # layer norm self.norm = PromeStand() # Dropout self.dropout = nn.Dropout(dropout) def forward(self, x): B, T, C = x.shape key = self.key(x) query = self.query(x) # compute attention scores # (B, T, C) @ (B, C, T) -> (B, T, T) wei = (query @ key.transpose(-2, -1)) * C ** -0.5 # Tril matrix (lower triagular matrix) is used to mask # future positions (setting them to -inf) so that the # decoder "learns" to predict next words wei = wei.masked_fill(self.tril[:T, :T] == 0, -float("inf")) # (B,T,T) wei = F.silu(F.softmax(wei, dim=-1)) # scale # multiplicative attention score = -1 / (C ** -0.5) wei.mul_(score) # weighted aggregation of the values value = self.value(x) out = wei @ value # (B,T,T) @ (B,T,C) ---> (B,T,C) return out class MultiHeadAttention(nn.Module): """ Multiple Heads of self-attention in parallel """ def __init__(self, num_heads, head_size, num_embed, block_size, dropout): super().__init__() self.heads = nn.ModuleList( [ AttentionHead( head_size=head_size, num_embed=num_embed, block_size=block_size, dropout=dropout ) for _ in range(num_heads) ] ) self.proj = nn.Linear(num_embed, num_embed) self.dropout = nn.Dropout(dropout) self.norm = PromeStand() def forward(self, x): # output of the self-attention out = torch.concat([h(x) for h in self.heads], dim=-1) # standartization out = self.norm(out + x) # apply the linear projection layer out = self.dropout(self.proj(out)) return out class MLP(nn.Module): def __init__(self, num_embed, hidden_dim, dropout=0.1): super().__init__() self.dropout = nn.Dropout(dropout) self.fc1 = nn.Linear(num_embed, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, num_embed) def forward(self, x): x = self.fc1(x) x = F.silu(x) x = self.fc2(x) x = self.dropout(x) x = F.silu(x) x = self.fc3(x) return x class TransformerBlock(nn.Module): """ This calss will group together MultiHead Attention and FeedForward NN, so that we can copy it in Transformer """ def __init__(self, num_heads, block_size, num_embed, hidden_dim, dropout): super().__init__() head_size = num_embed // num_heads self.mha = MultiHeadAttention( num_heads=num_heads, head_size=head_size, num_embed=num_embed, block_size=block_size, dropout=dropout ) self.mlp = MLP(num_embed=num_embed, hidden_dim = hidden_dim, dropout=dropout) # add the layer normalization self.ln = PromeStand(num_embed) self.dropout = nn.Dropout(dropout) def forward(self, x): """ Decodes the input sequence. Args: x (torch.Tensor): A tensor of shape (batch_size, sequence_length, embedding_dim). memory (torch.Tensor): A tensor of shape (batch_size, memory_length, embedding_dim). Returns: torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim). """ y = x x = self.ln(x) x = self.mha(x) x = self.dropout(x) x += y y = x x = self.ln(x) x = self.mlp(x) x = self.mha(x) x += y x = self.dropout(x) return x class TransformerDecoder(nn.Module): """ This class implements a Transformer decoder. Args: num_heads (int): The number of attention heads. block_size (int): The size of the input sequence. num_embed (int): The dimension of the embedding. num_layers (int): The number of decoder blocks. dropout (float): The dropout rate. Returns: torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim). """ def __init__(self, num_heads, block_size, num_embed, hidden_dim, num_layers, dropout): super().__init__() # Create the embedding layer. self.pemb = PromeEmbedding(block_size, num_embed) # Create a sequential block of Transformer blocks. self.blocks = nn.Sequential( *[ TransformerBlock( num_heads=num_heads, block_size=block_size, num_embed=num_embed, hidden_dim = hidden_dim, dropout=dropout ) for _ in range(num_layers) ] ) # Create a softmax layer. self.softmax = nn.Softmax(dim=-1) def forward(self, x): """ Decodes the input sequence. Args: x (torch.Tensor): A tensor of shape (batch_size, sequence_length). Returns: torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim). """ # Add positional encodings to the input sequence. x = x + self.pemb(torch.arange(x.size(1))) x = self.blocks(x) # Apply a softmax layer to the output of the last Transformer block. x = self.softmax(x) return x class Transformer(nn.Module): def __init__(self, **kwargs): super().__init__() # a simple lookup table that stores embeddings of a fixed dictionary and size # each token directly reads off the logits for the next token from a lookup table # see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html self.vocab_size = kwargs.get("vocab_size", 100) self.num_embed = kwargs.get("num_embed", 32) self.block_size = kwargs.get("block_size", 8) self.num_heads = kwargs.get("num_heads", 4) self.num_layers = kwargs.get("num_layers", 4) self.hidden_dim = kwargs.get("hidden_dim", 768) self.dropout = kwargs.get("dropout", 0.2) # each token reads the logits for the next token from a lookup table self.token_embedding_table = PromeEmbedding(self.vocab_size, self.num_embed) # each position from 0 to block_size-1 will get its embedding self.position_embedding_table = PromeEmbedding(self.block_size, self.num_embed) self.decoder = TransformerDecoder(self.num_heads, self.block_size, self.num_embed, self.hidden_dim, self.num_layers, self.dropout) # we add the layer norm before the Linear layer self.dropout = nn.Dropout(self.dropout) self.ln_f = PromeLayerNorm(self.num_embed) self.lm_head = nn.Linear(self.num_embed, self.vocab_size) def forward(self, idx, targets=None): B, T = idx.shape # idx and targets are (B,T) tensor of integers # the token_emb is (B, T, C), C = NUM_EMBED token_emb = self.token_embedding_table(idx) # (T, C) posit_emb = self.position_embedding_table(torch.arange(T, device=DEVICE)) x = token_emb + posit_emb # apply dropout x = self.dropout(x) # apply one head of self-attention x = self.decoder(x) # apply normalization x = self.ln_f(x) # (B, T, vocab_size) logits = self.lm_head(x) # Compute the loss if targets != None: # cross_entropy accepts inputs in a (batch_size, num_classes) # so we need to reformat our logits dimensions to # (batch_size * time, dim_vocabulary), time = block_size B, T, C = logits.shape logits = torch.reshape(logits, (B * T, C)) targets = torch.reshape(targets, (B * T, )) loss = F.cross_entropy(logits, targets) else: loss = None return logits, loss def generate(self, idx: torch.Tensor, max_new_tokens: int, block_size: int): # idx is (B, T) array of indices in the current context for _ in range(max_new_tokens): # crop the context too the last block_size tokens # because tokens don't communicate between blocks idx_crop = idx[:, -block_size:] # get the predictions logits, loss = self.forward(idx_crop) # focus only on the last time step logits = logits[:, -1, :] # becomes (B, C) # apply softmax to get probabilities probs = F.softmax(logits, dim=-1) # (B, C) # sample from the distribution with probabilities probs idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) # append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) return idx