File size: 4,123 Bytes
598e715 83b97e5 598e715 cac1163 598e715 cac1163 598e715 f516396 598e715 163984a 598e715 da39f22 598e715 da39f22 7ec097e 83b97e5 598e715 7ec097e 598e715 83b97e5 598e715 ba54a72 0233b91 598e715 da39f22 598e715 ba54a72 a32d620 ba54a72 a32d620 9aae32c a32d620 598e715 ba54a72 598e715 ba54a72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
from torch import nn
import bitsandbytes as bnb
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def __init__(self, device, n_heads, embed_dim, ffn_dim, dropout):
super().__init__()
self.device = device
self.layer_norm1 = nn.LayerNorm(embed_dim)
self.self_attention = nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True)
self.layer_norm2 = nn.LayerNorm(embed_dim)
self.ffn = nn.Sequential(nn.Linear(embed_dim, ffn_dim),
nn.SiLU(),
nn.Dropout(p=dropout),
nn.Linear(ffn_dim, embed_dim),
nn.Dropout(p=dropout))
def forward(self, x):
causal_mask = nn.Transformer.generate_square_subsequent_mask(x.shape[1]).to(self.device)
x_in = self.layer_norm1(x)
x = x + self.self_attention(x_in, x_in, x_in, attn_mask=causal_mask)[0]
x = x + self.ffn(self.layer_norm2(x))
return x
class GPTModel(nn.Module):
def __init__(self, device, n_layers, n_heads, embed_dim, ffn_dim, n_vocab, max_seq_len, dropout):
super().__init__()
self.device = device
self.max_seq_len = max_seq_len
self.embedding = bnb.nn.StableEmbedding(n_vocab, embed_dim)
nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
self.positional_embedding = nn.Embedding(max_seq_len, embed_dim)
self.transformer_blocks = nn.ModuleList([TransformerBlock(device, n_heads, embed_dim, ffn_dim, dropout) for _ in range(n_layers)])
self.layer_norm = nn.LayerNorm(embed_dim)
self.output_projection = nn.Linear(embed_dim, n_vocab, bias=False)
self.output_projection.weight = self.embedding.weight # Using weight sharing
def forward(self, input_tokens):
input_embed = self.embedding(input_tokens)
positions = torch.arange(0, input_tokens.size(1), device=input_tokens.device).unsqueeze(0)
input_embed = input_embed + self.positional_embedding(positions)
x = input_embed
for block in self.transformer_blocks:
x = checkpoint(block, x, use_reentrant=False)
x = self.layer_norm(x)
x = self.output_projection(x)
return x
# Generate a completion from the model
def generate(self, input_ctx, max_length, temperature=1.0, top_p=None, top_k=None):
assert max_length <= self.max_seq_len
self.eval()
context_list = [i for i in input_ctx]
with torch.no_grad():
while len(context_list) < max_length:
context = torch.tensor(context_list, dtype=torch.long, device=self.device).unsqueeze(0)
logits = self(context)[0,-1,:]
if top_p:
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_mask = cumulative_probs <= top_p
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = True
filtered_probs = sorted_probs * sorted_mask
filtered_probs = filtered_probs / filtered_probs.sum()
probs_sampled = torch.multinomial(filtered_probs, 1).item()
selected_token = sorted_indices[probs_sampled].item()
elif top_k:
scaled_logits = logits / temperature
topk_probs, topk_indices = scaled_logits.topk(top_k)
probs = nn.functional.softmax(topk_probs, dim=-1)
probs_sampled = torch.multinomial(probs, 1).item()
selected_token = topk_indices[probs_sampled].item()
else: # Greedy decoding
selected_token = logits.argmax(dim=-1).item()
context_list.append(selected_token)
return context_list |