anakin87 commited on
Commit
d6bdb02
·
1 Parent(s): a027256

class entailment_checker

Browse files
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
5
  colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
 
5
  colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
+ app_file: rock_fact_checker.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
app.py → Rock_fact_checker.py RENAMED
File without changes
app_utils/entailment_checker.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ from transformers import AutoModelForSequenceClassification,AutoTokenizer,AutoConfig
4
+ import torch
5
+ from haystack.nodes.base import BaseComponent
6
+ from haystack.modeling.utils import initialize_device_settings
7
+ from haystack.schema import Document, Answer, Span
8
+
9
+ class EntailmentChecker(BaseComponent):
10
+ """
11
+ This node checks the entailment between every document content and the query.
12
+ It enrichs the documents metadata with entailment_info
13
+ """
14
+
15
+ outgoing_edges = 1
16
+
17
+ def __init__(
18
+ self,
19
+ model_name_or_path: str = "roberta-large-mnli",
20
+ model_version: Optional[str] = None,
21
+ tokenizer: Optional[str] = None,
22
+ use_gpu: bool = True,
23
+ batch_size: int = 16,
24
+ ):
25
+ """
26
+ Load a Natural Language Inference model from Transformers.
27
+
28
+ :param model_name_or_path: Directory of a saved model or the name of a public model.
29
+ See https://huggingface.co/models for full list of available models.
30
+ :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
31
+ :param tokenizer: Name of the tokenizer (usually the same as model)
32
+ :param use_gpu: Whether to use GPU (if available).
33
+ # :param batch_size: Number of Documents to be processed at a time.
34
+ """
35
+ super().__init__()
36
+
37
+ self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
38
+
39
+ tokenizer = tokenizer or model_name_or_path
40
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
41
+ self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path,revision=model_version)
42
+ self.batch_size = batch_size
43
+ self.model.to(str(self.devices[0]))
44
+
45
+ id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
46
+ self.labels= [id2label[k].lower() for k in sorted(id2label)]
47
+ if 'entailment' not in self.labels:
48
+ raise ValueError("The model config must contain entailment value in the id2label dict.")
49
+
50
+ def run(self, query: str, documents: List[Document]):
51
+ for doc in documents:
52
+ entailment_dict=self.get_entailment(premise=doc.content, hypotesis=query)
53
+ doc.meta['entailment_info']=entailment_dict
54
+ return {'documents':documents}, "output_1"
55
+
56
+ def run_batch():
57
+ pass
58
+
59
+ def get_entailment(self, premise,hypotesis):
60
+ with torch.no_grad():
61
+ inputs = self.tokenizer(f'{premise}{self.tokenizer.sep_token}{hypotesis}', return_tensors="pt").to(self.devices[0])
62
+ out = self.model(**inputs)
63
+ logits = out.logits
64
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0,:].cpu().detach().numpy()
65
+ entailment_dict={k.lower():v for k,v in zip (self.labels, probs)}
66
+ return entailment_dict
pages/{app.py → Info.py} RENAMED
File without changes
pages/info.py DELETED
@@ -1,3 +0,0 @@
1
- import streamlit as st
2
-
3
- st.title("Test")