from transformers import AutoModel, AutoTokenizer import torch.nn as nn import torch class SingleLabelClassifier(nn.Module): def __init__(self, base_model_name, num_labels, hidden_size=2024, freeze_bert=True): super(SingleLabelClassifier, self).__init__() self.base = AutoModel.from_pretrained(base_model_name) if freeze_bert: for name, param in self.base.named_parameters(): if not name.startswith("embeddings"): param.requires_grad = False self.intermediate = nn.Linear(self.base.config.hidden_size, hidden_size) self.norm = nn.LayerNorm(hidden_size) self.activation = nn.ReLU() self.dropout = nn.Dropout(0.5) self.classifier = nn.Linear(hidden_size, num_labels) def forward(self, input_ids, attention_mask=None, token_type_ids=None,labels=None): outputs = self.base( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True ) pooled_output = outputs.last_hidden_state[:, 0] x = self.intermediate(pooled_output) x = self.norm(x) x = self.activation(x) x = self.dropout(x) logits = self.classifier(x) loss = None if labels is not None: labels = labels.long() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) return {"logits": logits, "loss": loss}