testspace / inference.py
kdevoe's picture
Create inference.py
240d9e8 verified
raw history blame
No virus
810 Bytes
import torch # Import PyTorch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch import nn
from transformers import AutoModel, AutoTokenizer
class DebertaEvaluator(nn.Module):
def __init__(self):
super().__init__()
self.deberta = AutoModel.from_pretrained('microsoft/deberta-v3-base')
self.dropout = nn.Dropout(0.5)
self.linear = nn.Linear(768, 6)
def forward(self, input_id, mask):
output = self.deberta(input_ids=input_id, attention_mask=mask)
output_pooled = torch.mean(output.last_hidden_state, 1)
dropout_output = self.dropout(output_pooled)
linear_output = self.linear(dropout_output)
return linear_output