Liyan06 commited on
Commit
3201a95
·
1 Parent(s): ca16988

add customized handler

Browse files
Files changed (3) hide show
  1. handler.py +13 -0
  2. minicheck/inference.py +210 -0
  3. minicheck/minicheck.py +51 -0
handler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from minicheck.minicheck import MiniCheck
2
+
3
+ class EndpointHandler():
4
+ def __init__(self, path="./"):
5
+ self.scorer = MiniCheck(path=path)
6
+
7
+ def __call__(self, data):
8
+ docs = data.pop("docs",data)
9
+ claims = data.pop("claims", None)
10
+
11
+ _, raw_prob, _, _ = self.scorer.score(docs=docs, claims=claims)
12
+
13
+ return raw_prob
minicheck/inference.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt code from https://github.com/yuh-zha/AlignScore/tree/main
2
+
3
+ import sys
4
+ sys.path.append("..")
5
+
6
+ from nltk.tokenize import sent_tokenize
7
+ import torch
8
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
9
+ import torch.nn as nn
10
+ from tqdm import tqdm
11
+ import torch.nn.functional as F
12
+ import os
13
+
14
+
15
+ def sent_tokenize_with_newlines(text):
16
+ blocks = text.split('\n')
17
+
18
+ tokenized_blocks = [sent_tokenize(block) for block in blocks]
19
+ tokenized_text = []
20
+ for block in tokenized_blocks:
21
+ tokenized_text.extend(block)
22
+ tokenized_text.append('\n')
23
+
24
+ return tokenized_text[:-1]
25
+
26
+
27
+ class Inferencer():
28
+ def __init__(self, path, chunk_size, max_input_length, batch_size) -> None:
29
+
30
+ self.path = path
31
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
32
+
33
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device)
34
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
35
+
36
+ self.chunk_size=500 if chunk_size is None else chunk_size
37
+ self.max_input_length=2048 if max_input_length is None else max_input_length
38
+ self.max_output_length = 256
39
+
40
+ self.model.eval()
41
+ self.batch_size = batch_size
42
+ self.softmax = nn.Softmax(dim=-1)
43
+
44
+ def inference_example_batch(self, doc: list, claim: list):
45
+ """
46
+ inference a example,
47
+ doc: list
48
+ claim: list
49
+ using self.inference to batch the process
50
+ """
51
+
52
+ assert len(doc) == len(claim), "doc must has the same length with claimthesis!"
53
+
54
+ max_support_probs = []
55
+ used_chunks = []
56
+ support_prob_per_chunk = []
57
+
58
+ for one_doc, one_claim in tqdm(zip(doc, claim), desc="Evaluating", total=len(doc)):
59
+ output = self.inference_per_example(one_doc, one_claim)
60
+ max_support_probs.append(output['max_support_prob'])
61
+ used_chunks.append(output['used_chunks'])
62
+ support_prob_per_chunk.append(output['support_prob_per_chunk'])
63
+
64
+ return {
65
+ 'max_support_probs': max_support_probs,
66
+ 'used_chunks': used_chunks,
67
+ 'support_prob_per_chunk': support_prob_per_chunk
68
+ }
69
+
70
+ def inference_per_example(self, doc:str, claim: str):
71
+ """
72
+ inference a example,
73
+ doc: string
74
+ claim: string
75
+ using self.inference to batch the process
76
+ """
77
+ def chunks(lst, n):
78
+ """Yield successive chunks from lst with each having approximately n tokens.
79
+
80
+ For flan-t5, we split using the white space;
81
+ For roberta and deberta, we split using the tokenization.
82
+ """
83
+ current_chunk = []
84
+ current_word_count = 0
85
+ for sentence in lst:
86
+ sentence_word_count = len(sentence.split())
87
+ if current_word_count + sentence_word_count > n:
88
+ yield ' '.join(current_chunk)
89
+ current_chunk = [sentence]
90
+ current_word_count = sentence_word_count
91
+ else:
92
+ current_chunk.append(sentence)
93
+ current_word_count += sentence_word_count
94
+ if current_chunk:
95
+ yield ' '.join(current_chunk)
96
+
97
+ doc_sents = sent_tokenize_with_newlines(doc)
98
+ doc_sents = doc_sents or ['']
99
+
100
+ doc_chunks = [chunk.replace(" \n ", '\n').strip() for chunk in chunks(doc_sents, self.chunk_size)]
101
+
102
+ '''
103
+ [chunk_1, chunk_2, chunk_3, chunk_4, ...]
104
+ [claim]
105
+ '''
106
+ claim_repeat = [claim] * len(doc_chunks)
107
+
108
+ output = self.inference(doc_chunks, claim_repeat)
109
+
110
+ return output
111
+
112
+ def inference(self, doc, claim):
113
+ """
114
+ inference a list of doc and claim
115
+
116
+ Standard aggregation (max) over chunks of doc
117
+
118
+ Note: We do not have any post-processing steps for 'claim'
119
+ and directly check 'doc' against 'claim'. If there are multiple
120
+ sentences in 'claim'. Sentences are not splitted and are checked
121
+ as a single piece of text.
122
+
123
+ If there are multiple sentences in 'claim', we suggest users to
124
+ split 'claim' into sentences beforehand and prepares data like
125
+ (doc, claim_1), (doc, claim_2), ... for a multi-sentence 'claim'.
126
+
127
+ **We leave the user to decide how to aggregate the results from multiple sentences.**
128
+
129
+ Note: AggreFact-CNN is the only dataset that contains three-sentence
130
+ summaries and have annotations on the whole summaries, so we do not
131
+ split the sentences in each 'claim' during prediciotn for simplicity.
132
+ Therefore, for this dataset, our result is based on treating the whole
133
+ summary as a single piece of text (one 'claim').
134
+
135
+ In general, sentence-level prediciton performance is better than that on
136
+ the full-response-level.
137
+ """
138
+
139
+ if isinstance(doc, str) and isinstance(claim, str):
140
+ doc = [doc]
141
+ claim = [claim]
142
+
143
+ batch_input, _, batch_org_chunks = self.batch_tokenize(doc, claim)
144
+
145
+ label_probs_list = []
146
+ used_chunks = []
147
+
148
+ for mini_batch_input, batch_org_chunk in zip(batch_input, batch_org_chunks):
149
+
150
+ mini_batch_input = {k: v.to(self.device) for k, v in mini_batch_input.items()}
151
+
152
+ with torch.no_grad():
153
+
154
+ decoder_input_ids = torch.zeros((mini_batch_input['input_ids'].size(0), 1), dtype=torch.long).to(self.device)
155
+ outputs = self.model(input_ids=mini_batch_input['input_ids'], attention_mask=mini_batch_input['attention_mask'], decoder_input_ids=decoder_input_ids)
156
+ logits = outputs.logits.squeeze(1)
157
+
158
+ # 3 for no support and 209 for support
159
+ label_logits = logits[:, torch.tensor([3, 209])].cpu()
160
+ label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
161
+
162
+ label_probs_list.append(label_probs)
163
+ used_chunks.extend(batch_org_chunk)
164
+
165
+ label_probs = torch.cat(label_probs_list)
166
+ support_prob_per_chunk = label_probs[:, 1].cpu().numpy()
167
+ max_support_prob = label_probs[:, 1].max().item()
168
+
169
+ return {
170
+ 'max_support_prob': max_support_prob,
171
+ 'used_chunks': used_chunks,
172
+ 'support_prob_per_chunk': support_prob_per_chunk
173
+ }
174
+
175
+ def batch_tokenize(self, doc, claim):
176
+ """
177
+ input doc and claims are lists
178
+ """
179
+ assert isinstance(doc, list) and isinstance(claim, list)
180
+ assert len(doc) == len(claim), "doc and claim should be in the same length."
181
+
182
+ original_text = [self.tokenizer.eos_token.join([one_doc, one_claim]) for one_doc, one_claim in zip(doc, claim)]
183
+
184
+ batch_input = []
185
+ batch_concat_text = []
186
+ batch_org_chunks = []
187
+ for mini_batch in self.chunks(original_text, self.batch_size):
188
+ model_inputs = self.tokenizer(
189
+ ['predict: ' + text for text in mini_batch],
190
+ max_length=self.max_input_length,
191
+ truncation=True,
192
+ padding=True,
193
+ return_tensors="pt"
194
+ )
195
+
196
+ batch_input.append(model_inputs)
197
+ batch_concat_text.append(mini_batch)
198
+ batch_org_chunks.append([item[:item.find('</s>')] for item in mini_batch])
199
+
200
+ return batch_input, batch_concat_text, batch_org_chunks
201
+
202
+ def chunks(self, lst, n):
203
+ """Yield successive n-sized chunks from lst."""
204
+ for i in range(0, len(lst), n):
205
+ yield lst[i:i + n]
206
+
207
+ def fact_check(self, doc, claim):
208
+
209
+ outputs = self.inference_example_batch(doc, claim)
210
+ return outputs['max_support_probs'], outputs['used_chunks'], outputs['support_prob_per_chunk']
minicheck/minicheck.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt code from https://github.com/yuh-zha/AlignScore/tree/main
2
+
3
+ import sys
4
+ sys.path.append("..")
5
+
6
+ from minicheck.inference import Inferencer
7
+ from typing import List
8
+ import numpy as np
9
+
10
+
11
+ class MiniCheck:
12
+ def __init__(self, path, chunk_size=None, max_input_length=None, batch_size=16) -> None:
13
+
14
+ self.model = Inferencer(
15
+ path=path,
16
+ batch_size=batch_size,
17
+ chunk_size=chunk_size,
18
+ max_input_length=max_input_length,
19
+ )
20
+
21
+ def score(self, docs: List[str], claims: List[str]) -> List[float]:
22
+ '''
23
+ pred_labels: 0 / 1 (0: unsupported, 1: supported)
24
+ max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
25
+ used_chunks: divided chunks of the input document
26
+ support_prob_per_chunk: the probability of "supported" for each chunk
27
+ '''
28
+
29
+ assert isinstance(docs, list) or isinstance(docs, np.ndarray), "docs must be a list or np.ndarray"
30
+ assert isinstance(claims, list) or isinstance(claims, np.ndarray), "claims must be a list or np.ndarray"
31
+
32
+ max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
33
+ pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
34
+
35
+ return pred_label, max_support_prob, used_chunk, support_prob_per_chunk
36
+
37
+
38
+ if __name__ == '__main__':
39
+
40
+ path = "./"
41
+
42
+ doc = "A group of students gather in the school library to study for their upcoming final exams."
43
+ claim_1 = "The students are preparing for an examination."
44
+ claim_2 = "The students are on vacation."
45
+
46
+ # flan-t5-large
47
+ scorer = MiniCheck(path)
48
+ pred_label, raw_prob, _, _ = scorer.score(docs=[doc, doc], claims=[claim_1, claim_2])
49
+
50
+ print(pred_label) # [1, 0]
51
+ print(raw_prob) # [0.9805923700332642, 0.007121307775378227]