import torch from torch import nn import bitsandbytes as bnb from torch.utils.checkpoint import checkpoint class TransformerBlock(nn.Module): def __init__(self, device, n_heads, embed_dim, ffn_dim, dropout): super().__init__() self.device = device self.layer_norm1 = nn.LayerNorm(embed_dim) self.self_attention = nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True) self.layer_norm2 = nn.LayerNorm(embed_dim) self.ffn = nn.Sequential(nn.Linear(embed_dim, ffn_dim), nn.SiLU(), nn.Dropout(p=dropout), nn.Linear(ffn_dim, embed_dim), nn.Dropout(p=dropout)) def forward(self, x): causal_mask = nn.Transformer.generate_square_subsequent_mask(x.shape[1]).to(self.device) x_in = self.layer_norm1(x) x = x + self.self_attention(x_in, x_in, x_in, attn_mask=causal_mask)[0] x = x + self.ffn(self.layer_norm2(x)) return x class GPTModel(nn.Module): def __init__(self, device, n_layers, n_heads, embed_dim, ffn_dim, n_vocab, max_seq_len, dropout): super().__init__() self.device = device self.max_seq_len = max_seq_len self.embedding = bnb.nn.StableEmbedding(n_vocab, embed_dim) nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02) self.positional_embedding = nn.Embedding(max_seq_len, embed_dim) self.transformer_blocks = nn.ModuleList([TransformerBlock(device, n_heads, embed_dim, ffn_dim, dropout) for _ in range(n_layers)]) self.layer_norm = nn.LayerNorm(embed_dim) self.output_projection = nn.Linear(embed_dim, n_vocab, bias=False) self.output_projection.weight = self.embedding.weight # Using weight sharing def forward(self, input_tokens): input_embed = self.embedding(input_tokens) positions = torch.arange(0, input_tokens.size(1), device=input_tokens.device).unsqueeze(0) input_embed = input_embed + self.positional_embedding(positions) x = input_embed for block in self.transformer_blocks: x = checkpoint(block, x, use_reentrant=False) x = self.layer_norm(x) x = self.output_projection(x) return x # Generate a completion from the model def generate(self, input_ctx, max_length, temperature=1.0, top_p=None, top_k=None): assert max_length <= self.max_seq_len self.eval() context_list = [i for i in input_ctx] with torch.no_grad(): while len(context_list) < max_length: context = torch.tensor(context_list, dtype=torch.long, device=self.device).unsqueeze(0) logits = self(context)[0,-1,:] if top_p: probs = torch.nn.functional.softmax(logits / temperature, dim=-1) sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_mask = cumulative_probs <= top_p sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() sorted_mask[..., 0] = True filtered_probs = sorted_probs * sorted_mask filtered_probs = filtered_probs / filtered_probs.sum() probs_sampled = torch.multinomial(filtered_probs, 1).item() selected_token = sorted_indices[probs_sampled].item() elif top_k: scaled_logits = logits / temperature topk_probs, topk_indices = scaled_logits.topk(top_k) probs = nn.functional.softmax(topk_probs, dim=-1) probs_sampled = torch.multinomial(probs, 1).item() selected_token = topk_indices[probs_sampled].item() else: # Greedy decoding selected_token = logits.argmax(dim=-1).item() context_list.append(selected_token) return context_list