fact-checking-rocks / app_utils /entailment_checker.py
anakin87
great progress in showing output
1434337
raw
history blame
3.04 kB
from typing import List, Optional
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import torch
from haystack.nodes.base import BaseComponent
from haystack.modeling.utils import initialize_device_settings
from haystack.schema import Document, Answer, Span
class EntailmentChecker(BaseComponent):
"""
This node checks the entailment between every document content and the query.
It enrichs the documents metadata with entailment_info
"""
outgoing_edges = 1
def __init__(
self,
model_name_or_path: str = "roberta-large-mnli",
model_version: Optional[str] = None,
tokenizer: Optional[str] = None,
use_gpu: bool = True,
batch_size: int = 16,
):
"""
Load a Natural Language Inference model from Transformers.
:param model_name_or_path: Directory of a saved model or the name of a public model.
See https://huggingface.co/models for full list of available models.
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param tokenizer: Name of the tokenizer (usually the same as model)
:param use_gpu: Whether to use GPU (if available).
# :param batch_size: Number of Documents to be processed at a time.
"""
super().__init__()
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
tokenizer = tokenizer or model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=model_name_or_path, revision=model_version
)
self.batch_size = batch_size
self.model.to(str(self.devices[0]))
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
self.labels = [id2label[k].lower() for k in sorted(id2label)]
if "entailment" not in self.labels:
raise ValueError(
"The model config must contain entailment value in the id2label dict."
)
def run(self, query: str, documents: List[Document]):
for doc in documents:
entailment_dict = self.get_entailment(premise=doc.content, hypotesis=query)
doc.meta["entailment_info"] = entailment_dict
return {"documents": documents}, "output_1"
def run_batch():
pass
def get_entailment(self, premise, hypotesis):
with torch.no_grad():
inputs = self.tokenizer(
f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt"
).to(self.devices[0])
out = self.model(**inputs)
logits = out.logits
probs = (
torch.nn.functional.softmax(logits, dim=-1)[0, :].cpu().detach().numpy()
)
entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
return entailment_dict