File size: 609 Bytes
190f036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
import torch

tokenizer = AutoTokenizer.from_pretrained(
    "textattack/distilbert-base-uncased-CoLA")
model = AutoModelForSequenceClassification.from_pretrained(
    "textattack/distilbert-base-uncased-CoLA")


def classify_correctness(sentence: str):
    encoded_input = tokenizer(sentence, return_tensors='pt')
    output = model(**encoded_input)
    output_softmaxed = F.softmax(output[0], dim=1)
    correct = output_softmaxed.detach().numpy()[:].flatten()[1]
    return "{:.2f}".format(correct)