|
import torch |
|
from torch import nn |
|
from transformers import PreTrainedModel |
|
|
|
from transformers import PretrainedConfig |
|
|
|
class CustomClassificationConfig(PretrainedConfig): |
|
model_type = "custom_classifier" |
|
|
|
def __init__(self, input_dim=32, hidden_dim=64, num_classes=2, **kwargs): |
|
super().__init__(**kwargs) |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.num_classes = num_classes |
|
|
|
class CustomClassifier(PreTrainedModel): |
|
config_class = CustomClassificationConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.encoder = nn.Sequential( |
|
nn.Linear(config.input_dim, config.hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(config.hidden_dim, config.hidden_dim), |
|
nn.ReLU(), |
|
) |
|
self.classifier = nn.Linear(config.hidden_dim, config.num_classes) |
|
|
|
def forward(self, input_ids=None, labels=None, **kwargs): |
|
|
|
hidden = self.encoder(input_ids) |
|
logits = self.classifier(hidden) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fn = nn.CrossEntropyLoss() |
|
loss = loss_fn(logits, labels) |
|
|
|
return {"loss": loss, "logits": logits} |
|
|