File size: 4,724 Bytes
0952d3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import math

import torch
import torch.nn as nn
from transformer_lens.HookedTransformer import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.train import HookedTransformerTrainConfig, train
from transformers import PretrainedConfig, PreTrainedModel


def generate_config(
    n_ctx,
    d_model,
    d_head,
    n_heads,
    d_mlp,
    n_layers,
    attention_dir,
    act_fn,
    d_vocab,
    d_vocab_out,
    use_attn_result,
    device,
    use_hook_tokens,
):
    return HookedTransformerConfig(
        n_ctx=n_ctx,
        d_model=d_model,
        d_head=d_head,
        n_heads=n_heads,
        d_mlp=d_mlp,
        n_layers=n_layers,
        attention_dir=attention_dir,
        act_fn=act_fn,
        d_vocab=d_vocab,
        d_vocab_out=d_vocab_out,
        use_attn_result=use_attn_result,
        device=device,
        use_hook_tokens=use_hook_tokens,
    )


def generate_model(config):
    return HookedTransformer(config)


def train_model(model, n_epochs, batch_size, lr, dataset):
    train_cfg = HookedTransformerTrainConfig(
        num_epochs=n_epochs, batch_size=128, lr=0.001, device="cuda:0"
    )

    return train(model, train_cfg, dataset)


class ScaledDotProductAttention(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q, k.transpose(-2, -1)) * 1 / self.scale
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float("-inf"))

        attn = torch.softmax(attn, dim=-1)

        out = torch.matmul(attn, v)
        return out, attn


class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model):
        super().__init__()
        assert d_model % n_heads == 0, "d_model should be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.depth = d_model // n_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.dense = nn.Linear(d_model, d_model)

        self.attn = ScaledDotProductAttention(scale=math.sqrt(self.depth))

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        q = self.wq(q).view(batch_size, -1, self.n_heads, self.depth).transpose(1, 2)
        k = self.wk(k).view(batch_size, -1, self.n_heads, self.depth).transpose(1, 2)
        v = self.wv(v).view(batch_size, -1, self.n_heads, self.depth).transpose(1, 2)

        attn_out, _ = self.attn(q, k, v, mask=mask)
        attn_out = (
            attn_out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        )

        out = self.dense(attn_out)

        return out


class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(n_heads, d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, d_model),
        )

        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out = self.attn(x, x, x, mask=mask)
        x = self.ln1(x + self.dropout(attn_out))

        ff_out = self.ff(x)
        x = self.ln2(x + self.dropout(ff_out))

        return x


class TransformerClassifierConfig(PretrainedConfig):
    model_type = "transformer-checker"

    def __init__(
        self,
        in_dim=512,
        d_model=256,
        n_heads=8,
        ff_dim=2048,
        n_layers=6,
        n_classes=2,
        **kwargs,
    ):
        self.in_dim = in_dim
        self.d_model = d_model
        self.n_heads = n_heads
        self.ff_dim = ff_dim
        self.n_layers = n_layers
        self.n_classes = n_classes

        super().__init__(**kwargs)


class TransformerClassifier(PreTrainedModel):
    config_class = TransformerClassifierConfig

    def __init__(self, config: TransformerClassifierConfig):
        super().__init__(config)
        self.embedding = nn.Linear(config.in_dim, config.d_model)
        self.encoders = nn.ModuleList(
            [
                TransformerEncoderLayer(config.d_model, config.n_heads, config.ff_dim)
                for _ in range(config.n_layers)
            ]
        )
        self.classifier = nn.Linear(config.d_model, config.n_classes)

    def forward(self, x, mask=None):
        x = self.embedding(x)
        for encoder in self.encoders:
            x = encoder(x, mask=mask)

        x = self.classifier(x[:, 0])

        return x