from transformers import BertForSequenceClassification, AutoConfig, AutoTokenizer | |
import torch.nn as nn | |
class CustomBertForSequenceClassification(BertForSequenceClassification): | |
def __init__(self, config): | |
super().__init__(config) | |
# Replace the default classifier (single linear layer) with a Sequential head | |
self.classifier = nn.Sequential( | |
nn.Linear(config.hidden_size, config.custom_head_hidden_size), | |
nn.ReLU(), | |
nn.Linear(config.custom_head_hidden_size, config.num_labels) | |
) |