File size: 3,499 Bytes
bccac83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from torch import nn
from transformers import ElectraPreTrainedModel, ElectraModel

# Custom activation function
class SwishGLU(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(SwishGLU, self).__init__()
        self.projection = nn.Linear(input_dim, 2 * output_dim)
        self.activation = nn.SiLU()

    def forward(self, x):
        x_proj_gate = self.projection(x)
        projected, gate = x_proj_gate.tensor_split(2, dim=-1)
        return projected * self.activation(gate)


# Custom pooling layer
class PoolingLayer(nn.Module):
    def __init__(self, pooling_type='cls'):
        super().__init__()
        self.pooling_type = pooling_type

    def forward(self, last_hidden_state, attention_mask):
        if self.pooling_type == 'cls':
            return last_hidden_state[:, 0, :]
        elif self.pooling_type == 'mean':
            # Mean pooling over the token embeddings
            return (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
        elif self.pooling_type == 'max':
            # Max pooling over the token embeddings
            return torch.max(last_hidden_state * attention_mask.unsqueeze(-1), dim=1)[0]
        else:
            raise ValueError(f"Unknown pooling method: {self.pooling_type}")


# Custom classifier
class Classifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, hidden_activation, num_layers, n_classes, dropout_rate=0.0):
        super().__init__()
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(hidden_activation)
        if dropout_rate > 0:
            layers.append(nn.Dropout(dropout_rate))

        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(hidden_activation)
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))

        layers.append(nn.Linear(hidden_dim, n_classes))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


# Custom Electra classifier model
class ElectraClassifier(ElectraPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.electra = ElectraModel(config)

        if hasattr(self.electra, 'pooler'):
            self.electra.pooler = None

        self.pooling = PoolingLayer(pooling_type=config.pooling)

        # Handle custom activation functions
        activation_name = config.hidden_activation
        if activation_name == 'SwishGLU':
            hidden_activation = SwishGLU(
                input_dim=config.hidden_dim,
                output_dim=config.hidden_dim
            )
        else:
            activation_class = getattr(nn, activation_name)
            hidden_activation = activation_class()

        self.classifier = Classifier(
            input_dim=config.hidden_size,
            hidden_dim=config.hidden_dim,
            hidden_activation=hidden_activation,
            num_layers=config.num_layers,
            n_classes=config.num_labels,
            dropout_rate=config.dropout_rate
        )
        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        outputs = self.electra(input_ids, attention_mask=attention_mask, **kwargs)
        pooled_output = self.pooling(outputs.last_hidden_state, attention_mask)
        logits = self.classifier(pooled_output)
        return logits