Liyan06 commited on
Commit
ac1a8a5
1 Parent(s): f2ce59c

input format debug

Browse files
Files changed (2) hide show
  1. handler.py +1 -1
  2. minicheck_web/minicheck.py +3 -18
handler.py CHANGED
@@ -6,6 +6,6 @@ class EndpointHandler():
6
 
7
  def __call__(self, data):
8
 
9
- _, raw_prob, _, _ = self.scorer.score(inputs=data)
10
 
11
  return raw_prob
 
6
 
7
  def __call__(self, data):
8
 
9
+ _, raw_prob, _, _ = self.scorer.score(data=data)
10
 
11
  return raw_prob
minicheck_web/minicheck.py CHANGED
@@ -18,7 +18,7 @@ class MiniCheck:
18
  max_input_length=max_input_length,
19
  )
20
 
21
- def score(self, inputs: Dict) -> List[float]:
22
  '''
23
  pred_labels: 0 / 1 (0: unsupported, 1: supported)
24
  max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
@@ -26,6 +26,7 @@ class MiniCheck:
26
  support_prob_per_chunk: the probability of "supported" for each chunk
27
  '''
28
 
 
29
  docs = inputs['docs']
30
  claims = inputs['claims']
31
 
@@ -35,20 +36,4 @@ class MiniCheck:
35
  max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
36
  pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
37
 
38
- return pred_label, max_support_prob, used_chunk, support_prob_per_chunk
39
-
40
-
41
- if __name__ == '__main__':
42
-
43
- path = "./"
44
-
45
- doc = "A group of students gather in the school library to study for their upcoming final exams."
46
- claim_1 = "The students are preparing for an examination."
47
- claim_2 = "The students are on vacation."
48
-
49
- # flan-t5-large
50
- scorer = MiniCheck(path)
51
- pred_label, raw_prob, _, _ = scorer.score(docs=[doc, doc], claims=[claim_1, claim_2])
52
-
53
- print(pred_label) # [1, 0]
54
- print(raw_prob) # [0.9805923700332642, 0.007121307775378227]
 
18
  max_input_length=max_input_length,
19
  )
20
 
21
+ def score(self, data: Dict) -> List[float]:
22
  '''
23
  pred_labels: 0 / 1 (0: unsupported, 1: supported)
24
  max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
 
26
  support_prob_per_chunk: the probability of "supported" for each chunk
27
  '''
28
 
29
+ inputs = data['inputs']
30
  docs = inputs['docs']
31
  claims = inputs['claims']
32
 
 
36
  max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
37
  pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
38
 
39
+ return pred_label, max_support_prob, used_chunk, support_prob_per_chunk