kdevoe commited on
Commit
240d9e8
1 Parent(s): 80792da

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +25 -0
inference.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch # Import PyTorch
2
+ import torch.optim as optim
3
+ import torch.optim.lr_scheduler as lr_scheduler
4
+ from torch.utils.data import DataLoader
5
+ from torch import nn
6
+ from transformers import AutoModel, AutoTokenizer
7
+
8
+ class DebertaEvaluator(nn.Module):
9
+
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ self.deberta = AutoModel.from_pretrained('microsoft/deberta-v3-base')
14
+ self.dropout = nn.Dropout(0.5)
15
+ self.linear = nn.Linear(768, 6)
16
+
17
+ def forward(self, input_id, mask):
18
+ output = self.deberta(input_ids=input_id, attention_mask=mask)
19
+ output_pooled = torch.mean(output.last_hidden_state, 1)
20
+ dropout_output = self.dropout(output_pooled)
21
+ linear_output = self.linear(dropout_output)
22
+
23
+ return linear_output
24
+
25
+