nlp-bert-team / bot /model.py
VerVelVel's picture
bot and new weights
cdb0abe
raw
history blame
784 Bytes
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class BERTClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModelForSequenceClassification.from_pretrained('cointegrated/rubert-tiny-toxicity')
self.bert.classifier = nn.Linear(312, 312)
for param in self.bert.parameters():
param.requires_grad = False
self.linear = nn.Sequential(
nn.Linear(312, 128),
nn.Sigmoid(),
nn.Dropout(),
nn.Linear(128, 1)
)
def forward(self, x, attention_mask=None):
bert_out = self.bert(x, attention_mask=attention_mask).logits
out = self.linear(bert_out).squeeze(1)
return out