Spaces:
Runtime error
Runtime error
import evaluate | |
import torch | |
from enum import Enum | |
from tqdm import tqdm | |
class AssertionType(Enum): | |
PRESENT = 0 | |
ABSENT = 1 | |
POSSIBLE = 2 | |
class EntityWithAssertion: | |
def __init__(self, entity: str, assertion_type: AssertionType): | |
self.entity = entity | |
self.assertion_type = assertion_type | |
def __repr__(self) -> str: | |
return f"{self.assertion_type.name}: {self.entity}" | |
def classify_assertions_in_sentences(sentences, model, tokenizer, batch_size=32): | |
predictions = [] | |
for i in tqdm(range(0, len(sentences), batch_size)): | |
batch = tokenizer(sentences[i:i + batch_size], return_tensors="pt", padding=True, truncation=True).to("cuda") | |
with torch.no_grad(): | |
outputs = model(**batch) | |
predicted_labels = torch.argmax(outputs.logits, dim=1) | |
predictions.append(predicted_labels) | |
return torch.cat(predictions) | |
def input_classification(model, tokenizer, x: str = None, all_classes = False): | |
if x is None: | |
x = input("Write your sentence and press Enter to continue") | |
tokenized_x = tokenizer(x, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**tokenized_x) | |
predicted_label = torch.argmax(outputs.logits, dim=1) | |
if all_classes: | |
return {model.config.id2label[i]:float(k) for i,k in enumerate(torch.softmax(outputs.logits, dim=1)[0])} | |
return model.config.id2label[int(predicted_label)] | |
def compute_results(y, y_hat): | |
metric_f1 = evaluate.load("f1") | |
metric_acc = evaluate.load("accuracy") | |
return { | |
"macro-f1": metric_f1.compute(predictions=y_hat, references=y, average="macro")["f1"], | |
"micro-f1": metric_f1.compute(predictions=y_hat, references=y, average="micro")["f1"], | |
"accuracy": metric_acc.compute(predictions=y_hat, references=y)["accuracy"] | |
} | |