File size: 5,840 Bytes
d6bdb02
 
1434337
d6bdb02
 
 
4c41de2
d6bdb02
1434337
d6bdb02
 
 
4c41de2
 
d6bdb02
 
 
 
 
 
 
 
 
 
 
4c41de2
d6bdb02
 
 
 
 
 
 
 
 
4c41de2
 
 
d6bdb02
 
 
 
 
 
 
1434337
 
 
d6bdb02
4c41de2
d6bdb02
1434337
d6bdb02
1434337
 
 
 
 
 
d6bdb02
1434337
4c41de2
5e92245
 
 
 
4c41de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34f14d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1434337
5e92245
 
 
 
 
 
a0ccfc1
5e92245
d6bdb02
 
5e92245
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import List, Optional

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import torch
from haystack.nodes.base import BaseComponent
from haystack.modeling.utils import initialize_device_settings
from haystack.schema import Document


class EntailmentChecker(BaseComponent):
    """
    This node checks the entailment between every document content and the query.
    It enrichs the documents metadata with entailment informations.
    It also returns aggregate entailment information.
    """

    outgoing_edges = 1

    def __init__(
        self,
        model_name_or_path: str = "roberta-large-mnli",
        model_version: Optional[str] = None,
        tokenizer: Optional[str] = None,
        use_gpu: bool = True,
        batch_size: int = 16,
        entailment_contradiction_threshold: float = 0.5,
    ):
        """
        Load a Natural Language Inference model from Transformers.

        :param model_name_or_path: Directory of a saved model or the name of a public model.
        See https://huggingface.co/models for full list of available models.
        :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
        :param tokenizer: Name of the tokenizer (usually the same as model)
        :param use_gpu: Whether to use GPU (if available).
        :param batch_size: Number of Documents to be processed at a time.
        :param entailment_contradiction_threshold: if in the first N documents there is a strong evidence of entailment/contradiction
        (aggregate entailment or contradiction are greater than the threshold), the less relevant documents are not taken into account
        """
        super().__init__()

        self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)

        tokenizer = tokenizer or model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            pretrained_model_name_or_path=model_name_or_path, revision=model_version
        )
        self.batch_size = batch_size
        self.entailment_contradiction_threshold = entailment_contradiction_threshold
        self.model.to(str(self.devices[0]))

        id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
        self.labels = [id2label[k].lower() for k in sorted(id2label)]
        if "entailment" not in self.labels:
            raise ValueError(
                "The model config must contain entailment value in the id2label dict."
            )

    def run(self, query: str, documents: List[Document]):

        scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
        premise_batch = [doc.content for doc in documents]
        hypotesis_batch = [query] * len(documents)
        entailment_info_batch = self.get_entailment_batch(premise_batch=premise_batch, hypotesis_batch=hypotesis_batch)
        for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)):
            doc.meta["entailment_info"] = entailment_info

            scores += doc.score
            con, neu, ent = (
                entailment_info["contradiction"],
                entailment_info["neutral"],
                entailment_info["entailment"],
            )
            agg_con += con * doc.score
            agg_neu += neu * doc.score
            agg_ent += ent * doc.score

            # if in the first documents there is a strong evidence of entailment/contradiction,
            # there is no need to consider less relevant documents
            if max(agg_con, agg_ent) / scores > self.entailment_contradiction_threshold:
                break

        aggregate_entailment_info = {
            "contradiction": round(agg_con / scores, 2),
            "neutral": round(agg_neu / scores, 2),
            "entailment": round(agg_ent / scores, 2),
        }

        entailment_checker_result = {
            "documents": documents[: i + 1],
            "aggregate_entailment_info": aggregate_entailment_info,
        }

        return entailment_checker_result, "output_1"

    def run_batch(self, queries: List[str], documents: List[Document]):
        entailment_checker_result_batch = []
        entailment_info_batch = self.get_entailment_batch(premise_batch=documents, hypotesis_batch=queries)
        for doc, entailment_info in zip(documents, entailment_info_batch):
            doc.meta["entailment_info"] = entailment_info
            aggregate_entailment_info = {
                "contradiction": round(entailment_info["contradiction"] / doc.score),
                "neutral": round(entailment_info["neutral"] / doc.score),
                "entailment": round(entailment_info["entailment"] / doc.score),
            }
            entailment_checker_result_batch.append({
                "documents": [doc],
                "aggregate_entailment_info": aggregate_entailment_info,
            })
        return entailment_checker_result_batch, "output_1"


    def get_entailment_dict(self, probs):
        entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
        return entailment_dict

    def get_entailment_batch(self, premise_batch: List[str], hypotesis_batch: List[str]):
        formatted_texts = [f"{premise}{self.tokenizer.sep_token}{hypotesis}" for premise, hypotesis in zip(premise_batch, hypotesis_batch)]
        with torch.inference_mode():
            inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to(self.devices[0])
            out = self.model(**inputs)
            logits = out.logits
            probs_batch = (torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy() )
        return [self.get_entailment_dict(probs) for probs in probs_batch]