Spaces:
Sleeping
Sleeping
File size: 5,173 Bytes
47d6804 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
# Ensure embedding dimension is divisible by number of heads
assert config.emb_dim % config.num_head == 0
self.n_head = config.num_head
self.n_embd = config.emb_dim
self.head_size = config.emb_dim // config.num_head
# Separate projections for Q, K, V instead of a single projection
self.q_proj = nn.Linear(config.emb_dim, config.emb_dim)
self.k_proj = nn.Linear(config.emb_dim, config.emb_dim)
self.v_proj = nn.Linear(config.emb_dim, config.emb_dim)
self.out_proj = nn.Linear(config.emb_dim, config.emb_dim)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
# Causal mask
self.register_buffer(
"mask",
torch.tril(torch.ones(config.block_size, config.block_size)).view(
1, 1, config.block_size, config.block_size
),
)
def forward(self, x):
B, T, C = x.size() # batch, sequence length, embedding dim
# Separate projections for Q, K, V
q = self.q_proj(x) # (B, T, C)
k = self.k_proj(x) # (B, T, C)
v = self.v_proj(x) # (B, T, C)
# Reshape heads
q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, nh, T, hs)
k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, nh, T, hs)
# Compute attention scores
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
# Apply attention to values
y = att @ v # (B, nh, T, hs)
# Reshape and project output
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
y = self.out_proj(y)
y = self.resid_dropout(y)
return y
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.emb_dim, 4 * config.emb_dim)
self.c_proj = nn.Linear(4 * config.emb_dim, config.emb_dim)
self.dropout = nn.Dropout(config.dropout)
self.gelu = nn.GELU()
def forward(self, x):
x = self.gelu(self.c_fc(x))
x = self.dropout(self.c_proj(x))
return x
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.emb_dim)
self.ln_2 = nn.LayerNorm(config.emb_dim)
self.attn = MultiHeadAttention(config)
self.mlp = FeedForward(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(
{
"wte": nn.Embedding(config.vocab_size, config.emb_dim),
"wpe": nn.Embedding(config.block_size, config.emb_dim),
"drop": nn.Dropout(config.dropout),
"h": nn.ModuleList(
[TransformerBlock(config) for _ in range(config.num_layer)]
),
"ln_f": nn.LayerNorm(config.emb_dim),
}
)
self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
# Initialize weights
self.apply(self._init_weights)
# Tie weights between embedding and final linear layer
self.transformer.wte.weight = self.lm_head.weight
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
# Get positions
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # (1, t)
# Get embeddings
tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # (1, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
# Apply transformer blocks
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
return logits
|