Liyan06 commited on
Commit
09efa05
1 Parent(s): 3fe4664

customize chunk_size in score function

Browse files
handler.py CHANGED
@@ -55,10 +55,20 @@ class EndpointHandler():
55
 
56
  self.tfidf_order = True
57
  self.num_highlights = 1
 
 
 
58
 
59
 
60
  def __call__(self, data):
61
 
 
 
 
 
 
 
 
62
  claim = data['inputs']['claims'][0]
63
  ents = extract_entities(claim)
64
 
@@ -128,9 +138,11 @@ class EndpointHandler():
128
  retrieved_data = {
129
  'inputs': {
130
  'docs': list(retrieved_docs),
131
- 'claims': [claim]*len(retrieved_docs)
 
132
  }
133
  }
 
134
  _, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data)
135
  end = time()
136
  num_chunks = len([item for items in used_chunk for item in items])
 
55
 
56
  self.tfidf_order = True
57
  self.num_highlights = 1
58
+
59
+ self.default_chunk_size = 500
60
+ self.chunk_size = 500
61
 
62
 
63
  def __call__(self, data):
64
 
65
+ # this is necessary for setting the chunk size for
66
+ # retrived docs
67
+ if 'chunk_size' in data['inputs']:
68
+ self.chunk_size = int(data['inputs']['chunk_size'])
69
+ else:
70
+ self.chunk_size = self.default_chunk_size
71
+
72
  claim = data['inputs']['claims'][0]
73
  ents = extract_entities(claim)
74
 
 
138
  retrieved_data = {
139
  'inputs': {
140
  'docs': list(retrieved_docs),
141
+ 'claims': [claim]*len(retrieved_docs),
142
+ 'chunk_size': self.chunk_size
143
  }
144
  }
145
+
146
  _, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data)
147
  end = time()
148
  num_chunks = len([item for items in used_chunk for item in items])
minicheck_web/inference.py CHANGED
@@ -28,7 +28,7 @@ def sent_tokenize_with_newlines(text):
28
 
29
 
30
  class Inferencer():
31
- def __init__(self, path, chunk_size, max_input_length, batch_size) -> None:
32
 
33
  self.path = path
34
  self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -36,7 +36,9 @@ class Inferencer():
36
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device)
37
  self.tokenizer = AutoTokenizer.from_pretrained(path)
38
 
39
- self.chunk_size=500 if chunk_size is None else chunk_size
 
 
40
  self.max_input_length=2048 if max_input_length is None else max_input_length
41
  self.max_output_length = 256
42
 
 
28
 
29
 
30
  class Inferencer():
31
+ def __init__(self, path, max_input_length, batch_size) -> None:
32
 
33
  self.path = path
34
  self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
36
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device)
37
  self.tokenizer = AutoTokenizer.from_pretrained(path)
38
 
39
+ self.default_chunk_size = 500
40
+ self.chunk_size=500
41
+
42
  self.max_input_length=2048 if max_input_length is None else max_input_length
43
  self.max_output_length = 256
44
 
minicheck_web/minicheck.py CHANGED
@@ -9,12 +9,11 @@ import numpy as np
9
 
10
 
11
  class MiniCheck:
12
- def __init__(self, path, chunk_size=None, max_input_length=None, batch_size=16) -> None:
13
 
14
  self.model = Inferencer(
15
  path=path,
16
  batch_size=batch_size,
17
- chunk_size=chunk_size,
18
  max_input_length=max_input_length,
19
  )
20
 
@@ -30,6 +29,11 @@ class MiniCheck:
30
  docs = inputs['docs']
31
  claims = inputs['claims']
32
 
 
 
 
 
 
33
  assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray"
34
  assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray"
35
 
 
9
 
10
 
11
  class MiniCheck:
12
+ def __init__(self, path, max_input_length=None, batch_size=16) -> None:
13
 
14
  self.model = Inferencer(
15
  path=path,
16
  batch_size=batch_size,
 
17
  max_input_length=max_input_length,
18
  )
19
 
 
29
  docs = inputs['docs']
30
  claims = inputs['claims']
31
 
32
+ if 'chunk_size' in inputs:
33
+ self.model.chunk_size = int(inputs['chunk_size'])
34
+ else:
35
+ self.model.chunk_size = self.model.default_chunk_size
36
+
37
  assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray"
38
  assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray"
39