yhamidullah's picture
update models.py
4365a31 verified
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers import PretrainedConfig
class CustomClassificationConfig(PretrainedConfig):
model_type = "custom_classifier"
def __init__(self, input_dim=32, hidden_dim=64, num_classes=2, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_classes = num_classes
class CustomClassifier(PreTrainedModel):
config_class = CustomClassificationConfig
def __init__(self, config):
super().__init__(config)
self.encoder = nn.Sequential(
nn.Linear(config.input_dim, config.hidden_dim),
nn.ReLU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.ReLU(),
)
self.classifier = nn.Linear(config.hidden_dim, config.num_classes)
def forward(self, input_ids=None, labels=None, **kwargs):
# input_ids: shape (batch_size, input_dim)
hidden = self.encoder(input_ids)
logits = self.classifier(hidden)
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}