|
import evaluate |
|
import torch |
|
from enum import Enum |
|
from tqdm import tqdm |
|
|
|
|
|
class AssertionType(Enum): |
|
PRESENT = 0 |
|
ABSENT = 1 |
|
POSSIBLE = 2 |
|
|
|
|
|
class EntityWithAssertion: |
|
def __init__(self, entity: str, assertion_type: AssertionType): |
|
self.entity = entity |
|
self.assertion_type = assertion_type |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.assertion_type.name}: {self.entity}" |
|
|
|
|
|
def classify_assertions_in_sentences(sentences, model, tokenizer, batch_size=32): |
|
predictions = [] |
|
for i in tqdm(range(0, len(sentences), batch_size)): |
|
batch = tokenizer(sentences[i:i + batch_size], return_tensors="pt", padding=True, truncation=True).to("cuda") |
|
with torch.no_grad(): |
|
outputs = model(**batch) |
|
predicted_labels = torch.argmax(outputs.logits, dim=1) |
|
predictions.append(predicted_labels) |
|
return torch.cat(predictions) |
|
|
|
|
|
def input_classification(model, tokenizer, x: str = None, all_classes = False): |
|
if x is None: |
|
x = input("Write your sentence and press Enter to continue") |
|
tokenized_x = tokenizer(x, return_tensors="pt", padding=True, truncation=True) |
|
with torch.no_grad(): |
|
outputs = model(**tokenized_x) |
|
predicted_label = torch.argmax(outputs.logits, dim=1) |
|
if all_classes: |
|
return {model.config.id2label[i]:float(k) for i,k in enumerate(torch.softmax(outputs.logits, dim=1)[0])} |
|
return model.config.id2label[int(predicted_label)] |
|
|
|
|
|
def compute_results(y, y_hat): |
|
metric_f1 = evaluate.load("f1") |
|
metric_acc = evaluate.load("accuracy") |
|
return { |
|
"macro-f1": metric_f1.compute(predictions=y_hat, references=y, average="macro")["f1"], |
|
"micro-f1": metric_f1.compute(predictions=y_hat, references=y, average="micro")["f1"], |
|
"accuracy": metric_acc.compute(predictions=y_hat, references=y)["accuracy"] |
|
} |
|
|