fact-checking-rocks / app_utils /entailment_checker.py
anakin87
torch inference_mode instead of no_grad and other optimization
a0ccfc1
raw
history blame
No virus
4.61 kB
from typing import List, Optional
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import torch
from haystack.nodes.base import BaseComponent
from haystack.modeling.utils import initialize_device_settings
from haystack.schema import Document
class EntailmentChecker(BaseComponent):
"""
This node checks the entailment between every document content and the query.
It enrichs the documents metadata with entailment informations.
It also returns aggregate entailment information.
"""
outgoing_edges = 1
def __init__(
self,
model_name_or_path: str = "roberta-large-mnli",
model_version: Optional[str] = None,
tokenizer: Optional[str] = None,
use_gpu: bool = True,
batch_size: int = 16,
entailment_contradiction_threshold: float = 0.5,
):
"""
Load a Natural Language Inference model from Transformers.
:param model_name_or_path: Directory of a saved model or the name of a public model.
See https://huggingface.co/models for full list of available models.
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param tokenizer: Name of the tokenizer (usually the same as model)
:param use_gpu: Whether to use GPU (if available).
:param batch_size: Number of Documents to be processed at a time.
:param entailment_contradiction_threshold: if in the first N documents there is a strong evidence of entailment/contradiction
(aggregate entailment or contradiction are greater than the threshold), the less relevant documents are not taken into account
"""
super().__init__()
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
tokenizer = tokenizer or model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=model_name_or_path, revision=model_version
)
self.batch_size = batch_size
self.entailment_contradiction_threshold = entailment_contradiction_threshold
self.model.to(str(self.devices[0]))
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
self.labels = [id2label[k].lower() for k in sorted(id2label)]
if "entailment" not in self.labels:
raise ValueError(
"The model config must contain entailment value in the id2label dict."
)
def run(self, query: str, documents: List[Document]):
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
for i, doc in enumerate(documents):
entailment_info = self.get_entailment(premise=doc.content, hypotesis=query)
doc.meta["entailment_info"] = entailment_info
scores += doc.score
con, neu, ent = (
entailment_info["contradiction"],
entailment_info["neutral"],
entailment_info["entailment"],
)
agg_con += con * doc.score
agg_neu += neu * doc.score
agg_ent += ent * doc.score
# if in the first documents there is a strong evidence of entailment/contradiction,
# there is no need to consider less relevant documents
if max(agg_con, agg_ent) / scores > self.entailment_contradiction_threshold:
break
aggregate_entailment_info = {
"contradiction": round(agg_con / scores, 2),
"neutral": round(agg_neu / scores, 2),
"entailment": round(agg_ent / scores, 2),
}
entailment_checker_result = {
"documents": documents[: i + 1],
"aggregate_entailment_info": aggregate_entailment_info,
}
return entailment_checker_result, "output_1"
def run_batch(self, queries: List[str], documents: List[Document]):
pass
def get_entailment(self, premise, hypotesis):
with torch.inference_mode():
inputs = self.tokenizer(
f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt"
).to(self.devices[0])
out = self.model(**inputs)
logits = out.logits
probs = (
torch.nn.functional.softmax(logits, dim=-1)[0, :].detach().cpu().numpy()
)
entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
return entailment_dict