File size: 853 Bytes
8514b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torch import nn
from transformers import DistilBertModel, DistilBertTokenizer

class MultilabelClassifier(nn.Module):
    """Base model for multilabel classification supporting different backbones"""
    
    def __init__(self, model_name, num_labels, dropout=0.1):
        super(MultilabelClassifier, self).__init__()
        
        self.backbone = DistilBertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(768, num_labels)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        x = self.dropout(pooled_output)
        logits = self.classifier(x)
        return self.sigmoid(logits)