File size: 1,870 Bytes
f388ec1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import evaluate
import torch
from enum import Enum
from tqdm import tqdm


class AssertionType(Enum):
    PRESENT = 0
    ABSENT = 1
    POSSIBLE = 2


class EntityWithAssertion:
    def __init__(self, entity: str, assertion_type: AssertionType):
        self.entity = entity
        self.assertion_type = assertion_type

    def __repr__(self) -> str:
        return f"{self.assertion_type.name}: {self.entity}"


def classify_assertions_in_sentences(sentences, model, tokenizer, batch_size=32):
    predictions = []
    for i in tqdm(range(0, len(sentences), batch_size)):
        batch = tokenizer(sentences[i:i + batch_size], return_tensors="pt", padding=True, truncation=True).to("cuda")
        with torch.no_grad():
            outputs = model(**batch)
        predicted_labels = torch.argmax(outputs.logits, dim=1)
        predictions.append(predicted_labels)
    return torch.cat(predictions)


def input_classification(model, tokenizer, x: str = None, all_classes = False):
    if x is None:
        x = input("Write your sentence and press Enter to continue")
    tokenized_x = tokenizer(x, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**tokenized_x)
    predicted_label = torch.argmax(outputs.logits, dim=1)
    if all_classes:
        return {model.config.id2label[i]:float(k) for i,k in enumerate(torch.softmax(outputs.logits, dim=1)[0])}
    return model.config.id2label[int(predicted_label)]


def compute_results(y, y_hat):
    metric_f1 = evaluate.load("f1")
    metric_acc = evaluate.load("accuracy")
    return {
        "macro-f1": metric_f1.compute(predictions=y_hat, references=y, average="macro")["f1"],
        "micro-f1": metric_f1.compute(predictions=y_hat, references=y, average="micro")["f1"],
        "accuracy": metric_acc.compute(predictions=y_hat, references=y)["accuracy"]
    }