File size: 551 Bytes
43e8cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 |
from transformers import BertForSequenceClassification, AutoConfig, AutoTokenizer
import torch.nn as nn
class CustomBertForSequenceClassification(BertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
# Replace the default classifier (single linear layer) with a Sequential head
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size, config.custom_head_hidden_size),
nn.ReLU(),
nn.Linear(config.custom_head_hidden_size, config.num_labels)
) |