Liyan06
commited on
Commit
•
09efa05
1
Parent(s):
3fe4664
customize chunk_size in score function
Browse files- handler.py +13 -1
- minicheck_web/inference.py +4 -2
- minicheck_web/minicheck.py +6 -2
handler.py
CHANGED
@@ -55,10 +55,20 @@ class EndpointHandler():
|
|
55 |
|
56 |
self.tfidf_order = True
|
57 |
self.num_highlights = 1
|
|
|
|
|
|
|
58 |
|
59 |
|
60 |
def __call__(self, data):
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
claim = data['inputs']['claims'][0]
|
63 |
ents = extract_entities(claim)
|
64 |
|
@@ -128,9 +138,11 @@ class EndpointHandler():
|
|
128 |
retrieved_data = {
|
129 |
'inputs': {
|
130 |
'docs': list(retrieved_docs),
|
131 |
-
'claims': [claim]*len(retrieved_docs)
|
|
|
132 |
}
|
133 |
}
|
|
|
134 |
_, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data)
|
135 |
end = time()
|
136 |
num_chunks = len([item for items in used_chunk for item in items])
|
|
|
55 |
|
56 |
self.tfidf_order = True
|
57 |
self.num_highlights = 1
|
58 |
+
|
59 |
+
self.default_chunk_size = 500
|
60 |
+
self.chunk_size = 500
|
61 |
|
62 |
|
63 |
def __call__(self, data):
|
64 |
|
65 |
+
# this is necessary for setting the chunk size for
|
66 |
+
# retrived docs
|
67 |
+
if 'chunk_size' in data['inputs']:
|
68 |
+
self.chunk_size = int(data['inputs']['chunk_size'])
|
69 |
+
else:
|
70 |
+
self.chunk_size = self.default_chunk_size
|
71 |
+
|
72 |
claim = data['inputs']['claims'][0]
|
73 |
ents = extract_entities(claim)
|
74 |
|
|
|
138 |
retrieved_data = {
|
139 |
'inputs': {
|
140 |
'docs': list(retrieved_docs),
|
141 |
+
'claims': [claim]*len(retrieved_docs),
|
142 |
+
'chunk_size': self.chunk_size
|
143 |
}
|
144 |
}
|
145 |
+
|
146 |
_, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data)
|
147 |
end = time()
|
148 |
num_chunks = len([item for items in used_chunk for item in items])
|
minicheck_web/inference.py
CHANGED
@@ -28,7 +28,7 @@ def sent_tokenize_with_newlines(text):
|
|
28 |
|
29 |
|
30 |
class Inferencer():
|
31 |
-
def __init__(self, path,
|
32 |
|
33 |
self.path = path
|
34 |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
@@ -36,7 +36,9 @@ class Inferencer():
|
|
36 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device)
|
37 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
38 |
|
39 |
-
self.
|
|
|
|
|
40 |
self.max_input_length=2048 if max_input_length is None else max_input_length
|
41 |
self.max_output_length = 256
|
42 |
|
|
|
28 |
|
29 |
|
30 |
class Inferencer():
|
31 |
+
def __init__(self, path, max_input_length, batch_size) -> None:
|
32 |
|
33 |
self.path = path
|
34 |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
36 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device)
|
37 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
38 |
|
39 |
+
self.default_chunk_size = 500
|
40 |
+
self.chunk_size=500
|
41 |
+
|
42 |
self.max_input_length=2048 if max_input_length is None else max_input_length
|
43 |
self.max_output_length = 256
|
44 |
|
minicheck_web/minicheck.py
CHANGED
@@ -9,12 +9,11 @@ import numpy as np
|
|
9 |
|
10 |
|
11 |
class MiniCheck:
|
12 |
-
def __init__(self, path,
|
13 |
|
14 |
self.model = Inferencer(
|
15 |
path=path,
|
16 |
batch_size=batch_size,
|
17 |
-
chunk_size=chunk_size,
|
18 |
max_input_length=max_input_length,
|
19 |
)
|
20 |
|
@@ -30,6 +29,11 @@ class MiniCheck:
|
|
30 |
docs = inputs['docs']
|
31 |
claims = inputs['claims']
|
32 |
|
|
|
|
|
|
|
|
|
|
|
33 |
assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray"
|
34 |
assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray"
|
35 |
|
|
|
9 |
|
10 |
|
11 |
class MiniCheck:
|
12 |
+
def __init__(self, path, max_input_length=None, batch_size=16) -> None:
|
13 |
|
14 |
self.model = Inferencer(
|
15 |
path=path,
|
16 |
batch_size=batch_size,
|
|
|
17 |
max_input_length=max_input_length,
|
18 |
)
|
19 |
|
|
|
29 |
docs = inputs['docs']
|
30 |
claims = inputs['claims']
|
31 |
|
32 |
+
if 'chunk_size' in inputs:
|
33 |
+
self.model.chunk_size = int(inputs['chunk_size'])
|
34 |
+
else:
|
35 |
+
self.model.chunk_size = self.model.default_chunk_size
|
36 |
+
|
37 |
assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray"
|
38 |
assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray"
|
39 |
|