Stefano Fiorucci commited on
Commit
97688d7
·
unverified ·
2 Parent(s): dbd4f9e 34f14d6

Merge pull request #1 from davidberenstein1957/main

Browse files
Files changed (1) hide show
  1. app_utils/entailment_checker.py +29 -12
app_utils/entailment_checker.py CHANGED
@@ -60,8 +60,10 @@ class EntailmentChecker(BaseComponent):
60
  def run(self, query: str, documents: List[Document]):
61
 
62
  scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
63
- for i, doc in enumerate(documents):
64
- entailment_info = self.get_entailment(premise=doc.content, hypotesis=query)
 
 
65
  doc.meta["entailment_info"] = entailment_info
66
 
67
  scores += doc.score
@@ -93,17 +95,32 @@ class EntailmentChecker(BaseComponent):
93
  return entailment_checker_result, "output_1"
94
 
95
  def run_batch(self, queries: List[str], documents: List[Document]):
96
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- def get_entailment(self, premise, hypotesis):
 
99
  with torch.inference_mode():
100
- inputs = self.tokenizer(
101
- f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt"
102
- ).to(self.devices[0])
103
  out = self.model(**inputs)
104
  logits = out.logits
105
- probs = (
106
- torch.nn.functional.softmax(logits, dim=-1)[0, :].detach().cpu().numpy()
107
- )
108
- entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
109
- return entailment_dict
 
60
  def run(self, query: str, documents: List[Document]):
61
 
62
  scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
63
+ premise_batch = [doc.content for doc in documents]
64
+ hypotesis_batch = [query] * len(documents)
65
+ entailment_info_batch = self.get_entailment_batch(premise_batch=premise_batch, hypotesis_batch=hypotesis_batch)
66
+ for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)):
67
  doc.meta["entailment_info"] = entailment_info
68
 
69
  scores += doc.score
 
95
  return entailment_checker_result, "output_1"
96
 
97
  def run_batch(self, queries: List[str], documents: List[Document]):
98
+ entailment_checker_result_batch = []
99
+ entailment_info_batch = self.get_entailment_batch(premise_batch=documents, hypotesis_batch=queries)
100
+ for doc, entailment_info in zip(documents, entailment_info_batch):
101
+ doc.meta["entailment_info"] = entailment_info
102
+ aggregate_entailment_info = {
103
+ "contradiction": round(entailment_info["contradiction"] / doc.score),
104
+ "neutral": round(entailment_info["neutral"] / doc.score),
105
+ "entailment": round(entailment_info["entailment"] / doc.score),
106
+ }
107
+ entailment_checker_result_batch.append({
108
+ "documents": [doc],
109
+ "aggregate_entailment_info": aggregate_entailment_info,
110
+ })
111
+ return entailment_checker_result_batch, "output_1"
112
+
113
+
114
+ def get_entailment_dict(self, probs):
115
+ entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
116
+ return entailment_dict
117
 
118
+ def get_entailment_batch(self, premise_batch: List[str], hypotesis_batch: List[str]):
119
+ formatted_texts = [f"{premise}{self.tokenizer.sep_token}{hypotesis}" for premise, hypotesis in zip(premise_batch, hypotesis_batch)]
120
  with torch.inference_mode():
121
+ inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to(self.devices[0])
 
 
122
  out = self.model(**inputs)
123
  logits = out.logits
124
+ probs_batch = (torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy() )
125
+ return [self.get_entailment_dict(probs) for probs in probs_batch]
126
+