|
import torch |
|
from torch import nn |
|
from transformers import DistilBertModel, DistilBertTokenizer |
|
|
|
class MultilabelClassifier(nn.Module): |
|
"""Base model for multilabel classification supporting different backbones""" |
|
|
|
def __init__(self, model_name, num_labels, dropout=0.1): |
|
super(MultilabelClassifier, self).__init__() |
|
|
|
self.backbone = DistilBertModel.from_pretrained(model_name) |
|
self.dropout = nn.Dropout(dropout) |
|
self.classifier = nn.Linear(768, num_labels) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) |
|
pooled_output = outputs.last_hidden_state[:, 0] |
|
x = self.dropout(pooled_output) |
|
logits = self.classifier(x) |
|
return self.sigmoid(logits) |
|
|