test_model / modeling.py
edur0's picture
Upload folder using huggingface_hub
43e8cb0 verified
from transformers import BertForSequenceClassification, AutoConfig, AutoTokenizer
import torch.nn as nn
class CustomBertForSequenceClassification(BertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
# Replace the default classifier (single linear layer) with a Sequential head
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size, config.custom_head_hidden_size),
nn.ReLU(),
nn.Linear(config.custom_head_hidden_size, config.num_labels)
)