|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformer_lens.HookedTransformer import HookedTransformer |
|
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig |
|
from transformer_lens.train import HookedTransformerTrainConfig, train |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
def generate_config( |
|
n_ctx, |
|
d_model, |
|
d_head, |
|
n_heads, |
|
d_mlp, |
|
n_layers, |
|
attention_dir, |
|
act_fn, |
|
d_vocab, |
|
d_vocab_out, |
|
use_attn_result, |
|
device, |
|
use_hook_tokens, |
|
): |
|
return HookedTransformerConfig( |
|
n_ctx=n_ctx, |
|
d_model=d_model, |
|
d_head=d_head, |
|
n_heads=n_heads, |
|
d_mlp=d_mlp, |
|
n_layers=n_layers, |
|
attention_dir=attention_dir, |
|
act_fn=act_fn, |
|
d_vocab=d_vocab, |
|
d_vocab_out=d_vocab_out, |
|
use_attn_result=use_attn_result, |
|
device=device, |
|
use_hook_tokens=use_hook_tokens, |
|
) |
|
|
|
|
|
def generate_model(config): |
|
return HookedTransformer(config) |
|
|
|
|
|
def train_model(model, n_epochs, batch_size, lr, dataset): |
|
train_cfg = HookedTransformerTrainConfig( |
|
num_epochs=n_epochs, batch_size=128, lr=0.001, device="cuda:0" |
|
) |
|
|
|
return train(model, train_cfg, dataset) |
|
|
|
|
|
class ScaledDotProductAttention(nn.Module): |
|
def __init__(self, scale): |
|
super().__init__() |
|
self.scale = scale |
|
|
|
def forward(self, q, k, v, mask=None): |
|
attn = torch.matmul(q, k.transpose(-2, -1)) * 1 / self.scale |
|
if mask is not None: |
|
attn = attn.masked_fill(mask == 0, float("-inf")) |
|
|
|
attn = torch.softmax(attn, dim=-1) |
|
|
|
out = torch.matmul(attn, v) |
|
return out, attn |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self, n_heads, d_model): |
|
super().__init__() |
|
assert d_model % n_heads == 0, "d_model should be divisible by n_heads" |
|
|
|
self.d_model = d_model |
|
self.n_heads = n_heads |
|
self.depth = d_model // n_heads |
|
|
|
self.wq = nn.Linear(d_model, d_model) |
|
self.wk = nn.Linear(d_model, d_model) |
|
self.wv = nn.Linear(d_model, d_model) |
|
|
|
self.dense = nn.Linear(d_model, d_model) |
|
|
|
self.attn = ScaledDotProductAttention(scale=math.sqrt(self.depth)) |
|
|
|
def forward(self, q, k, v, mask=None): |
|
batch_size = q.size(0) |
|
|
|
q = self.wq(q).view(batch_size, -1, self.n_heads, self.depth).transpose(1, 2) |
|
k = self.wk(k).view(batch_size, -1, self.n_heads, self.depth).transpose(1, 2) |
|
v = self.wv(v).view(batch_size, -1, self.n_heads, self.depth).transpose(1, 2) |
|
|
|
attn_out, _ = self.attn(q, k, v, mask=mask) |
|
attn_out = ( |
|
attn_out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) |
|
) |
|
|
|
out = self.dense(attn_out) |
|
|
|
return out |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
def __init__(self, d_model, n_heads, ff_dim, dropout=0.1): |
|
super().__init__() |
|
self.attn = MultiHeadAttention(n_heads, d_model) |
|
self.ff = nn.Sequential( |
|
nn.Linear(d_model, ff_dim), |
|
nn.ReLU(), |
|
nn.Linear(ff_dim, d_model), |
|
) |
|
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x, mask=None): |
|
attn_out = self.attn(x, x, x, mask=mask) |
|
x = self.ln1(x + self.dropout(attn_out)) |
|
|
|
ff_out = self.ff(x) |
|
x = self.ln2(x + self.dropout(ff_out)) |
|
|
|
return x |
|
|
|
|
|
class TransformerClassifierConfig(PretrainedConfig): |
|
model_type = "transformer-checker" |
|
|
|
def __init__( |
|
self, |
|
in_dim=512, |
|
d_model=256, |
|
n_heads=8, |
|
ff_dim=2048, |
|
n_layers=6, |
|
n_classes=2, |
|
**kwargs, |
|
): |
|
self.in_dim = in_dim |
|
self.d_model = d_model |
|
self.n_heads = n_heads |
|
self.ff_dim = ff_dim |
|
self.n_layers = n_layers |
|
self.n_classes = n_classes |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class TransformerClassifier(PreTrainedModel): |
|
config_class = TransformerClassifierConfig |
|
|
|
def __init__(self, config: TransformerClassifierConfig): |
|
super().__init__(config) |
|
self.embedding = nn.Linear(config.in_dim, config.d_model) |
|
self.encoders = nn.ModuleList( |
|
[ |
|
TransformerEncoderLayer(config.d_model, config.n_heads, config.ff_dim) |
|
for _ in range(config.n_layers) |
|
] |
|
) |
|
self.classifier = nn.Linear(config.d_model, config.n_classes) |
|
|
|
def forward(self, x, mask=None): |
|
x = self.embedding(x) |
|
for encoder in self.encoders: |
|
x = encoder(x, mask=mask) |
|
|
|
x = self.classifier(x[:, 0]) |
|
|
|
return x |
|
|