Spaces:
Running
Running
second pass model integrated
Browse files- BPSolver_inf.py +246 -16
- Dockerfile +4 -0
- Models_inf.py +66 -4
- Normal_utils_inf.py +100 -3
- Solver_inf.py +14 -30
- main.py +53 -23
BPSolver_inf.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import math
|
2 |
import string
|
|
|
3 |
from collections import defaultdict
|
4 |
from copy import deepcopy
|
5 |
|
@@ -9,6 +10,8 @@ from tqdm import trange
|
|
9 |
|
10 |
from Utils_inf import print_grid, get_word_flips
|
11 |
from Solver_inf import Solver
|
|
|
|
|
12 |
# the probability of each alphabetical character in the crossword
|
13 |
UNIGRAM_PROBS = [('A', 0.0897379968935765), ('B', 0.02121248877769636), ('C', 0.03482206634145926), ('D', 0.03700942543460491), ('E', 0.1159773210750429), ('F', 0.017257461694024614), ('G', 0.025429024796296124), ('H', 0.033122967601502), ('I', 0.06800036223479956), ('J', 0.00294611331754349), ('K', 0.013860682888259786), ('L', 0.05130800574373874), ('M', 0.027962776827660175), ('N', 0.06631994270448001), ('O', 0.07374646543246745), ('P', 0.026750756212433214), ('Q', 0.001507814175439393), ('R', 0.07080460813737305), ('S', 0.07410988246048224), ('T', 0.07242993582154593), ('U', 0.0289272388037645), ('V', 0.009153522059555467), ('W', 0.01434705167591524), ('X', 0.003096729223103298), ('Y', 0.01749958208224007), ('Z', 0.002659777584995724)]
|
14 |
|
@@ -18,12 +21,12 @@ LETTER_SMOOTHING_FACTOR = [0.0, 0.0, 0.04395604395604396, 0.0001372495196266813,
|
|
18 |
|
19 |
class BPVar:
|
20 |
def __init__(self, name, variable, candidates, cells):
|
21 |
-
self.name = name
|
22 |
cells_by_position = {}
|
23 |
-
for cell in cells:
|
24 |
-
cells_by_position[cell.position] = cell
|
25 |
cell._connect(self)
|
26 |
-
self.length = len(cells)
|
27 |
self.ordered_cells = [cells_by_position[pos] for pos in variable['cells']]
|
28 |
self.candidates = candidates
|
29 |
self.words = self.candidates['words']
|
@@ -85,6 +88,7 @@ class BPCell:
|
|
85 |
self.log_probs = log_softmax(sum(self.directional_scores))
|
86 |
|
87 |
def propagate(self):
|
|
|
88 |
try:
|
89 |
for i, v in enumerate(self.crossing_vars):
|
90 |
v._propagate_to_var(self, self.directional_scores[1-i])
|
@@ -98,7 +102,8 @@ class BPSolver(Solver):
|
|
98 |
model_path,
|
99 |
ans_tsv_path,
|
100 |
dense_embd_path,
|
101 |
-
|
|
|
102 |
process_id = 0,
|
103 |
model_type = 'bert',
|
104 |
**kwargs):
|
@@ -111,6 +116,8 @@ class BPSolver(Solver):
|
|
111 |
model_type = model_type,
|
112 |
**kwargs)
|
113 |
self.crossword = crossword
|
|
|
|
|
114 |
|
115 |
# our answer set
|
116 |
self.answer_set = set()
|
@@ -130,12 +137,27 @@ class BPSolver(Solver):
|
|
130 |
self.bp_cells_by_clue[clue].append(cell)
|
131 |
self.bp_vars = []
|
132 |
for key, value in self.crossword.variables.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
var = BPVar(key, value, self.candidates[key], self.bp_cells_by_clue[key])
|
|
|
|
|
|
|
134 |
self.bp_vars.append(var)
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
def solve(self, num_iters=10, iterative_improvement_steps=5, return_greedy_states = False, return_ii_states = False):
|
137 |
# run solving for num_iters iterations
|
138 |
-
print('
|
139 |
for _ in trange(num_iters):
|
140 |
for var in self.bp_vars:
|
141 |
var.propagate()
|
@@ -145,7 +167,7 @@ class BPSolver(Solver):
|
|
145 |
cell.propagate()
|
146 |
for var in self.bp_vars:
|
147 |
var.sync_state()
|
148 |
-
print('
|
149 |
|
150 |
# Get the current based grid based on greedy selection from the marginals
|
151 |
if return_greedy_states:
|
@@ -153,16 +175,168 @@ class BPSolver(Solver):
|
|
153 |
else:
|
154 |
grid = self.greedy_sequential_word_solution()
|
155 |
all_grids = []
|
156 |
-
|
157 |
-
#
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
if return_greedy_states or return_ii_states:
|
162 |
-
return
|
163 |
else:
|
164 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
def greedy_sequential_word_solution(self, return_grids = False):
|
167 |
all_grids = []
|
168 |
# after we've run BP, we run a simple greedy search to get the final.
|
@@ -181,7 +355,6 @@ class BPSolver(Solver):
|
|
181 |
best_index = best_per_var.index(max([x for x in best_per_var if x is not None]))
|
182 |
best_var = self.bp_vars[best_index]
|
183 |
best_word = best_var.words[best_var.log_probs.argmax()]
|
184 |
-
# print('greedy filling in', best_word)
|
185 |
for i, cell in enumerate(best_var.ordered_cells):
|
186 |
letter = best_word[i]
|
187 |
grid[cell.position[0]][cell.position[1]] = letter
|
@@ -201,14 +374,71 @@ class BPSolver(Solver):
|
|
201 |
best_var.words = []
|
202 |
best_var.log_probs = best_var.log_probs[[]]
|
203 |
best_per_var[best_index] = None
|
|
|
|
|
204 |
for cell in self.bp_cells:
|
205 |
if cell.position in unfilled_cells:
|
|
|
206 |
grid[cell.position[0]][cell.position[1]] = string.ascii_uppercase[cell.log_probs.argmax()]
|
207 |
-
|
208 |
for var, (words, log_probs) in zip(self.bp_vars, cache): # restore state
|
209 |
var.words = words
|
210 |
var.log_probs = log_probs
|
211 |
if return_grids:
|
212 |
return grid, all_grids
|
213 |
else:
|
214 |
-
return grid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import math
|
2 |
import string
|
3 |
+
import re
|
4 |
from collections import defaultdict
|
5 |
from copy import deepcopy
|
6 |
|
|
|
10 |
|
11 |
from Utils_inf import print_grid, get_word_flips
|
12 |
from Solver_inf import Solver
|
13 |
+
from Models_inf import setup_t5_reranker, t5_reranker_score_with_clue
|
14 |
+
|
15 |
# the probability of each alphabetical character in the crossword
|
16 |
UNIGRAM_PROBS = [('A', 0.0897379968935765), ('B', 0.02121248877769636), ('C', 0.03482206634145926), ('D', 0.03700942543460491), ('E', 0.1159773210750429), ('F', 0.017257461694024614), ('G', 0.025429024796296124), ('H', 0.033122967601502), ('I', 0.06800036223479956), ('J', 0.00294611331754349), ('K', 0.013860682888259786), ('L', 0.05130800574373874), ('M', 0.027962776827660175), ('N', 0.06631994270448001), ('O', 0.07374646543246745), ('P', 0.026750756212433214), ('Q', 0.001507814175439393), ('R', 0.07080460813737305), ('S', 0.07410988246048224), ('T', 0.07242993582154593), ('U', 0.0289272388037645), ('V', 0.009153522059555467), ('W', 0.01434705167591524), ('X', 0.003096729223103298), ('Y', 0.01749958208224007), ('Z', 0.002659777584995724)]
|
17 |
|
|
|
21 |
|
22 |
class BPVar:
|
23 |
def __init__(self, name, variable, candidates, cells):
|
24 |
+
self.name = name # key from crossword.variables i.e. 1A, 2D, 3A
|
25 |
cells_by_position = {}
|
26 |
+
for cell in cells: # every cells or letter box that a particular variable or filling takes into consideration
|
27 |
+
cells_by_position[cell.position] = cell # cell.position (0,0) -> cell -> BPCell
|
28 |
cell._connect(self)
|
29 |
+
self.length = len(cells) # obviously the length of the answer
|
30 |
self.ordered_cells = [cells_by_position[pos] for pos in variable['cells']]
|
31 |
self.candidates = candidates
|
32 |
self.words = self.candidates['words']
|
|
|
88 |
self.log_probs = log_softmax(sum(self.directional_scores))
|
89 |
|
90 |
def propagate(self):
|
91 |
+
# assert len(self.crossing_vars) == 2
|
92 |
try:
|
93 |
for i, v in enumerate(self.crossing_vars):
|
94 |
v._propagate_to_var(self, self.directional_scores[1-i])
|
|
|
102 |
model_path,
|
103 |
ans_tsv_path,
|
104 |
dense_embd_path,
|
105 |
+
reranker_path,
|
106 |
+
max_candidates = 100,
|
107 |
process_id = 0,
|
108 |
model_type = 'bert',
|
109 |
**kwargs):
|
|
|
116 |
model_type = model_type,
|
117 |
**kwargs)
|
118 |
self.crossword = crossword
|
119 |
+
self.reranker_path = reranker_path
|
120 |
+
self.reranker_model_type = 't5-small'
|
121 |
|
122 |
# our answer set
|
123 |
self.answer_set = set()
|
|
|
137 |
self.bp_cells_by_clue[clue].append(cell)
|
138 |
self.bp_vars = []
|
139 |
for key, value in self.crossword.variables.items():
|
140 |
+
# if key == '1A':
|
141 |
+
# print('-'*100)
|
142 |
+
# print(self.candidates[key]['words'])
|
143 |
+
# print(self.candidates[key]['bit_array'].shape)
|
144 |
+
# print(self.candidates[key]['weights'])
|
145 |
+
# print('-'*100)
|
146 |
var = BPVar(key, value, self.candidates[key], self.bp_cells_by_clue[key])
|
147 |
+
# print('*'*100)
|
148 |
+
# print(self.bp_cells_by_clue[key])
|
149 |
+
# print('*'*100)
|
150 |
self.bp_vars.append(var)
|
151 |
+
|
152 |
+
def extract_float(self, input_string):
|
153 |
+
pattern = r"\d+\.\d+"
|
154 |
+
matches = re.findall(pattern, input_string)
|
155 |
+
float_numbers = [float(match) for match in matches]
|
156 |
+
return float_numbers
|
157 |
|
158 |
def solve(self, num_iters=10, iterative_improvement_steps=5, return_greedy_states = False, return_ii_states = False):
|
159 |
# run solving for num_iters iterations
|
160 |
+
print('\nBeginning Belief Propagation Iteration Steps: ')
|
161 |
for _ in trange(num_iters):
|
162 |
for var in self.bp_vars:
|
163 |
var.propagate()
|
|
|
167 |
cell.propagate()
|
168 |
for var in self.bp_vars:
|
169 |
var.sync_state()
|
170 |
+
print('Belief Propagation Iteration Complete\n')
|
171 |
|
172 |
# Get the current based grid based on greedy selection from the marginals
|
173 |
if return_greedy_states:
|
|
|
175 |
else:
|
176 |
grid = self.greedy_sequential_word_solution()
|
177 |
all_grids = []
|
178 |
+
|
179 |
+
# properly save all the outputs results:
|
180 |
+
output_results = {}
|
181 |
+
output_results['first pass model'] = {}
|
182 |
+
output_results['first pass model']['grid'] = grid
|
183 |
+
|
184 |
+
# save first pass model grid, and letter accuracies
|
185 |
+
_, accu_log = self.evaluate(grid, False)
|
186 |
+
[ori_letter_accu, ori_word_accu] = self.extract_float(accu_log)
|
187 |
+
output_results['first pass model']['letter accuracy'] = ori_letter_accu
|
188 |
+
output_results['first pass model']['word accuracy'] = ori_word_accu
|
189 |
+
|
190 |
+
print("First pass model result was", grid,ori_letter_accu,ori_word_accu)
|
191 |
|
192 |
+
output_results['second pass model'] = {}
|
193 |
+
output_results['second pass model']['final grid'] = [] # just for the sake of the api
|
194 |
+
output_results['second pass model']['final grid'] = grid # just for the sake of the api
|
195 |
+
output_results['second pass model']['all grids'] = []
|
196 |
+
output_results['second pass model']['all letter accuracy'] = []
|
197 |
+
output_results['second pass model']['all word accuracy'] = []
|
198 |
+
|
199 |
+
if iterative_improvement_steps < 1 or ori_letter_accu == 100.0:
|
200 |
if return_greedy_states or return_ii_states:
|
201 |
+
return output_results, all_grids
|
202 |
else:
|
203 |
+
return output_results
|
204 |
+
|
205 |
+
'''
|
206 |
+
Iterative Improvement with t5-small starts from here.
|
207 |
+
'''
|
208 |
+
self.reranker, self.tokenizer = setup_t5_reranker(self.reranker_path, self.reranker_model_type)
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
for i in range(iterative_improvement_steps):
|
213 |
+
grid, did_iterative_improvement_make_edit = self.iterative_improvement(grid)
|
214 |
+
|
215 |
+
_, accu_log = self.evaluate(grid, False)
|
216 |
+
[temp_letter_accu, temp_word_accu] = self.extract_float(accu_log)
|
217 |
+
print(f"{i+1}th iteration: {accu_log}")
|
218 |
+
|
219 |
+
# save grid & accuracies at each iteration
|
220 |
+
output_results['second pass model']['all grids'].append(grid)
|
221 |
+
output_results['second pass model']['all letter accuracy'].append(temp_letter_accu)
|
222 |
+
output_results['second pass model']['all word accuracy'].append(temp_word_accu)
|
223 |
+
|
224 |
+
if not did_iterative_improvement_make_edit or temp_letter_accu == 100.0:
|
225 |
+
break
|
226 |
+
|
227 |
+
if return_ii_states:
|
228 |
+
all_grids.append(deepcopy(grid))
|
229 |
+
|
230 |
+
temp_lett_accu_list = output_results['second pass model']['all letter accuracy'].copy()
|
231 |
+
ii_max_index = temp_lett_accu_list.index(max(temp_lett_accu_list))
|
232 |
+
|
233 |
+
output_results['second pass model']['final grid'] = output_results['second pass model']['all grids'][ii_max_index]
|
234 |
+
output_results['second pass model']['final letter'] = output_results['second pass model']['all letter accuracy'][ii_max_index]
|
235 |
+
output_results['second pass model']['final word'] = output_results['second pass model']['all word accuracy'][ii_max_index]
|
236 |
+
|
237 |
+
if return_greedy_states or return_ii_states:
|
238 |
+
return output_results, all_grids
|
239 |
+
else:
|
240 |
+
return output_results
|
241 |
+
|
242 |
+
def get_candidate_replacements(self, uncertain_answers, grid):
|
243 |
+
# find alternate answers for all the uncertain words
|
244 |
+
candidate_replacements = []
|
245 |
+
replacement_id_set = set()
|
246 |
+
|
247 |
+
# check against dictionaries
|
248 |
+
for clue in uncertain_answers.keys():
|
249 |
+
initial_word = uncertain_answers[clue]
|
250 |
+
clue_flips = get_word_flips(initial_word, 10) # flip then segment
|
251 |
+
clue_positions = [key for key, value in self.crossword.variables.items() if value['clue'] == clue]
|
252 |
+
for clue_position in clue_positions:
|
253 |
+
cells = sorted([cell for cell in self.bp_cells if clue_position in cell.crossing_clues], key=lambda c: c.position)
|
254 |
+
if len(cells) == len(initial_word):
|
255 |
+
break
|
256 |
+
for flip in clue_flips:
|
257 |
+
if len(flip) != len(cells):
|
258 |
+
import pdb; pdb.set_trace()
|
259 |
+
assert len(flip) == len(cells)
|
260 |
+
for i in range(len(flip)):
|
261 |
+
if flip[i] != initial_word[i]:
|
262 |
+
candidate_replacements.append([(cells[i], flip[i])])
|
263 |
+
break
|
264 |
+
|
265 |
+
# also add candidates based on uncertainties in the letters, e.g., if we said P but G also had some probability, try G too
|
266 |
+
for cell_id, cell in enumerate(self.bp_cells):
|
267 |
+
probs = np.exp(cell.log_probs)
|
268 |
+
above_threshold = list(probs > 0.01)
|
269 |
+
new_characters = ['ABCDEFGHIJKLMNOPQRSTUVWXYZ'[i] for i in range(26) if above_threshold[i]]
|
270 |
+
# used = set()
|
271 |
+
# new_characters = [x for x in new_characters if x not in used and (used.add(x) or True)] # unique the set
|
272 |
+
new_characters = [x for x in new_characters if x != grid[cell.position[0]][cell.position[1]]] # ignore if its the same as the original solution
|
273 |
+
if len(new_characters) > 0:
|
274 |
+
for new_character in new_characters:
|
275 |
+
id = '_'.join([str(cell.position), new_character])
|
276 |
+
if id not in replacement_id_set:
|
277 |
+
candidate_replacements.append([(cell, new_character)])
|
278 |
+
replacement_id_set.add(id)
|
279 |
+
|
280 |
+
# create composite flips based on things in the same row/column
|
281 |
+
composite_replacements = []
|
282 |
+
for i in range(len(candidate_replacements)):
|
283 |
+
for j in range(i+1, len(candidate_replacements)):
|
284 |
+
flip1, flip2 = candidate_replacements[i], candidate_replacements[j]
|
285 |
+
if flip1[0][0] != flip2[0][0]:
|
286 |
+
if len(set(flip1[0][0].crossing_clues + flip2[0][0].crossing_clues)) < 4: # shared clue
|
287 |
+
composite_replacements.append(flip1 + flip2)
|
288 |
+
|
289 |
+
candidate_replacements += composite_replacements
|
290 |
+
|
291 |
+
#print('\ncandidate replacements')
|
292 |
+
for cr in candidate_replacements:
|
293 |
+
modified_grid = deepcopy(grid)
|
294 |
+
for cell, letter in cr:
|
295 |
+
modified_grid[cell.position[0]][cell.position[1]] = letter
|
296 |
+
variables = set(sum([cell.crossing_vars for cell, _ in cr], []))
|
297 |
+
for var in variables:
|
298 |
+
original_fill = ''.join([grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
|
299 |
+
modified_fill = ''.join([modified_grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
|
300 |
+
#print('original:', original_fill, 'modified:', modified_fill)
|
301 |
+
|
302 |
+
return candidate_replacements
|
303 |
+
|
304 |
+
def get_uncertain_answers(self, grid):
|
305 |
+
original_qa_pairs = {} # the original puzzle preds that we will try to improve
|
306 |
+
# first save what the argmax word-level prediction was for each grid cell just to make life easier
|
307 |
+
for var in self.crossword.variables:
|
308 |
+
# read the current word off the grid
|
309 |
+
cells = self.crossword.variables[var]["cells"]
|
310 |
+
word = []
|
311 |
+
for cell in cells:
|
312 |
+
word.append(grid[cell[0]][cell[1]])
|
313 |
+
word = ''.join(word)
|
314 |
+
for cell in self.bp_cells: # loop through all cells
|
315 |
+
if cell.position in cells: # if this cell is in the word we are currently handling
|
316 |
+
# save {clue, answer} pair into this cell
|
317 |
+
cell.prediction[self.crossword.variables[var]['clue']] = word
|
318 |
+
original_qa_pairs[self.crossword.variables[var]['clue']] = word
|
319 |
+
|
320 |
+
uncertain_answers = {}
|
321 |
+
|
322 |
+
# find uncertain answers
|
323 |
+
# right now the heuristic we use is any answer that is not in the answer set
|
324 |
+
for clue in original_qa_pairs.keys():
|
325 |
+
if original_qa_pairs[clue] not in self.answer_set:
|
326 |
+
uncertain_answers[clue] = original_qa_pairs[clue]
|
327 |
+
|
328 |
+
return uncertain_answers
|
329 |
|
330 |
+
def score_grid(self, grid):
|
331 |
+
clues = []
|
332 |
+
answers = []
|
333 |
+
for clue, cells in self.bp_cells_by_clue.items():
|
334 |
+
letters = ''.join([grid[cell.position[0]][cell.position[1]] for cell in sorted(list(cells), key=lambda c: c.position)])
|
335 |
+
clues.append(self.crossword.variables[clue]['clue'])
|
336 |
+
answers.append(letters)
|
337 |
+
scores = t5_reranker_score_with_clue(self.reranker, self.tokenizer, self.reranker_model_type, clues, answers)
|
338 |
+
return sum(scores)
|
339 |
+
|
340 |
def greedy_sequential_word_solution(self, return_grids = False):
|
341 |
all_grids = []
|
342 |
# after we've run BP, we run a simple greedy search to get the final.
|
|
|
355 |
best_index = best_per_var.index(max([x for x in best_per_var if x is not None]))
|
356 |
best_var = self.bp_vars[best_index]
|
357 |
best_word = best_var.words[best_var.log_probs.argmax()]
|
|
|
358 |
for i, cell in enumerate(best_var.ordered_cells):
|
359 |
letter = best_word[i]
|
360 |
grid[cell.position[0]][cell.position[1]] = letter
|
|
|
374 |
best_var.words = []
|
375 |
best_var.log_probs = best_var.log_probs[[]]
|
376 |
best_per_var[best_index] = None
|
377 |
+
|
378 |
+
unfilled_cells_count = 0
|
379 |
for cell in self.bp_cells:
|
380 |
if cell.position in unfilled_cells:
|
381 |
+
unfilled_cells_count += 1
|
382 |
grid[cell.position[0]][cell.position[1]] = string.ascii_uppercase[cell.log_probs.argmax()]
|
383 |
+
|
384 |
for var, (words, log_probs) in zip(self.bp_vars, cache): # restore state
|
385 |
var.words = words
|
386 |
var.log_probs = log_probs
|
387 |
if return_grids:
|
388 |
return grid, all_grids
|
389 |
else:
|
390 |
+
return grid
|
391 |
+
|
392 |
+
def iterative_improvement(self, grid):
|
393 |
+
# check the grid for uncertain areas and save those words to be analyzed in local search, aka looking for alternate candidates
|
394 |
+
uncertain_answers = self.get_uncertain_answers(grid)
|
395 |
+
self.candidate_replacements = self.get_candidate_replacements(uncertain_answers, grid)
|
396 |
+
|
397 |
+
# print('\nstarting iterative improvement')
|
398 |
+
original_grid_score = self.score_grid(grid)
|
399 |
+
possible_edits = []
|
400 |
+
for replacements in self.candidate_replacements:
|
401 |
+
modified_grid = deepcopy(grid)
|
402 |
+
for cell, letter in replacements:
|
403 |
+
modified_grid[cell.position[0]][cell.position[1]] = letter
|
404 |
+
modified_grid_score = self.score_grid(modified_grid)
|
405 |
+
# print('candidate edit')
|
406 |
+
variables = set(sum([cell.crossing_vars for cell, _ in replacements], []))
|
407 |
+
for var in variables:
|
408 |
+
original_fill = ''.join([grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
|
409 |
+
modified_fill = ''.join([modified_grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
|
410 |
+
clue_index = list(set(var.ordered_cells[0].crossing_clues).intersection(*[set(cell.crossing_clues) for cell in var.ordered_cells]))[0]
|
411 |
+
# print('original:', original_fill, 'modified:', modified_fill)
|
412 |
+
# print('gold answer', self.crossword.variables[clue_index]['gold'])
|
413 |
+
# print('clue', self.crossword.variables[clue_index]['clue'])
|
414 |
+
# print('original score:', original_grid_score, 'modified score:', modified_grid_score)
|
415 |
+
if modified_grid_score - original_grid_score > 0.5:
|
416 |
+
# print('found a possible edit')
|
417 |
+
possible_edits.append((modified_grid, modified_grid_score, replacements))
|
418 |
+
# print()
|
419 |
+
|
420 |
+
if len(possible_edits) > 0:
|
421 |
+
variables_modified = set()
|
422 |
+
possible_edits = sorted(possible_edits, key=lambda x: x[1], reverse=True)
|
423 |
+
selected_edits = []
|
424 |
+
for edit in possible_edits:
|
425 |
+
replacements = edit[2]
|
426 |
+
variables = set(sum([cell.crossing_vars for cell, _ in replacements], []))
|
427 |
+
if len(variables_modified.intersection(variables)) == 0: # we can do multiple updates at once if they don't share clues
|
428 |
+
variables_modified.update(variables)
|
429 |
+
selected_edits.append(edit)
|
430 |
+
|
431 |
+
new_grid = deepcopy(grid)
|
432 |
+
for edit in selected_edits:
|
433 |
+
# print('\nactually applying edit')
|
434 |
+
replacements = edit[2]
|
435 |
+
for cell, letter in replacements:
|
436 |
+
new_grid[cell.position[0]][cell.position[1]] = letter
|
437 |
+
variables = set(sum([cell.crossing_vars for cell, _ in replacements], []))
|
438 |
+
for var in variables:
|
439 |
+
original_fill = ''.join([grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
|
440 |
+
modified_fill = ''.join([new_grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
|
441 |
+
# print('original:', original_fill, 'modified:', modified_fill)
|
442 |
+
return new_grid, True
|
443 |
+
else:
|
444 |
+
return grid, False
|
Dockerfile
CHANGED
@@ -34,4 +34,8 @@ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scor
|
|
34 |
ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/dpr_biencoder_trained_500k.bin $HOME/app/Inference_components/
|
35 |
ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/embeddings_all_answers_json_0.pkl $HOME/app/Inference_components/
|
36 |
|
|
|
|
|
|
|
|
|
37 |
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
34 |
ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/dpr_biencoder_trained_500k.bin $HOME/app/Inference_components/
|
35 |
ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/embeddings_all_answers_json_0.pkl $HOME/app/Inference_components/
|
36 |
|
37 |
+
ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/t5_small_new_dataset_2EPOCHS/config.json $HOME/app/Inference_components/t5_small_new_dataset_2EPOCHS/
|
38 |
+
ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/t5_small_new_dataset_2EPOCHS/generation_config.json $HOME/app/Inference_components/t5_small_new_dataset_2EPOCHS/
|
39 |
+
ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/resolve/main/t5_small_new_dataset_2EPOCHS/model.safetensors $HOME/app/Inference_components/t5_small_new_dataset_2EPOCHS/
|
40 |
+
|
41 |
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
Models_inf.py
CHANGED
@@ -7,8 +7,6 @@ import string
|
|
7 |
import sys
|
8 |
from typing import List, Tuple, Dict
|
9 |
import re
|
10 |
-
import math
|
11 |
-
import collections
|
12 |
|
13 |
import numpy as np
|
14 |
import unicodedata
|
@@ -21,7 +19,10 @@ from Options_inf import setup_args_gpu, print_args, set_encoder_params_from_stat
|
|
21 |
from Faiss_Indexers_inf import DenseIndexer, DenseFlatIndexer
|
22 |
from Data_utils_inf import Tensorizer
|
23 |
from Model_utils_inf import load_states_from_checkpoint, get_model_obj
|
24 |
-
|
|
|
|
|
|
|
25 |
|
26 |
SEGMENTER_CACHE = {}
|
27 |
RERANKER_CACHE = {}
|
@@ -37,6 +38,64 @@ def setup_closedbook(model_path, ans_tsv_path, dense_embd_path, process_id, mode
|
|
37 |
)
|
38 |
return dpr
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def preprocess_clue_fn(clue):
|
41 |
clue = str(clue)
|
42 |
|
@@ -202,6 +261,7 @@ class DenseRetriever(object):
|
|
202 |
query_vectors.extend(out.cpu().split(1, dim=0))
|
203 |
|
204 |
query_tensor = torch.cat(query_vectors, dim=0)
|
|
|
205 |
assert query_tensor.size(0) == len(questions)
|
206 |
return query_tensor
|
207 |
|
@@ -353,9 +413,11 @@ class DPRForCrossword(object):
|
|
353 |
if max_answers > self.len_all_passages:
|
354 |
max_answers = self.len_all_passages
|
355 |
|
|
|
356 |
# get top k results
|
357 |
top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_answers)
|
358 |
-
|
|
|
359 |
if not output_strings:
|
360 |
return top_ids_and_scores
|
361 |
else:
|
|
|
7 |
import sys
|
8 |
from typing import List, Tuple, Dict
|
9 |
import re
|
|
|
|
|
10 |
|
11 |
import numpy as np
|
12 |
import unicodedata
|
|
|
19 |
from Faiss_Indexers_inf import DenseIndexer, DenseFlatIndexer
|
20 |
from Data_utils_inf import Tensorizer
|
21 |
from Model_utils_inf import load_states_from_checkpoint, get_model_obj
|
22 |
+
from transformers import T5ForConditionalGeneration, AutoTokenizer
|
23 |
+
import time
|
24 |
+
from wordsegment import load, segment
|
25 |
+
load()
|
26 |
|
27 |
SEGMENTER_CACHE = {}
|
28 |
RERANKER_CACHE = {}
|
|
|
38 |
)
|
39 |
return dpr
|
40 |
|
41 |
+
def setup_t5_reranker(reranker_path, reranker_model_type = 't5-small'):
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(reranker_model_type)
|
43 |
+
model = T5ForConditionalGeneration.from_pretrained(reranker_path)
|
44 |
+
model.eval().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
45 |
+
return model, tokenizer
|
46 |
+
|
47 |
+
def post_process_clue(clue):
|
48 |
+
clue = preprocess_clue_fn(clue)
|
49 |
+
if clue[-3:] == '. .':
|
50 |
+
clue = clue[:-3]
|
51 |
+
elif clue[-3:] == ' ..':
|
52 |
+
clue = clue[:-3]
|
53 |
+
elif clue[-2:] == '..':
|
54 |
+
clue = clue[:-2]
|
55 |
+
elif clue[-1] == '.':
|
56 |
+
clue = clue[:-1]
|
57 |
+
return clue
|
58 |
+
|
59 |
+
def t5_reranker_score_with_clue(model, tokenizer, model_type, clues, possibly_ungrammatical_fills):
|
60 |
+
global RERANKER_CACHE
|
61 |
+
results = []
|
62 |
+
device = model.device
|
63 |
+
|
64 |
+
fills = possibly_ungrammatical_fills.copy()
|
65 |
+
|
66 |
+
if model_type == 't5-small':
|
67 |
+
segmented_fills = []
|
68 |
+
for answer in possibly_ungrammatical_fills:
|
69 |
+
segmented_fills.append(" ".join(segment(answer.lower())))
|
70 |
+
fills = segmented_fills.copy()
|
71 |
+
|
72 |
+
for clue, possibly_ungrammatical_fill in zip(clues, fills):
|
73 |
+
# possibly here is where the byt5 failed
|
74 |
+
if not possibly_ungrammatical_fill.islower():
|
75 |
+
possibly_ungrammatical_fill = possibly_ungrammatical_fill.lower()
|
76 |
+
|
77 |
+
clue = post_process_clue(clue)
|
78 |
+
|
79 |
+
if clue + possibly_ungrammatical_fill in RERANKER_CACHE:
|
80 |
+
results.append(RERANKER_CACHE[clue + possibly_ungrammatical_fill])
|
81 |
+
continue
|
82 |
+
else:
|
83 |
+
with torch.no_grad(), torch.inference_mode():
|
84 |
+
# move all the input tensors to the GPU (cuda)
|
85 |
+
inputs = tokenizer(["Q: " + clue], return_tensors='pt')['input_ids'].to(device)
|
86 |
+
labels = tokenizer([possibly_ungrammatical_fill], return_tensors='pt')['input_ids'].to(device)
|
87 |
+
|
88 |
+
# model mode set to evaluation
|
89 |
+
model.eval()
|
90 |
+
|
91 |
+
loss = model(inputs, labels = labels)
|
92 |
+
answer_length = labels.shape[1]
|
93 |
+
logprob = -loss[0].item() * answer_length
|
94 |
+
results.append(logprob)
|
95 |
+
RERANKER_CACHE[clue + possibly_ungrammatical_fill] = logprob
|
96 |
+
|
97 |
+
return results
|
98 |
+
|
99 |
def preprocess_clue_fn(clue):
|
100 |
clue = str(clue)
|
101 |
|
|
|
261 |
query_vectors.extend(out.cpu().split(1, dim=0))
|
262 |
|
263 |
query_tensor = torch.cat(query_vectors, dim=0)
|
264 |
+
print("CLUE Vector Shape", query_tensor.shape)
|
265 |
assert query_tensor.size(0) == len(questions)
|
266 |
return query_tensor
|
267 |
|
|
|
413 |
if max_answers > self.len_all_passages:
|
414 |
max_answers = self.len_all_passages
|
415 |
|
416 |
+
start_time = time.time()
|
417 |
# get top k results
|
418 |
top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_answers)
|
419 |
+
end_time = time.time()
|
420 |
+
print("\n\nTime taken by FAISS INDEXER: ", end_time - start_time)
|
421 |
if not output_strings:
|
422 |
return top_ids_and_scores
|
423 |
else:
|
Normal_utils_inf.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import puz
|
2 |
-
import
|
3 |
-
import
|
4 |
-
import sys
|
5 |
|
6 |
def puz_to_json(fname):
|
7 |
""" Converts a puzzle in .puz format to .json format
|
@@ -63,3 +62,101 @@ def puz_to_pairs(filepath):
|
|
63 |
|
64 |
return [(k, v) for k, v in pairs.items()]
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import puz
|
2 |
+
import json
|
3 |
+
import requests
|
|
|
4 |
|
5 |
def puz_to_json(fname):
|
6 |
""" Converts a puzzle in .puz format to .json format
|
|
|
62 |
|
63 |
return [(k, v) for k, v in pairs.items()]
|
64 |
|
65 |
+
def json_CA_json_converter(json_file_path, is_path):
|
66 |
+
try:
|
67 |
+
if is_path:
|
68 |
+
with open(json_file_path, "r") as file:
|
69 |
+
data = json.load(file)
|
70 |
+
else:
|
71 |
+
data = json_file_path
|
72 |
+
|
73 |
+
json_conversion_dict = {}
|
74 |
+
|
75 |
+
rows = data["size"]["rows"]
|
76 |
+
cols = data["size"]["cols"]
|
77 |
+
date = data["date"]
|
78 |
+
|
79 |
+
clues = data["clues"]
|
80 |
+
answers = data["answers"]
|
81 |
+
|
82 |
+
json_conversion_dict["metadata"] = {"date": date, "rows": rows, "cols": cols}
|
83 |
+
|
84 |
+
across_clue_answer = {}
|
85 |
+
down_clue_answer = {}
|
86 |
+
|
87 |
+
for clue, ans in zip(clues["across"], answers["across"]):
|
88 |
+
split_clue = clue.split(" ")
|
89 |
+
clue_num = split_clue[0][:-1]
|
90 |
+
clue_ = " ".join(split_clue[1:])
|
91 |
+
clue_ = clue_.replace("[", "").replace("]", "")
|
92 |
+
across_clue_answer[clue_num] = [clue_, ans]
|
93 |
+
|
94 |
+
for clue, ans in zip(clues["down"], answers["down"]):
|
95 |
+
split_clue = clue.split(" ")
|
96 |
+
clue_num = split_clue[0][:-1]
|
97 |
+
clue_ = " ".join(split_clue[1:])
|
98 |
+
clue_ = clue_.replace("[", "").replace("]", "")
|
99 |
+
down_clue_answer[clue_num] = [clue_, ans]
|
100 |
+
|
101 |
+
json_conversion_dict["clues"] = {
|
102 |
+
"across": across_clue_answer,
|
103 |
+
"down": down_clue_answer,
|
104 |
+
}
|
105 |
+
|
106 |
+
grid_info = data["grid"]
|
107 |
+
grid_num = data["gridnums"]
|
108 |
+
|
109 |
+
grid_info_list = []
|
110 |
+
for i in range(rows):
|
111 |
+
row_list = []
|
112 |
+
for j in range(cols):
|
113 |
+
if grid_info[i * rows + j] == ".":
|
114 |
+
row_list.append("BLACK")
|
115 |
+
else:
|
116 |
+
if grid_num[i * rows + j] == 0:
|
117 |
+
row_list.append(["", grid_info[i * rows + j]])
|
118 |
+
else:
|
119 |
+
row_list.append(
|
120 |
+
[str(grid_num[i * rows + j]), grid_info[i * rows + j]]
|
121 |
+
)
|
122 |
+
grid_info_list.append(row_list)
|
123 |
+
|
124 |
+
json_conversion_dict["grid"] = grid_info_list
|
125 |
+
|
126 |
+
return json_conversion_dict
|
127 |
+
|
128 |
+
except:
|
129 |
+
print("ERROR has occured.")
|
130 |
+
|
131 |
+
def fetch_nyt_crossword(dateStr):
|
132 |
+
'''
|
133 |
+
Fetch NYT puzzle from a specific date.
|
134 |
+
'''
|
135 |
+
|
136 |
+
headers = {
|
137 |
+
'Referer': 'https://www.xwordinfo.com/JSON/'
|
138 |
+
}
|
139 |
+
# mm/dd/yyyy
|
140 |
+
|
141 |
+
url = 'https://www.xwordinfo.com/JSON/Data.ashx?date=' + dateStr
|
142 |
+
|
143 |
+
response = requests.get(url, headers=headers)
|
144 |
+
|
145 |
+
context = {}
|
146 |
+
grid_data = {}
|
147 |
+
if response.status_code == 200:
|
148 |
+
bytevalue = response.content
|
149 |
+
jsonText = bytevalue.decode('utf-8').replace("'", '"')
|
150 |
+
grid_data = json.loads(jsonText)
|
151 |
+
puzzle_data = json_CA_json_converter(grid_data, False)
|
152 |
+
for dim in ['across', 'down']:
|
153 |
+
for grid_num in puzzle_data['clues'][dim].keys():
|
154 |
+
clue_answer_list = puzzle_data['clues'][dim][grid_num]
|
155 |
+
clue_section = clue_answer_list[0]
|
156 |
+
ans_section = clue_answer_list[1]
|
157 |
+
clue_section = clue_section.replace(""", "'").replace("'", "'")
|
158 |
+
puzzle_data['clues'][dim][grid_num] = [clue_section, ans_section]
|
159 |
+
return puzzle_data
|
160 |
+
|
161 |
+
else:
|
162 |
+
print(f"Request failed with status code {response.status_code}.")
|
Solver_inf.py
CHANGED
@@ -16,7 +16,7 @@ class Solver:
|
|
16 |
crossword (Crossword): puzzle to solve
|
17 |
max_candidates (int): number of answer candidates to consider per clue
|
18 |
"""
|
19 |
-
def __init__(self, crossword, model_path, ans_tsv_path, dense_embd_path, max_candidates=
|
20 |
self.crossword = crossword
|
21 |
self.max_candidates = max_candidates
|
22 |
self.process_id = process_id
|
@@ -46,10 +46,10 @@ class Solver:
|
|
46 |
clue = clue[:match.start()] + self.crossword.variables[var]['clue'] + clue[match.end():]
|
47 |
all_clues[idx] = clue
|
48 |
|
49 |
-
# print("MODEL PATH: ", type(self.dense_embd_glob))
|
50 |
# get predictions
|
51 |
dpr = setup_closedbook(self.model_path, self.ans_tsv_path, self.dense_embd_glob, self.process_id, self.model_type)
|
52 |
-
all_words, all_scores = answer_clues(dpr, all_clues, max_answers=self.max_candidates, output_strings=True)
|
|
|
53 |
for index, var in enumerate(self.crossword.variables):
|
54 |
length = len(self.crossword.variables[var]["gold"])
|
55 |
self.candidates[var] = {"words": [], "bit_array": None, "weights": {}}
|
@@ -63,21 +63,13 @@ class Solver:
|
|
63 |
keep_positions.append(word_index)
|
64 |
words = [words[i] for i in keep_positions]
|
65 |
scores = [scores[i] for i in keep_positions]
|
|
|
66 |
scores = list(-np.log(softmax(np.array(scores) / 0.75)))
|
67 |
|
68 |
for word, score in zip(words, scores):
|
69 |
self.candidates[var]["weights"][word] = score
|
70 |
-
|
71 |
-
# for debugging purposes, print the rank of the gold answer on our candidate list
|
72 |
-
# the gold answer is otherwise *not* used in any way during solving
|
73 |
-
# if self.crossword.variables[var]["gold"] in words:
|
74 |
-
# print(clue, self.crossword.variables[var]["gold"], words.index(self.crossword.variables[var]["gold"]))
|
75 |
-
# else:
|
76 |
-
# print('not found', clue, self.crossword.variables[var]["gold"])
|
77 |
|
78 |
-
|
79 |
-
for word, score in zip(words, scores):
|
80 |
-
self.candidates[var]["weights"][word] = score
|
81 |
weights = self.candidates[var]["weights"]
|
82 |
self.candidates[var]["words"] = sorted(weights, key=weights.get)
|
83 |
self.candidates[var]["bit_array"] = np.zeros((len(chars), length, len(self.candidates[var]["words"])))
|
@@ -94,8 +86,7 @@ class Solver:
|
|
94 |
# cleanup a bit
|
95 |
del dpr
|
96 |
|
97 |
-
def evaluate(self, solution):
|
98 |
-
# print puzzle accuracy results given a generated solution
|
99 |
letters_correct = 0
|
100 |
letters_total = 0
|
101 |
for i in range(len(self.crossword.letter_grid)):
|
@@ -110,20 +101,13 @@ class Solver:
|
|
110 |
matching_cells = [self.crossword.letter_grid[cell[0]][cell[1]] == solution[cell[0]][cell[1]] for cell in cells]
|
111 |
if len(cells) == sum(matching_cells):
|
112 |
words_correct += 1
|
113 |
-
else:
|
114 |
-
# print('evaluation: correct word', ''.join([self.crossword.letter_grid[cell[0]][cell[1]] for cell in cells]), 'our prediction:', ''.join([solution[cell[0]][cell[1]] for cell in cells]))
|
115 |
-
pass
|
116 |
words_total += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
|
119 |
-
print("Letters Correct: {}% | Words Correct: {}%".format(float(letters_correct/letters_total*100), float(words_correct/words_total*100)))
|
120 |
-
|
121 |
-
info = {
|
122 |
-
"total_letters" : int(letters_total),
|
123 |
-
"total_words" : int(words_total),
|
124 |
-
"correct_letters" : int(letters_correct),
|
125 |
-
"correct_words" : int(words_correct),
|
126 |
-
"correct_letters_percent" : float(letters_correct/letters_total*100),
|
127 |
-
"correct_words_percent" : float(words_correct/words_total*100),
|
128 |
-
}
|
129 |
-
return info
|
|
|
16 |
crossword (Crossword): puzzle to solve
|
17 |
max_candidates (int): number of answer candidates to consider per clue
|
18 |
"""
|
19 |
+
def __init__(self, crossword, model_path, ans_tsv_path, dense_embd_path, max_candidates = 100, process_id = 0, model_type = 'bert'):
|
20 |
self.crossword = crossword
|
21 |
self.max_candidates = max_candidates
|
22 |
self.process_id = process_id
|
|
|
46 |
clue = clue[:match.start()] + self.crossword.variables[var]['clue'] + clue[match.end():]
|
47 |
all_clues[idx] = clue
|
48 |
|
|
|
49 |
# get predictions
|
50 |
dpr = setup_closedbook(self.model_path, self.ans_tsv_path, self.dense_embd_glob, self.process_id, self.model_type)
|
51 |
+
all_words, all_scores = answer_clues(dpr, all_clues, max_answers = self.max_candidates, output_strings=True)
|
52 |
+
|
53 |
for index, var in enumerate(self.crossword.variables):
|
54 |
length = len(self.crossword.variables[var]["gold"])
|
55 |
self.candidates[var] = {"words": [], "bit_array": None, "weights": {}}
|
|
|
63 |
keep_positions.append(word_index)
|
64 |
words = [words[i] for i in keep_positions]
|
65 |
scores = [scores[i] for i in keep_positions]
|
66 |
+
|
67 |
scores = list(-np.log(softmax(np.array(scores) / 0.75)))
|
68 |
|
69 |
for word, score in zip(words, scores):
|
70 |
self.candidates[var]["weights"][word] = score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
+
|
|
|
|
|
73 |
weights = self.candidates[var]["weights"]
|
74 |
self.candidates[var]["words"] = sorted(weights, key=weights.get)
|
75 |
self.candidates[var]["bit_array"] = np.zeros((len(chars), length, len(self.candidates[var]["words"])))
|
|
|
86 |
# cleanup a bit
|
87 |
del dpr
|
88 |
|
89 |
+
def evaluate(self, solution, print_log = True):
|
|
|
90 |
letters_correct = 0
|
91 |
letters_total = 0
|
92 |
for i in range(len(self.crossword.letter_grid)):
|
|
|
101 |
matching_cells = [self.crossword.letter_grid[cell[0]][cell[1]] == solution[cell[0]][cell[1]] for cell in cells]
|
102 |
if len(cells) == sum(matching_cells):
|
103 |
words_correct += 1
|
|
|
|
|
|
|
104 |
words_total += 1
|
105 |
+
|
106 |
+
letter_frac_log = "Letters Correct: {}/{} | Words Correct: {}/{}".format(int(letters_correct), int(letters_total), int(words_correct), int(words_total))
|
107 |
+
letter_acc_log = "Letters Correct: {}% | Words Correct: {}%".format(float(letters_correct/letters_total*100), float(words_correct/words_total*100))
|
108 |
+
|
109 |
+
if print_log:
|
110 |
+
print(letter_frac_log)
|
111 |
+
print(letter_acc_log)
|
112 |
|
113 |
+
return letter_frac_log, letter_acc_log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -7,9 +7,30 @@ from Strict_json import json_CA_json_converter
|
|
7 |
import asyncio
|
8 |
from fastapi.middleware.cors import CORSMiddleware
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
@@ -22,27 +43,35 @@ app.add_middleware(
|
|
22 |
)
|
23 |
|
24 |
async def solve_puzzle(json):
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
|
|
|
|
46 |
|
47 |
|
48 |
fifo_queue = asyncio.Queue()
|
@@ -55,8 +84,9 @@ async def worker():
|
|
55 |
job_id, job, args, future = await fifo_queue.get()
|
56 |
jobs[job_id]["status"] = "processing"
|
57 |
result = await job(*args)
|
|
|
58 |
jobs[job_id]["result"] = result
|
59 |
-
jobs[job_id]["status"] = "completed"
|
60 |
future.set_result(job_id)
|
61 |
|
62 |
@app.on_event("startup")
|
|
|
7 |
import asyncio
|
8 |
from fastapi.middleware.cors import CORSMiddleware
|
9 |
|
10 |
+
MODEL_CONFIG = {
|
11 |
+
'bert':
|
12 |
+
{
|
13 |
+
'MODEL_PATH' : "./Inference_components/dpr_biencoder_trained_EPOCH_2_COMPLETE.bin",
|
14 |
+
'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv",
|
15 |
+
'DENSE_EMBD_PATH': "./Inference_components/embeddings_BERT_EPOCH_2_COMPLETE0.pkl"
|
16 |
+
},
|
17 |
+
'distilbert':
|
18 |
+
{
|
19 |
+
'MODEL_PATH': "./Inference_components/distilbert_EPOCHs_7_COMPLETE.bin",
|
20 |
+
'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv",
|
21 |
+
'DENSE_EMBD_PATH': "./Inference_components/distilbert_7_epochs_embeddings.pkl"
|
22 |
+
},
|
23 |
+
't5_small':
|
24 |
+
{
|
25 |
+
'MODEL_PATH': './Inference_components/t5_small_new_dataset_2EPOCHS/'
|
26 |
+
}
|
27 |
+
}
|
28 |
+
|
29 |
+
choosen_model_path = MODEL_CONFIG['distilbert']['MODEL_PATH']
|
30 |
+
ans_list_path = MODEL_CONFIG['distilbert']['ANS_TSV_PATH']
|
31 |
+
dense_embedding_path = MODEL_CONFIG['distilbert']['DENSE_EMBD_PATH']
|
32 |
+
second_pass_model_path = MODEL_CONFIG['t5_small']['MODEL_PATH']
|
33 |
+
|
34 |
|
35 |
app = FastAPI()
|
36 |
|
|
|
43 |
)
|
44 |
|
45 |
async def solve_puzzle(json):
|
46 |
+
try:
|
47 |
+
|
48 |
+
puzzle = json_CA_json_converter(json, False)
|
49 |
+
crossword = Crossword(puzzle)
|
50 |
+
|
51 |
+
async def solve_async():
|
52 |
+
return await asyncio.to_thread(BPSolver, crossword,
|
53 |
+
model_path = choosen_model_path,
|
54 |
+
ans_tsv_path = ans_list_path,
|
55 |
+
dense_embd_path = dense_embedding_path,
|
56 |
+
reranker_path = second_pass_model_path,
|
57 |
+
max_candidates = 40000,
|
58 |
+
model_type = 'distilbert')
|
59 |
+
|
60 |
+
solver = await solve_async()
|
61 |
+
|
62 |
+
async def solve_method_async():
|
63 |
+
return await asyncio.to_thread(solver.solve,num_iters=60, iterative_improvement_steps=3)
|
64 |
|
65 |
+
solution = await solve_method_async()
|
66 |
+
|
67 |
+
evaluation1 = await asyncio.to_thread(solver.evaluate, solution['first pass model']['grid'])
|
68 |
+
evaluation2 = await asyncio.to_thread(solver.evaluate, solution['second pass model']['final grid'])
|
69 |
+
|
70 |
+
return solution['second pass model']['final grid'], evaluation1, solution['second pass model']['final grid'], evaluation2
|
71 |
|
72 |
+
except Exception as e:
|
73 |
+
print(f"An error occurred: {e}")
|
74 |
+
return None, None, None
|
75 |
|
76 |
|
77 |
fifo_queue = asyncio.Queue()
|
|
|
84 |
job_id, job, args, future = await fifo_queue.get()
|
85 |
jobs[job_id]["status"] = "processing"
|
86 |
result = await job(*args)
|
87 |
+
print(result)
|
88 |
jobs[job_id]["result"] = result
|
89 |
+
jobs[job_id]["status"] = "completed" if result[1] else "failed"
|
90 |
future.set_result(job_id)
|
91 |
|
92 |
@app.on_event("startup")
|