davidberenstein1957 HF staff commited on
Commit
5e92245
·
1 Parent(s): 28c4c1d

chore: added batching to default `run`

Browse files
Files changed (1) hide show
  1. app_utils/entailment_checker.py +14 -11
app_utils/entailment_checker.py CHANGED
@@ -60,8 +60,10 @@ class EntailmentChecker(BaseComponent):
60
  def run(self, query: str, documents: List[Document]):
61
 
62
  scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
63
- for i, doc in enumerate(documents):
64
- entailment_info = self.get_entailment(premise=doc.content, hypotesis=query)
 
 
65
  doc.meta["entailment_info"] = entailment_info
66
 
67
  scores += doc.score
@@ -95,15 +97,16 @@ class EntailmentChecker(BaseComponent):
95
  def run_batch(self, queries: List[str], documents: List[Document]):
96
  pass
97
 
98
- def get_entailment(self, premise, hypotesis):
 
 
 
 
 
99
  with torch.inference_mode():
100
- inputs = self.tokenizer(
101
- f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt"
102
- ).to(self.devices[0])
103
  out = self.model(**inputs)
104
  logits = out.logits
105
- probs = (
106
- torch.nn.functional.softmax(logits, dim=-1)[0, :].detach().cpu().numpy()
107
- )
108
- entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
109
- return entailment_dict
 
60
  def run(self, query: str, documents: List[Document]):
61
 
62
  scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
63
+ premise_batch = [doc.content for doc in documents]
64
+ hypotesis_batch = [query] * len(documents)
65
+ entailment_info_batch = self.get_entailment_batch(premise_batch=premise_batch, hypotesis_batch=hypotesis_batch)
66
+ for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)):
67
  doc.meta["entailment_info"] = entailment_info
68
 
69
  scores += doc.score
 
97
  def run_batch(self, queries: List[str], documents: List[Document]):
98
  pass
99
 
100
+ def get_entailment_dict(self, probs):
101
+ entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
102
+ return entailment_dict
103
+
104
+ def get_entailment_batch(self, premise_batch: List[str], hypotesis_batch: List[str]):
105
+ formatted_texts = [f"{premise}{self.tokenizer.sep_token}{hypotesis}" for premise, hypotesis in zip(premise_batch, hypotesis_batch)]
106
  with torch.inference_mode():
107
+ inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to(self.devices[0])
 
 
108
  out = self.model(**inputs)
109
  logits = out.logits
110
+ probs_batch = (torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy() )
111
+ return [self.get_entailment_dict(probs) for probs in probs_batch]
112
+