File size: 551 Bytes
43e8cb0
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from transformers import BertForSequenceClassification, AutoConfig, AutoTokenizer
import torch.nn as nn

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        # Replace the default classifier (single linear layer) with a Sequential head
        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.custom_head_hidden_size),
            nn.ReLU(),
            nn.Linear(config.custom_head_hidden_size, config.num_labels)
        )