Liyan06
commited on
Commit
•
ac1a8a5
1
Parent(s):
f2ce59c
input format debug
Browse files- handler.py +1 -1
- minicheck_web/minicheck.py +3 -18
handler.py
CHANGED
@@ -6,6 +6,6 @@ class EndpointHandler():
|
|
6 |
|
7 |
def __call__(self, data):
|
8 |
|
9 |
-
_, raw_prob, _, _ = self.scorer.score(
|
10 |
|
11 |
return raw_prob
|
|
|
6 |
|
7 |
def __call__(self, data):
|
8 |
|
9 |
+
_, raw_prob, _, _ = self.scorer.score(data=data)
|
10 |
|
11 |
return raw_prob
|
minicheck_web/minicheck.py
CHANGED
@@ -18,7 +18,7 @@ class MiniCheck:
|
|
18 |
max_input_length=max_input_length,
|
19 |
)
|
20 |
|
21 |
-
def score(self,
|
22 |
'''
|
23 |
pred_labels: 0 / 1 (0: unsupported, 1: supported)
|
24 |
max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
|
@@ -26,6 +26,7 @@ class MiniCheck:
|
|
26 |
support_prob_per_chunk: the probability of "supported" for each chunk
|
27 |
'''
|
28 |
|
|
|
29 |
docs = inputs['docs']
|
30 |
claims = inputs['claims']
|
31 |
|
@@ -35,20 +36,4 @@ class MiniCheck:
|
|
35 |
max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
|
36 |
pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
|
37 |
|
38 |
-
return pred_label, max_support_prob, used_chunk, support_prob_per_chunk
|
39 |
-
|
40 |
-
|
41 |
-
if __name__ == '__main__':
|
42 |
-
|
43 |
-
path = "./"
|
44 |
-
|
45 |
-
doc = "A group of students gather in the school library to study for their upcoming final exams."
|
46 |
-
claim_1 = "The students are preparing for an examination."
|
47 |
-
claim_2 = "The students are on vacation."
|
48 |
-
|
49 |
-
# flan-t5-large
|
50 |
-
scorer = MiniCheck(path)
|
51 |
-
pred_label, raw_prob, _, _ = scorer.score(docs=[doc, doc], claims=[claim_1, claim_2])
|
52 |
-
|
53 |
-
print(pred_label) # [1, 0]
|
54 |
-
print(raw_prob) # [0.9805923700332642, 0.007121307775378227]
|
|
|
18 |
max_input_length=max_input_length,
|
19 |
)
|
20 |
|
21 |
+
def score(self, data: Dict) -> List[float]:
|
22 |
'''
|
23 |
pred_labels: 0 / 1 (0: unsupported, 1: supported)
|
24 |
max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
|
|
|
26 |
support_prob_per_chunk: the probability of "supported" for each chunk
|
27 |
'''
|
28 |
|
29 |
+
inputs = data['inputs']
|
30 |
docs = inputs['docs']
|
31 |
claims = inputs['claims']
|
32 |
|
|
|
36 |
max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
|
37 |
pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
|
38 |
|
39 |
+
return pred_label, max_support_prob, used_chunk, support_prob_per_chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|