dyck-3-transformer / transformer.py
matiasmolinolo's picture
Upload TransformerClassifier
0952d3d verified
import math
import torch
import torch.nn as nn
from transformer_lens.HookedTransformer import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.train import HookedTransformerTrainConfig, train
from transformers import PretrainedConfig, PreTrainedModel
def generate_config(
n_ctx,
d_model,
d_head,
n_heads,
d_mlp,
n_layers,
attention_dir,
act_fn,
d_vocab,
d_vocab_out,
use_attn_result,
device,
use_hook_tokens,
):
return HookedTransformerConfig(
n_ctx=n_ctx,
d_model=d_model,
d_head=d_head,
n_heads=n_heads,
d_mlp=d_mlp,
n_layers=n_layers,
attention_dir=attention_dir,
act_fn=act_fn,
d_vocab=d_vocab,
d_vocab_out=d_vocab_out,
use_attn_result=use_attn_result,
device=device,
use_hook_tokens=use_hook_tokens,
)
def generate_model(config):
return HookedTransformer(config)
def train_model(model, n_epochs, batch_size, lr, dataset):
train_cfg = HookedTransformerTrainConfig(
num_epochs=n_epochs, batch_size=128, lr=0.001, device="cuda:0"
)
return train(model, train_cfg, dataset)
class ScaledDotProductAttention(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, q, k, v, mask=None):
attn = torch.matmul(q, k.transpose(-2, -1)) * 1 / self.scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
return out, attn
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, d_model):
super().__init__()
assert d_model % n_heads == 0, "d_model should be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.depth = d_model // n_heads
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.dense = nn.Linear(d_model, d_model)
self.attn = ScaledDotProductAttention(scale=math.sqrt(self.depth))
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
q = self.wq(q).view(batch_size, -1, self.n_heads, self.depth).transpose(1, 2)
k = self.wk(k).view(batch_size, -1, self.n_heads, self.depth).transpose(1, 2)
v = self.wv(v).view(batch_size, -1, self.n_heads, self.depth).transpose(1, 2)
attn_out, _ = self.attn(q, k, v, mask=mask)
attn_out = (
attn_out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
)
out = self.dense(attn_out)
return out
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, ff_dim, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(n_heads, d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, d_model),
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_out = self.attn(x, x, x, mask=mask)
x = self.ln1(x + self.dropout(attn_out))
ff_out = self.ff(x)
x = self.ln2(x + self.dropout(ff_out))
return x
class TransformerClassifierConfig(PretrainedConfig):
model_type = "transformer-checker"
def __init__(
self,
in_dim=512,
d_model=256,
n_heads=8,
ff_dim=2048,
n_layers=6,
n_classes=2,
**kwargs,
):
self.in_dim = in_dim
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.n_layers = n_layers
self.n_classes = n_classes
super().__init__(**kwargs)
class TransformerClassifier(PreTrainedModel):
config_class = TransformerClassifierConfig
def __init__(self, config: TransformerClassifierConfig):
super().__init__(config)
self.embedding = nn.Linear(config.in_dim, config.d_model)
self.encoders = nn.ModuleList(
[
TransformerEncoderLayer(config.d_model, config.n_heads, config.ff_dim)
for _ in range(config.n_layers)
]
)
self.classifier = nn.Linear(config.d_model, config.n_classes)
def forward(self, x, mask=None):
x = self.embedding(x)
for encoder in self.encoders:
x = encoder(x, mask=mask)
x = self.classifier(x[:, 0])
return x