File size: 810 Bytes
240d9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch        # Import PyTorch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch import nn
from transformers import AutoModel, AutoTokenizer

class DebertaEvaluator(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.deberta = AutoModel.from_pretrained('microsoft/deberta-v3-base')
        self.dropout = nn.Dropout(0.5)
        self.linear = nn.Linear(768, 6)
        
    def forward(self, input_id, mask):
        output = self.deberta(input_ids=input_id, attention_mask=mask)
        output_pooled = torch.mean(output.last_hidden_state, 1)
        dropout_output = self.dropout(output_pooled)
        linear_output = self.linear(dropout_output)
        
        return linear_output