Ujjwal123 commited on
Commit
076da67
1 Parent(s): 3de893e

second pass model integrated

Browse files
Files changed (6) hide show
  1. BPSolver_inf.py +246 -16
  2. Dockerfile +4 -0
  3. Models_inf.py +66 -4
  4. Normal_utils_inf.py +100 -3
  5. Solver_inf.py +14 -30
  6. main.py +53 -23
BPSolver_inf.py CHANGED
@@ -1,5 +1,6 @@
1
  import math
2
  import string
 
3
  from collections import defaultdict
4
  from copy import deepcopy
5
 
@@ -9,6 +10,8 @@ from tqdm import trange
9
 
10
  from Utils_inf import print_grid, get_word_flips
11
  from Solver_inf import Solver
 
 
12
  # the probability of each alphabetical character in the crossword
13
  UNIGRAM_PROBS = [('A', 0.0897379968935765), ('B', 0.02121248877769636), ('C', 0.03482206634145926), ('D', 0.03700942543460491), ('E', 0.1159773210750429), ('F', 0.017257461694024614), ('G', 0.025429024796296124), ('H', 0.033122967601502), ('I', 0.06800036223479956), ('J', 0.00294611331754349), ('K', 0.013860682888259786), ('L', 0.05130800574373874), ('M', 0.027962776827660175), ('N', 0.06631994270448001), ('O', 0.07374646543246745), ('P', 0.026750756212433214), ('Q', 0.001507814175439393), ('R', 0.07080460813737305), ('S', 0.07410988246048224), ('T', 0.07242993582154593), ('U', 0.0289272388037645), ('V', 0.009153522059555467), ('W', 0.01434705167591524), ('X', 0.003096729223103298), ('Y', 0.01749958208224007), ('Z', 0.002659777584995724)]
14
 
@@ -18,12 +21,12 @@ LETTER_SMOOTHING_FACTOR = [0.0, 0.0, 0.04395604395604396, 0.0001372495196266813,
18
 
19
  class BPVar:
20
  def __init__(self, name, variable, candidates, cells):
21
- self.name = name
22
  cells_by_position = {}
23
- for cell in cells:
24
- cells_by_position[cell.position] = cell
25
  cell._connect(self)
26
- self.length = len(cells)
27
  self.ordered_cells = [cells_by_position[pos] for pos in variable['cells']]
28
  self.candidates = candidates
29
  self.words = self.candidates['words']
@@ -85,6 +88,7 @@ class BPCell:
85
  self.log_probs = log_softmax(sum(self.directional_scores))
86
 
87
  def propagate(self):
 
88
  try:
89
  for i, v in enumerate(self.crossing_vars):
90
  v._propagate_to_var(self, self.directional_scores[1-i])
@@ -98,7 +102,8 @@ class BPSolver(Solver):
98
  model_path,
99
  ans_tsv_path,
100
  dense_embd_path,
101
- max_candidates = 5000,
 
102
  process_id = 0,
103
  model_type = 'bert',
104
  **kwargs):
@@ -111,6 +116,8 @@ class BPSolver(Solver):
111
  model_type = model_type,
112
  **kwargs)
113
  self.crossword = crossword
 
 
114
 
115
  # our answer set
116
  self.answer_set = set()
@@ -130,12 +137,27 @@ class BPSolver(Solver):
130
  self.bp_cells_by_clue[clue].append(cell)
131
  self.bp_vars = []
132
  for key, value in self.crossword.variables.items():
 
 
 
 
 
 
133
  var = BPVar(key, value, self.candidates[key], self.bp_cells_by_clue[key])
 
 
 
134
  self.bp_vars.append(var)
 
 
 
 
 
 
135
 
136
  def solve(self, num_iters=10, iterative_improvement_steps=5, return_greedy_states = False, return_ii_states = False):
137
  # run solving for num_iters iterations
138
- print('beginning BP iterations')
139
  for _ in trange(num_iters):
140
  for var in self.bp_vars:
141
  var.propagate()
@@ -145,7 +167,7 @@ class BPSolver(Solver):
145
  cell.propagate()
146
  for var in self.bp_vars:
147
  var.sync_state()
148
- print('done BP iterations')
149
 
150
  # Get the current based grid based on greedy selection from the marginals
151
  if return_greedy_states:
@@ -153,16 +175,168 @@ class BPSolver(Solver):
153
  else:
154
  grid = self.greedy_sequential_word_solution()
155
  all_grids = []
156
- grid = self.greedy_sequential_word_solution()
157
- # print('=====Greedy search grid=====')
158
- # print_grid(grid)
 
 
 
 
 
 
 
 
 
 
159
 
160
- if iterative_improvement_steps < 1:
 
 
 
 
 
 
 
161
  if return_greedy_states or return_ii_states:
162
- return grid, all_grids
163
  else:
164
- return grid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
 
 
 
 
 
 
 
 
 
 
166
  def greedy_sequential_word_solution(self, return_grids = False):
167
  all_grids = []
168
  # after we've run BP, we run a simple greedy search to get the final.
@@ -181,7 +355,6 @@ class BPSolver(Solver):
181
  best_index = best_per_var.index(max([x for x in best_per_var if x is not None]))
182
  best_var = self.bp_vars[best_index]
183
  best_word = best_var.words[best_var.log_probs.argmax()]
184
- # print('greedy filling in', best_word)
185
  for i, cell in enumerate(best_var.ordered_cells):
186
  letter = best_word[i]
187
  grid[cell.position[0]][cell.position[1]] = letter
@@ -201,14 +374,71 @@ class BPSolver(Solver):
201
  best_var.words = []
202
  best_var.log_probs = best_var.log_probs[[]]
203
  best_per_var[best_index] = None
 
 
204
  for cell in self.bp_cells:
205
  if cell.position in unfilled_cells:
 
206
  grid[cell.position[0]][cell.position[1]] = string.ascii_uppercase[cell.log_probs.argmax()]
207
-
208
  for var, (words, log_probs) in zip(self.bp_vars, cache): # restore state
209
  var.words = words
210
  var.log_probs = log_probs
211
  if return_grids:
212
  return grid, all_grids
213
  else:
214
- return grid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  import string
3
+ import re
4
  from collections import defaultdict
5
  from copy import deepcopy
6
 
 
10
 
11
  from Utils_inf import print_grid, get_word_flips
12
  from Solver_inf import Solver
13
+ from Models_inf import setup_t5_reranker, t5_reranker_score_with_clue
14
+
15
  # the probability of each alphabetical character in the crossword
16
  UNIGRAM_PROBS = [('A', 0.0897379968935765), ('B', 0.02121248877769636), ('C', 0.03482206634145926), ('D', 0.03700942543460491), ('E', 0.1159773210750429), ('F', 0.017257461694024614), ('G', 0.025429024796296124), ('H', 0.033122967601502), ('I', 0.06800036223479956), ('J', 0.00294611331754349), ('K', 0.013860682888259786), ('L', 0.05130800574373874), ('M', 0.027962776827660175), ('N', 0.06631994270448001), ('O', 0.07374646543246745), ('P', 0.026750756212433214), ('Q', 0.001507814175439393), ('R', 0.07080460813737305), ('S', 0.07410988246048224), ('T', 0.07242993582154593), ('U', 0.0289272388037645), ('V', 0.009153522059555467), ('W', 0.01434705167591524), ('X', 0.003096729223103298), ('Y', 0.01749958208224007), ('Z', 0.002659777584995724)]
17
 
 
21
 
22
  class BPVar:
23
  def __init__(self, name, variable, candidates, cells):
24
+ self.name = name # key from crossword.variables i.e. 1A, 2D, 3A
25
  cells_by_position = {}
26
+ for cell in cells: # every cells or letter box that a particular variable or filling takes into consideration
27
+ cells_by_position[cell.position] = cell # cell.position (0,0) -> cell -> BPCell
28
  cell._connect(self)
29
+ self.length = len(cells) # obviously the length of the answer
30
  self.ordered_cells = [cells_by_position[pos] for pos in variable['cells']]
31
  self.candidates = candidates
32
  self.words = self.candidates['words']
 
88
  self.log_probs = log_softmax(sum(self.directional_scores))
89
 
90
  def propagate(self):
91
+ # assert len(self.crossing_vars) == 2
92
  try:
93
  for i, v in enumerate(self.crossing_vars):
94
  v._propagate_to_var(self, self.directional_scores[1-i])
 
102
  model_path,
103
  ans_tsv_path,
104
  dense_embd_path,
105
+ reranker_path,
106
+ max_candidates = 100,
107
  process_id = 0,
108
  model_type = 'bert',
109
  **kwargs):
 
116
  model_type = model_type,
117
  **kwargs)
118
  self.crossword = crossword
119
+ self.reranker_path = reranker_path
120
+ self.reranker_model_type = 't5-small'
121
 
122
  # our answer set
123
  self.answer_set = set()
 
137
  self.bp_cells_by_clue[clue].append(cell)
138
  self.bp_vars = []
139
  for key, value in self.crossword.variables.items():
140
+ # if key == '1A':
141
+ # print('-'*100)
142
+ # print(self.candidates[key]['words'])
143
+ # print(self.candidates[key]['bit_array'].shape)
144
+ # print(self.candidates[key]['weights'])
145
+ # print('-'*100)
146
  var = BPVar(key, value, self.candidates[key], self.bp_cells_by_clue[key])
147
+ # print('*'*100)
148
+ # print(self.bp_cells_by_clue[key])
149
+ # print('*'*100)
150
  self.bp_vars.append(var)
151
+
152
+ def extract_float(self, input_string):
153
+ pattern = r"\d+\.\d+"
154
+ matches = re.findall(pattern, input_string)
155
+ float_numbers = [float(match) for match in matches]
156
+ return float_numbers
157
 
158
  def solve(self, num_iters=10, iterative_improvement_steps=5, return_greedy_states = False, return_ii_states = False):
159
  # run solving for num_iters iterations
160
+ print('\nBeginning Belief Propagation Iteration Steps: ')
161
  for _ in trange(num_iters):
162
  for var in self.bp_vars:
163
  var.propagate()
 
167
  cell.propagate()
168
  for var in self.bp_vars:
169
  var.sync_state()
170
+ print('Belief Propagation Iteration Complete\n')
171
 
172
  # Get the current based grid based on greedy selection from the marginals
173
  if return_greedy_states:
 
175
  else:
176
  grid = self.greedy_sequential_word_solution()
177
  all_grids = []
178
+
179
+ # properly save all the outputs results:
180
+ output_results = {}
181
+ output_results['first pass model'] = {}
182
+ output_results['first pass model']['grid'] = grid
183
+
184
+ # save first pass model grid, and letter accuracies
185
+ _, accu_log = self.evaluate(grid, False)
186
+ [ori_letter_accu, ori_word_accu] = self.extract_float(accu_log)
187
+ output_results['first pass model']['letter accuracy'] = ori_letter_accu
188
+ output_results['first pass model']['word accuracy'] = ori_word_accu
189
+
190
+ print("First pass model result was", grid,ori_letter_accu,ori_word_accu)
191
 
192
+ output_results['second pass model'] = {}
193
+ output_results['second pass model']['final grid'] = [] # just for the sake of the api
194
+ output_results['second pass model']['final grid'] = grid # just for the sake of the api
195
+ output_results['second pass model']['all grids'] = []
196
+ output_results['second pass model']['all letter accuracy'] = []
197
+ output_results['second pass model']['all word accuracy'] = []
198
+
199
+ if iterative_improvement_steps < 1 or ori_letter_accu == 100.0:
200
  if return_greedy_states or return_ii_states:
201
+ return output_results, all_grids
202
  else:
203
+ return output_results
204
+
205
+ '''
206
+ Iterative Improvement with t5-small starts from here.
207
+ '''
208
+ self.reranker, self.tokenizer = setup_t5_reranker(self.reranker_path, self.reranker_model_type)
209
+
210
+
211
+
212
+ for i in range(iterative_improvement_steps):
213
+ grid, did_iterative_improvement_make_edit = self.iterative_improvement(grid)
214
+
215
+ _, accu_log = self.evaluate(grid, False)
216
+ [temp_letter_accu, temp_word_accu] = self.extract_float(accu_log)
217
+ print(f"{i+1}th iteration: {accu_log}")
218
+
219
+ # save grid & accuracies at each iteration
220
+ output_results['second pass model']['all grids'].append(grid)
221
+ output_results['second pass model']['all letter accuracy'].append(temp_letter_accu)
222
+ output_results['second pass model']['all word accuracy'].append(temp_word_accu)
223
+
224
+ if not did_iterative_improvement_make_edit or temp_letter_accu == 100.0:
225
+ break
226
+
227
+ if return_ii_states:
228
+ all_grids.append(deepcopy(grid))
229
+
230
+ temp_lett_accu_list = output_results['second pass model']['all letter accuracy'].copy()
231
+ ii_max_index = temp_lett_accu_list.index(max(temp_lett_accu_list))
232
+
233
+ output_results['second pass model']['final grid'] = output_results['second pass model']['all grids'][ii_max_index]
234
+ output_results['second pass model']['final letter'] = output_results['second pass model']['all letter accuracy'][ii_max_index]
235
+ output_results['second pass model']['final word'] = output_results['second pass model']['all word accuracy'][ii_max_index]
236
+
237
+ if return_greedy_states or return_ii_states:
238
+ return output_results, all_grids
239
+ else:
240
+ return output_results
241
+
242
+ def get_candidate_replacements(self, uncertain_answers, grid):
243
+ # find alternate answers for all the uncertain words
244
+ candidate_replacements = []
245
+ replacement_id_set = set()
246
+
247
+ # check against dictionaries
248
+ for clue in uncertain_answers.keys():
249
+ initial_word = uncertain_answers[clue]
250
+ clue_flips = get_word_flips(initial_word, 10) # flip then segment
251
+ clue_positions = [key for key, value in self.crossword.variables.items() if value['clue'] == clue]
252
+ for clue_position in clue_positions:
253
+ cells = sorted([cell for cell in self.bp_cells if clue_position in cell.crossing_clues], key=lambda c: c.position)
254
+ if len(cells) == len(initial_word):
255
+ break
256
+ for flip in clue_flips:
257
+ if len(flip) != len(cells):
258
+ import pdb; pdb.set_trace()
259
+ assert len(flip) == len(cells)
260
+ for i in range(len(flip)):
261
+ if flip[i] != initial_word[i]:
262
+ candidate_replacements.append([(cells[i], flip[i])])
263
+ break
264
+
265
+ # also add candidates based on uncertainties in the letters, e.g., if we said P but G also had some probability, try G too
266
+ for cell_id, cell in enumerate(self.bp_cells):
267
+ probs = np.exp(cell.log_probs)
268
+ above_threshold = list(probs > 0.01)
269
+ new_characters = ['ABCDEFGHIJKLMNOPQRSTUVWXYZ'[i] for i in range(26) if above_threshold[i]]
270
+ # used = set()
271
+ # new_characters = [x for x in new_characters if x not in used and (used.add(x) or True)] # unique the set
272
+ new_characters = [x for x in new_characters if x != grid[cell.position[0]][cell.position[1]]] # ignore if its the same as the original solution
273
+ if len(new_characters) > 0:
274
+ for new_character in new_characters:
275
+ id = '_'.join([str(cell.position), new_character])
276
+ if id not in replacement_id_set:
277
+ candidate_replacements.append([(cell, new_character)])
278
+ replacement_id_set.add(id)
279
+
280
+ # create composite flips based on things in the same row/column
281
+ composite_replacements = []
282
+ for i in range(len(candidate_replacements)):
283
+ for j in range(i+1, len(candidate_replacements)):
284
+ flip1, flip2 = candidate_replacements[i], candidate_replacements[j]
285
+ if flip1[0][0] != flip2[0][0]:
286
+ if len(set(flip1[0][0].crossing_clues + flip2[0][0].crossing_clues)) < 4: # shared clue
287
+ composite_replacements.append(flip1 + flip2)
288
+
289
+ candidate_replacements += composite_replacements
290
+
291
+ #print('\ncandidate replacements')
292
+ for cr in candidate_replacements:
293
+ modified_grid = deepcopy(grid)
294
+ for cell, letter in cr:
295
+ modified_grid[cell.position[0]][cell.position[1]] = letter
296
+ variables = set(sum([cell.crossing_vars for cell, _ in cr], []))
297
+ for var in variables:
298
+ original_fill = ''.join([grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
299
+ modified_fill = ''.join([modified_grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
300
+ #print('original:', original_fill, 'modified:', modified_fill)
301
+
302
+ return candidate_replacements
303
+
304
+ def get_uncertain_answers(self, grid):
305
+ original_qa_pairs = {} # the original puzzle preds that we will try to improve
306
+ # first save what the argmax word-level prediction was for each grid cell just to make life easier
307
+ for var in self.crossword.variables:
308
+ # read the current word off the grid
309
+ cells = self.crossword.variables[var]["cells"]
310
+ word = []
311
+ for cell in cells:
312
+ word.append(grid[cell[0]][cell[1]])
313
+ word = ''.join(word)
314
+ for cell in self.bp_cells: # loop through all cells
315
+ if cell.position in cells: # if this cell is in the word we are currently handling
316
+ # save {clue, answer} pair into this cell
317
+ cell.prediction[self.crossword.variables[var]['clue']] = word
318
+ original_qa_pairs[self.crossword.variables[var]['clue']] = word
319
+
320
+ uncertain_answers = {}
321
+
322
+ # find uncertain answers
323
+ # right now the heuristic we use is any answer that is not in the answer set
324
+ for clue in original_qa_pairs.keys():
325
+ if original_qa_pairs[clue] not in self.answer_set:
326
+ uncertain_answers[clue] = original_qa_pairs[clue]
327
+
328
+ return uncertain_answers
329
 
330
+ def score_grid(self, grid):
331
+ clues = []
332
+ answers = []
333
+ for clue, cells in self.bp_cells_by_clue.items():
334
+ letters = ''.join([grid[cell.position[0]][cell.position[1]] for cell in sorted(list(cells), key=lambda c: c.position)])
335
+ clues.append(self.crossword.variables[clue]['clue'])
336
+ answers.append(letters)
337
+ scores = t5_reranker_score_with_clue(self.reranker, self.tokenizer, self.reranker_model_type, clues, answers)
338
+ return sum(scores)
339
+
340
  def greedy_sequential_word_solution(self, return_grids = False):
341
  all_grids = []
342
  # after we've run BP, we run a simple greedy search to get the final.
 
355
  best_index = best_per_var.index(max([x for x in best_per_var if x is not None]))
356
  best_var = self.bp_vars[best_index]
357
  best_word = best_var.words[best_var.log_probs.argmax()]
 
358
  for i, cell in enumerate(best_var.ordered_cells):
359
  letter = best_word[i]
360
  grid[cell.position[0]][cell.position[1]] = letter
 
374
  best_var.words = []
375
  best_var.log_probs = best_var.log_probs[[]]
376
  best_per_var[best_index] = None
377
+
378
+ unfilled_cells_count = 0
379
  for cell in self.bp_cells:
380
  if cell.position in unfilled_cells:
381
+ unfilled_cells_count += 1
382
  grid[cell.position[0]][cell.position[1]] = string.ascii_uppercase[cell.log_probs.argmax()]
383
+
384
  for var, (words, log_probs) in zip(self.bp_vars, cache): # restore state
385
  var.words = words
386
  var.log_probs = log_probs
387
  if return_grids:
388
  return grid, all_grids
389
  else:
390
+ return grid
391
+
392
+ def iterative_improvement(self, grid):
393
+ # check the grid for uncertain areas and save those words to be analyzed in local search, aka looking for alternate candidates
394
+ uncertain_answers = self.get_uncertain_answers(grid)
395
+ self.candidate_replacements = self.get_candidate_replacements(uncertain_answers, grid)
396
+
397
+ # print('\nstarting iterative improvement')
398
+ original_grid_score = self.score_grid(grid)
399
+ possible_edits = []
400
+ for replacements in self.candidate_replacements:
401
+ modified_grid = deepcopy(grid)
402
+ for cell, letter in replacements:
403
+ modified_grid[cell.position[0]][cell.position[1]] = letter
404
+ modified_grid_score = self.score_grid(modified_grid)
405
+ # print('candidate edit')
406
+ variables = set(sum([cell.crossing_vars for cell, _ in replacements], []))
407
+ for var in variables:
408
+ original_fill = ''.join([grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
409
+ modified_fill = ''.join([modified_grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
410
+ clue_index = list(set(var.ordered_cells[0].crossing_clues).intersection(*[set(cell.crossing_clues) for cell in var.ordered_cells]))[0]
411
+ # print('original:', original_fill, 'modified:', modified_fill)
412
+ # print('gold answer', self.crossword.variables[clue_index]['gold'])
413
+ # print('clue', self.crossword.variables[clue_index]['clue'])
414
+ # print('original score:', original_grid_score, 'modified score:', modified_grid_score)
415
+ if modified_grid_score - original_grid_score > 0.5:
416
+ # print('found a possible edit')
417
+ possible_edits.append((modified_grid, modified_grid_score, replacements))
418
+ # print()
419
+
420
+ if len(possible_edits) > 0:
421
+ variables_modified = set()
422
+ possible_edits = sorted(possible_edits, key=lambda x: x[1], reverse=True)
423
+ selected_edits = []
424
+ for edit in possible_edits:
425
+ replacements = edit[2]
426
+ variables = set(sum([cell.crossing_vars for cell, _ in replacements], []))
427
+ if len(variables_modified.intersection(variables)) == 0: # we can do multiple updates at once if they don't share clues
428
+ variables_modified.update(variables)
429
+ selected_edits.append(edit)
430
+
431
+ new_grid = deepcopy(grid)
432
+ for edit in selected_edits:
433
+ # print('\nactually applying edit')
434
+ replacements = edit[2]
435
+ for cell, letter in replacements:
436
+ new_grid[cell.position[0]][cell.position[1]] = letter
437
+ variables = set(sum([cell.crossing_vars for cell, _ in replacements], []))
438
+ for var in variables:
439
+ original_fill = ''.join([grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
440
+ modified_fill = ''.join([new_grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
441
+ # print('original:', original_fill, 'modified:', modified_fill)
442
+ return new_grid, True
443
+ else:
444
+ return grid, False
Dockerfile CHANGED
@@ -34,4 +34,8 @@ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scor
34
  ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/dpr_biencoder_trained_500k.bin $HOME/app/Inference_components/
35
  ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/embeddings_all_answers_json_0.pkl $HOME/app/Inference_components/
36
 
 
 
 
 
37
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
34
  ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/dpr_biencoder_trained_500k.bin $HOME/app/Inference_components/
35
  ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/embeddings_all_answers_json_0.pkl $HOME/app/Inference_components/
36
 
37
+ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/t5_small_new_dataset_2EPOCHS/config.json $HOME/app/Inference_components/t5_small_new_dataset_2EPOCHS/
38
+ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/t5_small_new_dataset_2EPOCHS/generation_config.json $HOME/app/Inference_components/t5_small_new_dataset_2EPOCHS/
39
+ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/t5_small_new_dataset_2EPOCHS/model.safetensors $HOME/app/Inference_components/t5_small_new_dataset_2EPOCHS/
40
+
41
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
Models_inf.py CHANGED
@@ -7,8 +7,6 @@ import string
7
  import sys
8
  from typing import List, Tuple, Dict
9
  import re
10
- import math
11
- import collections
12
 
13
  import numpy as np
14
  import unicodedata
@@ -21,7 +19,10 @@ from Options_inf import setup_args_gpu, print_args, set_encoder_params_from_stat
21
  from Faiss_Indexers_inf import DenseIndexer, DenseFlatIndexer
22
  from Data_utils_inf import Tensorizer
23
  from Model_utils_inf import load_states_from_checkpoint, get_model_obj
24
-
 
 
 
25
 
26
  SEGMENTER_CACHE = {}
27
  RERANKER_CACHE = {}
@@ -37,6 +38,64 @@ def setup_closedbook(model_path, ans_tsv_path, dense_embd_path, process_id, mode
37
  )
38
  return dpr
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def preprocess_clue_fn(clue):
41
  clue = str(clue)
42
 
@@ -202,6 +261,7 @@ class DenseRetriever(object):
202
  query_vectors.extend(out.cpu().split(1, dim=0))
203
 
204
  query_tensor = torch.cat(query_vectors, dim=0)
 
205
  assert query_tensor.size(0) == len(questions)
206
  return query_tensor
207
 
@@ -353,9 +413,11 @@ class DPRForCrossword(object):
353
  if max_answers > self.len_all_passages:
354
  max_answers = self.len_all_passages
355
 
 
356
  # get top k results
357
  top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_answers)
358
-
 
359
  if not output_strings:
360
  return top_ids_and_scores
361
  else:
 
7
  import sys
8
  from typing import List, Tuple, Dict
9
  import re
 
 
10
 
11
  import numpy as np
12
  import unicodedata
 
19
  from Faiss_Indexers_inf import DenseIndexer, DenseFlatIndexer
20
  from Data_utils_inf import Tensorizer
21
  from Model_utils_inf import load_states_from_checkpoint, get_model_obj
22
+ from transformers import T5ForConditionalGeneration, AutoTokenizer
23
+ import time
24
+ from wordsegment import load, segment
25
+ load()
26
 
27
  SEGMENTER_CACHE = {}
28
  RERANKER_CACHE = {}
 
38
  )
39
  return dpr
40
 
41
+ def setup_t5_reranker(reranker_path, reranker_model_type = 't5-small'):
42
+ tokenizer = AutoTokenizer.from_pretrained(reranker_model_type)
43
+ model = T5ForConditionalGeneration.from_pretrained(reranker_path)
44
+ model.eval().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
45
+ return model, tokenizer
46
+
47
+ def post_process_clue(clue):
48
+ clue = preprocess_clue_fn(clue)
49
+ if clue[-3:] == '. .':
50
+ clue = clue[:-3]
51
+ elif clue[-3:] == ' ..':
52
+ clue = clue[:-3]
53
+ elif clue[-2:] == '..':
54
+ clue = clue[:-2]
55
+ elif clue[-1] == '.':
56
+ clue = clue[:-1]
57
+ return clue
58
+
59
+ def t5_reranker_score_with_clue(model, tokenizer, model_type, clues, possibly_ungrammatical_fills):
60
+ global RERANKER_CACHE
61
+ results = []
62
+ device = model.device
63
+
64
+ fills = possibly_ungrammatical_fills.copy()
65
+
66
+ if model_type == 't5-small':
67
+ segmented_fills = []
68
+ for answer in possibly_ungrammatical_fills:
69
+ segmented_fills.append(" ".join(segment(answer.lower())))
70
+ fills = segmented_fills.copy()
71
+
72
+ for clue, possibly_ungrammatical_fill in zip(clues, fills):
73
+ # possibly here is where the byt5 failed
74
+ if not possibly_ungrammatical_fill.islower():
75
+ possibly_ungrammatical_fill = possibly_ungrammatical_fill.lower()
76
+
77
+ clue = post_process_clue(clue)
78
+
79
+ if clue + possibly_ungrammatical_fill in RERANKER_CACHE:
80
+ results.append(RERANKER_CACHE[clue + possibly_ungrammatical_fill])
81
+ continue
82
+ else:
83
+ with torch.no_grad(), torch.inference_mode():
84
+ # move all the input tensors to the GPU (cuda)
85
+ inputs = tokenizer(["Q: " + clue], return_tensors='pt')['input_ids'].to(device)
86
+ labels = tokenizer([possibly_ungrammatical_fill], return_tensors='pt')['input_ids'].to(device)
87
+
88
+ # model mode set to evaluation
89
+ model.eval()
90
+
91
+ loss = model(inputs, labels = labels)
92
+ answer_length = labels.shape[1]
93
+ logprob = -loss[0].item() * answer_length
94
+ results.append(logprob)
95
+ RERANKER_CACHE[clue + possibly_ungrammatical_fill] = logprob
96
+
97
+ return results
98
+
99
  def preprocess_clue_fn(clue):
100
  clue = str(clue)
101
 
 
261
  query_vectors.extend(out.cpu().split(1, dim=0))
262
 
263
  query_tensor = torch.cat(query_vectors, dim=0)
264
+ print("CLUE Vector Shape", query_tensor.shape)
265
  assert query_tensor.size(0) == len(questions)
266
  return query_tensor
267
 
 
413
  if max_answers > self.len_all_passages:
414
  max_answers = self.len_all_passages
415
 
416
+ start_time = time.time()
417
  # get top k results
418
  top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_answers)
419
+ end_time = time.time()
420
+ print("\n\nTime taken by FAISS INDEXER: ", end_time - start_time)
421
  if not output_strings:
422
  return top_ids_and_scores
423
  else:
Normal_utils_inf.py CHANGED
@@ -1,7 +1,6 @@
1
  import puz
2
- import re
3
- import unicodedata
4
- import sys
5
 
6
  def puz_to_json(fname):
7
  """ Converts a puzzle in .puz format to .json format
@@ -63,3 +62,101 @@ def puz_to_pairs(filepath):
63
 
64
  return [(k, v) for k, v in pairs.items()]
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import puz
2
+ import json
3
+ import requests
 
4
 
5
  def puz_to_json(fname):
6
  """ Converts a puzzle in .puz format to .json format
 
62
 
63
  return [(k, v) for k, v in pairs.items()]
64
 
65
+ def json_CA_json_converter(json_file_path, is_path):
66
+ try:
67
+ if is_path:
68
+ with open(json_file_path, "r") as file:
69
+ data = json.load(file)
70
+ else:
71
+ data = json_file_path
72
+
73
+ json_conversion_dict = {}
74
+
75
+ rows = data["size"]["rows"]
76
+ cols = data["size"]["cols"]
77
+ date = data["date"]
78
+
79
+ clues = data["clues"]
80
+ answers = data["answers"]
81
+
82
+ json_conversion_dict["metadata"] = {"date": date, "rows": rows, "cols": cols}
83
+
84
+ across_clue_answer = {}
85
+ down_clue_answer = {}
86
+
87
+ for clue, ans in zip(clues["across"], answers["across"]):
88
+ split_clue = clue.split(" ")
89
+ clue_num = split_clue[0][:-1]
90
+ clue_ = " ".join(split_clue[1:])
91
+ clue_ = clue_.replace("[", "").replace("]", "")
92
+ across_clue_answer[clue_num] = [clue_, ans]
93
+
94
+ for clue, ans in zip(clues["down"], answers["down"]):
95
+ split_clue = clue.split(" ")
96
+ clue_num = split_clue[0][:-1]
97
+ clue_ = " ".join(split_clue[1:])
98
+ clue_ = clue_.replace("[", "").replace("]", "")
99
+ down_clue_answer[clue_num] = [clue_, ans]
100
+
101
+ json_conversion_dict["clues"] = {
102
+ "across": across_clue_answer,
103
+ "down": down_clue_answer,
104
+ }
105
+
106
+ grid_info = data["grid"]
107
+ grid_num = data["gridnums"]
108
+
109
+ grid_info_list = []
110
+ for i in range(rows):
111
+ row_list = []
112
+ for j in range(cols):
113
+ if grid_info[i * rows + j] == ".":
114
+ row_list.append("BLACK")
115
+ else:
116
+ if grid_num[i * rows + j] == 0:
117
+ row_list.append(["", grid_info[i * rows + j]])
118
+ else:
119
+ row_list.append(
120
+ [str(grid_num[i * rows + j]), grid_info[i * rows + j]]
121
+ )
122
+ grid_info_list.append(row_list)
123
+
124
+ json_conversion_dict["grid"] = grid_info_list
125
+
126
+ return json_conversion_dict
127
+
128
+ except:
129
+ print("ERROR has occured.")
130
+
131
+ def fetch_nyt_crossword(dateStr):
132
+ '''
133
+ Fetch NYT puzzle from a specific date.
134
+ '''
135
+
136
+ headers = {
137
+ 'Referer': 'https://www.xwordinfo.com/JSON/'
138
+ }
139
+ # mm/dd/yyyy
140
+
141
+ url = 'https://www.xwordinfo.com/JSON/Data.ashx?date=' + dateStr
142
+
143
+ response = requests.get(url, headers=headers)
144
+
145
+ context = {}
146
+ grid_data = {}
147
+ if response.status_code == 200:
148
+ bytevalue = response.content
149
+ jsonText = bytevalue.decode('utf-8').replace("'", '"')
150
+ grid_data = json.loads(jsonText)
151
+ puzzle_data = json_CA_json_converter(grid_data, False)
152
+ for dim in ['across', 'down']:
153
+ for grid_num in puzzle_data['clues'][dim].keys():
154
+ clue_answer_list = puzzle_data['clues'][dim][grid_num]
155
+ clue_section = clue_answer_list[0]
156
+ ans_section = clue_answer_list[1]
157
+ clue_section = clue_section.replace("&quot;", "'").replace("&#39;", "'")
158
+ puzzle_data['clues'][dim][grid_num] = [clue_section, ans_section]
159
+ return puzzle_data
160
+
161
+ else:
162
+ print(f"Request failed with status code {response.status_code}.")
Solver_inf.py CHANGED
@@ -16,7 +16,7 @@ class Solver:
16
  crossword (Crossword): puzzle to solve
17
  max_candidates (int): number of answer candidates to consider per clue
18
  """
19
- def __init__(self, crossword, model_path, ans_tsv_path, dense_embd_path, max_candidates=1000, process_id = 0, model_type = 'bert'):
20
  self.crossword = crossword
21
  self.max_candidates = max_candidates
22
  self.process_id = process_id
@@ -46,10 +46,10 @@ class Solver:
46
  clue = clue[:match.start()] + self.crossword.variables[var]['clue'] + clue[match.end():]
47
  all_clues[idx] = clue
48
 
49
- # print("MODEL PATH: ", type(self.dense_embd_glob))
50
  # get predictions
51
  dpr = setup_closedbook(self.model_path, self.ans_tsv_path, self.dense_embd_glob, self.process_id, self.model_type)
52
- all_words, all_scores = answer_clues(dpr, all_clues, max_answers=self.max_candidates, output_strings=True)
 
53
  for index, var in enumerate(self.crossword.variables):
54
  length = len(self.crossword.variables[var]["gold"])
55
  self.candidates[var] = {"words": [], "bit_array": None, "weights": {}}
@@ -63,21 +63,13 @@ class Solver:
63
  keep_positions.append(word_index)
64
  words = [words[i] for i in keep_positions]
65
  scores = [scores[i] for i in keep_positions]
 
66
  scores = list(-np.log(softmax(np.array(scores) / 0.75)))
67
 
68
  for word, score in zip(words, scores):
69
  self.candidates[var]["weights"][word] = score
70
-
71
- # for debugging purposes, print the rank of the gold answer on our candidate list
72
- # the gold answer is otherwise *not* used in any way during solving
73
- # if self.crossword.variables[var]["gold"] in words:
74
- # print(clue, self.crossword.variables[var]["gold"], words.index(self.crossword.variables[var]["gold"]))
75
- # else:
76
- # print('not found', clue, self.crossword.variables[var]["gold"])
77
 
78
- # fill up some data structures used later in solving
79
- for word, score in zip(words, scores):
80
- self.candidates[var]["weights"][word] = score
81
  weights = self.candidates[var]["weights"]
82
  self.candidates[var]["words"] = sorted(weights, key=weights.get)
83
  self.candidates[var]["bit_array"] = np.zeros((len(chars), length, len(self.candidates[var]["words"])))
@@ -94,8 +86,7 @@ class Solver:
94
  # cleanup a bit
95
  del dpr
96
 
97
- def evaluate(self, solution):
98
- # print puzzle accuracy results given a generated solution
99
  letters_correct = 0
100
  letters_total = 0
101
  for i in range(len(self.crossword.letter_grid)):
@@ -110,20 +101,13 @@ class Solver:
110
  matching_cells = [self.crossword.letter_grid[cell[0]][cell[1]] == solution[cell[0]][cell[1]] for cell in cells]
111
  if len(cells) == sum(matching_cells):
112
  words_correct += 1
113
- else:
114
- # print('evaluation: correct word', ''.join([self.crossword.letter_grid[cell[0]][cell[1]] for cell in cells]), 'our prediction:', ''.join([solution[cell[0]][cell[1]] for cell in cells]))
115
- pass
116
  words_total += 1
 
 
 
 
 
 
 
117
 
118
- print("Letters Correct: {}/{} | Words Correct: {}/{}".format(int(letters_correct), int(letters_total), int(words_correct), int(words_total)))
119
- print("Letters Correct: {}% | Words Correct: {}%".format(float(letters_correct/letters_total*100), float(words_correct/words_total*100)))
120
-
121
- info = {
122
- "total_letters" : int(letters_total),
123
- "total_words" : int(words_total),
124
- "correct_letters" : int(letters_correct),
125
- "correct_words" : int(words_correct),
126
- "correct_letters_percent" : float(letters_correct/letters_total*100),
127
- "correct_words_percent" : float(words_correct/words_total*100),
128
- }
129
- return info
 
16
  crossword (Crossword): puzzle to solve
17
  max_candidates (int): number of answer candidates to consider per clue
18
  """
19
+ def __init__(self, crossword, model_path, ans_tsv_path, dense_embd_path, max_candidates = 100, process_id = 0, model_type = 'bert'):
20
  self.crossword = crossword
21
  self.max_candidates = max_candidates
22
  self.process_id = process_id
 
46
  clue = clue[:match.start()] + self.crossword.variables[var]['clue'] + clue[match.end():]
47
  all_clues[idx] = clue
48
 
 
49
  # get predictions
50
  dpr = setup_closedbook(self.model_path, self.ans_tsv_path, self.dense_embd_glob, self.process_id, self.model_type)
51
+ all_words, all_scores = answer_clues(dpr, all_clues, max_answers = self.max_candidates, output_strings=True)
52
+
53
  for index, var in enumerate(self.crossword.variables):
54
  length = len(self.crossword.variables[var]["gold"])
55
  self.candidates[var] = {"words": [], "bit_array": None, "weights": {}}
 
63
  keep_positions.append(word_index)
64
  words = [words[i] for i in keep_positions]
65
  scores = [scores[i] for i in keep_positions]
66
+
67
  scores = list(-np.log(softmax(np.array(scores) / 0.75)))
68
 
69
  for word, score in zip(words, scores):
70
  self.candidates[var]["weights"][word] = score
 
 
 
 
 
 
 
71
 
72
+
 
 
73
  weights = self.candidates[var]["weights"]
74
  self.candidates[var]["words"] = sorted(weights, key=weights.get)
75
  self.candidates[var]["bit_array"] = np.zeros((len(chars), length, len(self.candidates[var]["words"])))
 
86
  # cleanup a bit
87
  del dpr
88
 
89
+ def evaluate(self, solution, print_log = True):
 
90
  letters_correct = 0
91
  letters_total = 0
92
  for i in range(len(self.crossword.letter_grid)):
 
101
  matching_cells = [self.crossword.letter_grid[cell[0]][cell[1]] == solution[cell[0]][cell[1]] for cell in cells]
102
  if len(cells) == sum(matching_cells):
103
  words_correct += 1
 
 
 
104
  words_total += 1
105
+
106
+ letter_frac_log = "Letters Correct: {}/{} | Words Correct: {}/{}".format(int(letters_correct), int(letters_total), int(words_correct), int(words_total))
107
+ letter_acc_log = "Letters Correct: {}% | Words Correct: {}%".format(float(letters_correct/letters_total*100), float(words_correct/words_total*100))
108
+
109
+ if print_log:
110
+ print(letter_frac_log)
111
+ print(letter_acc_log)
112
 
113
+ return letter_frac_log, letter_acc_log
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -7,9 +7,30 @@ from Strict_json import json_CA_json_converter
7
  import asyncio
8
  from fastapi.middleware.cors import CORSMiddleware
9
 
10
- MODEL_PATH_DISTIL = os.path.join("Inference_components","distilbert_EPOCHs_7_COMPLETE.bin")
11
- ANS_TSV_PATH_DISTIL = os.path.join("Inference_components","all_answer_list.tsv")
12
- DENSE_EMBD_PATH_DISTIL = os.path.join("Inference_components","distilbert_7_epochs_embeddings.pkl")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  app = FastAPI()
15
 
@@ -22,27 +43,35 @@ app.add_middleware(
22
  )
23
 
24
  async def solve_puzzle(json):
25
- puzzle = json_CA_json_converter(json, False)
26
- crossword = Crossword(puzzle)
27
-
28
- # Perform asynchronous operations using asyncio.gather or asyncio.create_task
29
- async def solve_async():
30
- return await asyncio.to_thread(BPSolver, crossword, model_path=MODEL_PATH_DISTIL,
31
- ans_tsv_path=ANS_TSV_PATH_DISTIL,
32
- dense_embd_path=DENSE_EMBD_PATH_DISTIL,
33
- max_candidates=40000,
34
- model_type='distilbert')
35
-
36
- solver = await solve_async()
37
-
38
- # Run solve method asynchronously
39
- async def solve_method_async():
40
- return await asyncio.to_thread(solver.solve, num_iters=100, iterative_improvement_steps=0)
 
 
41
 
42
- solution = await solve_method_async()
43
- evaluation = await asyncio.to_thread(solver.evaluate, solution)
 
 
 
 
44
 
45
- return solution, evaluation
 
 
46
 
47
 
48
  fifo_queue = asyncio.Queue()
@@ -55,8 +84,9 @@ async def worker():
55
  job_id, job, args, future = await fifo_queue.get()
56
  jobs[job_id]["status"] = "processing"
57
  result = await job(*args)
 
58
  jobs[job_id]["result"] = result
59
- jobs[job_id]["status"] = "completed"
60
  future.set_result(job_id)
61
 
62
  @app.on_event("startup")
 
7
  import asyncio
8
  from fastapi.middleware.cors import CORSMiddleware
9
 
10
+ MODEL_CONFIG = {
11
+ 'bert':
12
+ {
13
+ 'MODEL_PATH' : "./Inference_components/dpr_biencoder_trained_EPOCH_2_COMPLETE.bin",
14
+ 'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv",
15
+ 'DENSE_EMBD_PATH': "./Inference_components/embeddings_BERT_EPOCH_2_COMPLETE0.pkl"
16
+ },
17
+ 'distilbert':
18
+ {
19
+ 'MODEL_PATH': "./Inference_components/distilbert_EPOCHs_7_COMPLETE.bin",
20
+ 'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv",
21
+ 'DENSE_EMBD_PATH': "./Inference_components/distilbert_7_epochs_embeddings.pkl"
22
+ },
23
+ 't5_small':
24
+ {
25
+ 'MODEL_PATH': './Inference_components/t5_small_new_dataset_2EPOCHS/'
26
+ }
27
+ }
28
+
29
+ choosen_model_path = MODEL_CONFIG['distilbert']['MODEL_PATH']
30
+ ans_list_path = MODEL_CONFIG['distilbert']['ANS_TSV_PATH']
31
+ dense_embedding_path = MODEL_CONFIG['distilbert']['DENSE_EMBD_PATH']
32
+ second_pass_model_path = MODEL_CONFIG['t5_small']['MODEL_PATH']
33
+
34
 
35
  app = FastAPI()
36
 
 
43
  )
44
 
45
  async def solve_puzzle(json):
46
+ try:
47
+
48
+ puzzle = json_CA_json_converter(json, False)
49
+ crossword = Crossword(puzzle)
50
+
51
+ async def solve_async():
52
+ return await asyncio.to_thread(BPSolver, crossword,
53
+ model_path = choosen_model_path,
54
+ ans_tsv_path = ans_list_path,
55
+ dense_embd_path = dense_embedding_path,
56
+ reranker_path = second_pass_model_path,
57
+ max_candidates = 40000,
58
+ model_type = 'distilbert')
59
+
60
+ solver = await solve_async()
61
+
62
+ async def solve_method_async():
63
+ return await asyncio.to_thread(solver.solve,num_iters=60, iterative_improvement_steps=3)
64
 
65
+ solution = await solve_method_async()
66
+
67
+ evaluation1 = await asyncio.to_thread(solver.evaluate, solution['first pass model']['grid'])
68
+ evaluation2 = await asyncio.to_thread(solver.evaluate, solution['second pass model']['final grid'])
69
+
70
+ return solution['second pass model']['final grid'], evaluation1, solution['second pass model']['final grid'], evaluation2
71
 
72
+ except Exception as e:
73
+ print(f"An error occurred: {e}")
74
+ return None, None, None
75
 
76
 
77
  fifo_queue = asyncio.Queue()
 
84
  job_id, job, args, future = await fifo_queue.get()
85
  jobs[job_id]["status"] = "processing"
86
  result = await job(*args)
87
+ print(result)
88
  jobs[job_id]["result"] = result
89
+ jobs[job_id]["status"] = "completed" if result[1] else "failed"
90
  future.set_result(job_id)
91
 
92
  @app.on_event("startup")