EZ-Crossword / BPSolver_inf.py
Ujjwal123's picture
Removed second pass and solved for british crossword
c8ed164
import math
import string
import re
from collections import defaultdict
from copy import deepcopy
import numpy as np
from scipy.special import log_softmax, softmax
from tqdm import trange
from Utils_inf import print_grid, get_word_flips
from Solver_inf import Solver
from Models_inf import setup_t5_reranker, t5_reranker_score_with_clue
# the probability of each alphabetical character in the crossword
UNIGRAM_PROBS = [('A', 0.0897379968935765), ('B', 0.02121248877769636), ('C', 0.03482206634145926), ('D', 0.03700942543460491), ('E', 0.1159773210750429), ('F', 0.017257461694024614), ('G', 0.025429024796296124), ('H', 0.033122967601502), ('I', 0.06800036223479956), ('J', 0.00294611331754349), ('K', 0.013860682888259786), ('L', 0.05130800574373874), ('M', 0.027962776827660175), ('N', 0.06631994270448001), ('O', 0.07374646543246745), ('P', 0.026750756212433214), ('Q', 0.001507814175439393), ('R', 0.07080460813737305), ('S', 0.07410988246048224), ('T', 0.07242993582154593), ('U', 0.0289272388037645), ('V', 0.009153522059555467), ('W', 0.01434705167591524), ('X', 0.003096729223103298), ('Y', 0.01749958208224007), ('Z', 0.002659777584995724)]
# the LETTER_SMOOTHING_FACTOR controls how much we interpolate with the unigram LM. TODO this should be tuned.
# Right now it is set according to the probability that the answer is not in the answer set
LETTER_SMOOTHING_FACTOR = [0.0, 0.0, 0.04395604395604396, 0.0001372495196266813, 0.0005752186417796561, 0.0019841824329989103, 0.0048042463338563764, 0.013325257419745608, 0.027154447774285505, 0.06513517299341645, 0.12527790128946198, 0.22003002358996354, 0.23172376584839494, 0.254873006497342, 0.3985086992543496, 0.2764976958525346, 0.672645739910314, 0.6818181818181818, 0.8571428571428571, 0.8245614035087719, 0.8, 0.71900826446281, 0.0]
class BPVar:
def __init__(self, name, variable, candidates, cells):
self.name = name # key from crossword.variables i.e. 1A, 2D, 3A
cells_by_position = {}
for cell in cells: # every cells or letter box that a particular variable or filling takes into consideration
cells_by_position[cell.position] = cell # cell.position (0,0) -> cell -> BPCell
cell._connect(self)
self.length = len(cells) # obviously the length of the answer
self.ordered_cells = [cells_by_position[pos] for pos in variable['cells']]
self.candidates = candidates
self.words = self.candidates['words']
self.word_indices = np.array([[string.ascii_uppercase.index(l) for l in fill] for fill in self.candidates['words']]) # words x length of letter indices
self.scores = -np.array([self.candidates['weights'][fill] for fill in self.candidates['words']]) # the incoming 'weights' are costs
self.prior_log_probs = log_softmax(self.scores)
self.log_probs = log_softmax(self.scores)
self.directional_scores = [np.zeros(len(self.log_probs)) for _ in range(len(self.ordered_cells))]
def _propagate_to_var(self, other, belief_state):
assert other in self.ordered_cells
other_idx = self.ordered_cells.index(other)
letter_scores = belief_state
self.directional_scores[other_idx] = letter_scores[self.word_indices[:, other_idx]]
def _postprocess(self, all_letter_probs):
# unigram smoothing
unigram_probs = np.array([x[1] for x in UNIGRAM_PROBS])
for i in range(len(all_letter_probs)):
all_letter_probs[i] = (1 - LETTER_SMOOTHING_FACTOR[self.length]) * all_letter_probs[i] + LETTER_SMOOTHING_FACTOR[self.length] * unigram_probs
return all_letter_probs
def sync_state(self):
self.log_probs = log_softmax(sum(self.directional_scores) + self.prior_log_probs)
def propagate(self):
all_letter_probs = []
for i in range(len(self.ordered_cells)):
word_scores = self.log_probs - self.directional_scores[i]
word_probs = softmax(word_scores)
letter_probs = (self.candidates['bit_array'][:, i] * np.expand_dims(word_probs, axis=0)).sum(axis=1) + 1e-8
all_letter_probs.append(letter_probs)
all_letter_probs = self._postprocess(all_letter_probs) # unigram postprocessing
for i, cell in enumerate(self.ordered_cells):
cell._propagate_to_cell(self, np.log(all_letter_probs[i]))
class BPCell:
def __init__(self, position, clue_pair):
self.crossing_clues = clue_pair
self.position = tuple(position)
self.letters = list(string.ascii_uppercase)
self.log_probs = np.log(np.array([1./len(self.letters) for _ in range(len(self.letters))]))
self.crossing_vars = []
self.directional_scores = []
self.prediction = {}
def _connect(self, other):
self.crossing_vars.append(other)
self.directional_scores.append(None)
# assert len(self.crossing_vars) <= 2
def _propagate_to_cell(self, other, belief_state):
assert other in self.crossing_vars
other_idx = self.crossing_vars.index(other)
self.directional_scores[other_idx] = belief_state
def sync_state(self):
self.log_probs = log_softmax(sum(self.directional_scores))
def propagate(self):
# assert len(self.crossing_vars) == 2
try:
for i, v in enumerate(self.crossing_vars):
v._propagate_to_var(self, self.directional_scores[1-i])
except IndexError:
pass
class BPSolver(Solver):
def __init__(self,
crossword,
model_path,
ans_tsv_path,
dense_embd_path,
reranker_path,
max_candidates = 100,
process_id = 0,
model_type = 'bert',
**kwargs):
super().__init__(crossword,
model_path,
ans_tsv_path,
dense_embd_path,
max_candidates = max_candidates,
process_id = process_id,
model_type = model_type,
**kwargs)
self.crossword = crossword
self.reranker_path = reranker_path
self.reranker_model_type = 't5-small'
# our answer set
self.answer_set = set()
with open(ans_tsv_path, 'r') as rf:
for line in rf:
w = ''.join([c.upper() for c in (line.split('\t')[-1]).upper() if c in string.ascii_uppercase])
self.answer_set.add(w)
self.reset()
def reset(self):
self.bp_cells = []
self.bp_cells_by_clue = defaultdict(lambda: [])
for position, clue_pair in self.crossword.grid_cells.items():
cell = BPCell(position, clue_pair)
self.bp_cells.append(cell)
for clue in clue_pair:
self.bp_cells_by_clue[clue].append(cell)
self.bp_vars = []
for key, value in self.crossword.variables.items():
# if key == '1A':
# print('-'*100)
# print(self.candidates[key]['words'])
# print(self.candidates[key]['bit_array'].shape)
# print(self.candidates[key]['weights'])
# print('-'*100)
var = BPVar(key, value, self.candidates[key], self.bp_cells_by_clue[key])
# print('*'*100)
# print(self.bp_cells_by_clue[key])
# print('*'*100)
self.bp_vars.append(var)
def extract_float(self, input_string):
pattern = r"\d+\.\d+"
matches = re.findall(pattern, input_string)
float_numbers = [float(match) for match in matches]
return float_numbers
def solve(self, num_iters=10, iterative_improvement_steps=5, return_greedy_states = False, return_ii_states = False):
# run solving for num_iters iterations
print('\nBeginning Belief Propagation Iteration Steps: ')
for _ in trange(num_iters):
for var in self.bp_vars:
var.propagate()
for cell in self.bp_cells:
cell.sync_state()
for cell in self.bp_cells:
cell.propagate()
for var in self.bp_vars:
var.sync_state()
print('Belief Propagation Iteration Complete\n')
# Get the current based grid based on greedy selection from the marginals
if return_greedy_states:
grid, all_grids = self.greedy_sequential_word_solution(return_grids = True)
else:
grid = self.greedy_sequential_word_solution()
all_grids = []
# properly save all the outputs results:
output_results = {}
output_results['first pass model'] = {}
output_results['first pass model']['grid'] = grid
# save first pass model grid, and letter accuracies
_, accu_log = self.evaluate(grid, False)
[ori_letter_accu, ori_word_accu] = self.extract_float(accu_log)
output_results['first pass model']['letter accuracy'] = ori_letter_accu
output_results['first pass model']['word accuracy'] = ori_word_accu
print("First pass model result was", grid,ori_letter_accu,ori_word_accu)
output_results['second pass model'] = {}
output_results['second pass model']['final grid'] = [] # just for the sake of the api
output_results['second pass model']['final grid'] = grid # just for the sake of the api
output_results['second pass model']['all grids'] = []
output_results['second pass model']['all letter accuracy'] = []
output_results['second pass model']['all word accuracy'] = []
if iterative_improvement_steps < 1 or ori_letter_accu == 100.0:
if return_greedy_states or return_ii_states:
return output_results, all_grids
else:
return output_results
'''
Iterative Improvement with t5-small starts from here.
'''
self.reranker, self.tokenizer = setup_t5_reranker(self.reranker_path, self.reranker_model_type)
for i in range(iterative_improvement_steps):
grid, did_iterative_improvement_make_edit = self.iterative_improvement(grid)
_, accu_log = self.evaluate(grid, False)
[temp_letter_accu, temp_word_accu] = self.extract_float(accu_log)
print(f"{i+1}th iteration: {accu_log}")
# save grid & accuracies at each iteration
output_results['second pass model']['all grids'].append(grid)
output_results['second pass model']['all letter accuracy'].append(temp_letter_accu)
output_results['second pass model']['all word accuracy'].append(temp_word_accu)
if not did_iterative_improvement_make_edit or temp_letter_accu == 100.0:
break
if return_ii_states:
all_grids.append(deepcopy(grid))
temp_lett_accu_list = output_results['second pass model']['all letter accuracy'].copy()
ii_max_index = temp_lett_accu_list.index(max(temp_lett_accu_list))
output_results['second pass model']['final grid'] = output_results['second pass model']['all grids'][ii_max_index]
output_results['second pass model']['final letter'] = output_results['second pass model']['all letter accuracy'][ii_max_index]
output_results['second pass model']['final word'] = output_results['second pass model']['all word accuracy'][ii_max_index]
if return_greedy_states or return_ii_states:
return output_results, all_grids
else:
return output_results
def get_candidate_replacements(self, uncertain_answers, grid):
# find alternate answers for all the uncertain words
candidate_replacements = []
replacement_id_set = set()
# check against dictionaries
for clue in uncertain_answers.keys():
initial_word = uncertain_answers[clue]
clue_flips = get_word_flips(initial_word, 10) # flip then segment
clue_positions = [key for key, value in self.crossword.variables.items() if value['clue'] == clue]
for clue_position in clue_positions:
cells = sorted([cell for cell in self.bp_cells if clue_position in cell.crossing_clues], key=lambda c: c.position)
if len(cells) == len(initial_word):
break
for flip in clue_flips:
if len(flip) != len(cells):
import pdb; pdb.set_trace()
assert len(flip) == len(cells)
for i in range(len(flip)):
if flip[i] != initial_word[i]:
candidate_replacements.append([(cells[i], flip[i])])
break
# also add candidates based on uncertainties in the letters, e.g., if we said P but G also had some probability, try G too
for cell_id, cell in enumerate(self.bp_cells):
probs = np.exp(cell.log_probs)
above_threshold = list(probs > 0.01)
new_characters = ['ABCDEFGHIJKLMNOPQRSTUVWXYZ'[i] for i in range(26) if above_threshold[i]]
# used = set()
# new_characters = [x for x in new_characters if x not in used and (used.add(x) or True)] # unique the set
new_characters = [x for x in new_characters if x != grid[cell.position[0]][cell.position[1]]] # ignore if its the same as the original solution
if len(new_characters) > 0:
for new_character in new_characters:
id = '_'.join([str(cell.position), new_character])
if id not in replacement_id_set:
candidate_replacements.append([(cell, new_character)])
replacement_id_set.add(id)
# create composite flips based on things in the same row/column
composite_replacements = []
for i in range(len(candidate_replacements)):
for j in range(i+1, len(candidate_replacements)):
flip1, flip2 = candidate_replacements[i], candidate_replacements[j]
if flip1[0][0] != flip2[0][0]:
if len(set(flip1[0][0].crossing_clues + flip2[0][0].crossing_clues)) < 4: # shared clue
composite_replacements.append(flip1 + flip2)
candidate_replacements += composite_replacements
#print('\ncandidate replacements')
for cr in candidate_replacements:
modified_grid = deepcopy(grid)
for cell, letter in cr:
modified_grid[cell.position[0]][cell.position[1]] = letter
variables = set(sum([cell.crossing_vars for cell, _ in cr], []))
for var in variables:
original_fill = ''.join([grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
modified_fill = ''.join([modified_grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
#print('original:', original_fill, 'modified:', modified_fill)
return candidate_replacements
def get_uncertain_answers(self, grid):
original_qa_pairs = {} # the original puzzle preds that we will try to improve
# first save what the argmax word-level prediction was for each grid cell just to make life easier
for var in self.crossword.variables:
# read the current word off the grid
cells = self.crossword.variables[var]["cells"]
word = []
for cell in cells:
word.append(grid[cell[0]][cell[1]])
word = ''.join(word)
for cell in self.bp_cells: # loop through all cells
if cell.position in cells: # if this cell is in the word we are currently handling
# save {clue, answer} pair into this cell
cell.prediction[self.crossword.variables[var]['clue']] = word
original_qa_pairs[self.crossword.variables[var]['clue']] = word
uncertain_answers = {}
# find uncertain answers
# right now the heuristic we use is any answer that is not in the answer set
for clue in original_qa_pairs.keys():
if original_qa_pairs[clue] not in self.answer_set:
uncertain_answers[clue] = original_qa_pairs[clue]
return uncertain_answers
def score_grid(self, grid):
clues = []
answers = []
for clue, cells in self.bp_cells_by_clue.items():
letters = ''.join([grid[cell.position[0]][cell.position[1]] for cell in sorted(list(cells), key=lambda c: c.position)])
clues.append(self.crossword.variables[clue]['clue'])
answers.append(letters)
scores = t5_reranker_score_with_clue(self.reranker, self.tokenizer, self.reranker_model_type, clues, answers)
return sum(scores)
def greedy_sequential_word_solution(self, return_grids = False):
all_grids = []
# after we've run BP, we run a simple greedy search to get the final.
# We repeatedly pick the highest-log-prob candidate across all clues which fits the grid, and fill it.
# at the end, if you have any cells left (due to missing gold candidates) just fill it with the argmax on that letter.
cache = [(deepcopy(var.words), deepcopy(var.log_probs)) for var in self.bp_vars]
grid = [["" for _ in row] for row in self.crossword.letter_grid]
unfilled_cells = set([cell.position for cell in self.bp_cells])
for var in self.bp_vars:
# postprocess log probs to estimate probability that you don't have the right candidate
var.log_probs = var.log_probs + math.log(1 - LETTER_SMOOTHING_FACTOR[var.length])
best_per_var = [var.log_probs.max() for var in self.bp_vars]
while not all([x is None for x in best_per_var]):
all_grids.append(deepcopy(grid))
best_index = best_per_var.index(max([x for x in best_per_var if x is not None]))
best_var = self.bp_vars[best_index]
best_word = best_var.words[best_var.log_probs.argmax()]
for i, cell in enumerate(best_var.ordered_cells):
letter = best_word[i]
grid[cell.position[0]][cell.position[1]] = letter
if cell.position in unfilled_cells:
unfilled_cells.remove(cell.position)
for var in cell.crossing_vars:
if var != best_var:
cell_index = var.ordered_cells.index(cell)
keep_indices = [j for j in range(len(var.words)) if var.words[j][cell_index] == letter]
var.words = [var.words[j] for j in keep_indices]
var.log_probs = var.log_probs[keep_indices]
var_index = self.bp_vars.index(var)
if len(keep_indices) > 0:
best_per_var[var_index] = var.log_probs.max()
else:
best_per_var[var_index] = None
best_var.words = []
best_var.log_probs = best_var.log_probs[[]]
best_per_var[best_index] = None
unfilled_cells_count = 0
for cell in self.bp_cells:
if cell.position in unfilled_cells:
unfilled_cells_count += 1
grid[cell.position[0]][cell.position[1]] = string.ascii_uppercase[cell.log_probs.argmax()]
for var, (words, log_probs) in zip(self.bp_vars, cache): # restore state
var.words = words
var.log_probs = log_probs
if return_grids:
return grid, all_grids
else:
return grid
def iterative_improvement(self, grid):
# check the grid for uncertain areas and save those words to be analyzed in local search, aka looking for alternate candidates
uncertain_answers = self.get_uncertain_answers(grid)
self.candidate_replacements = self.get_candidate_replacements(uncertain_answers, grid)
# print('\nstarting iterative improvement')
original_grid_score = self.score_grid(grid)
possible_edits = []
for replacements in self.candidate_replacements:
modified_grid = deepcopy(grid)
for cell, letter in replacements:
modified_grid[cell.position[0]][cell.position[1]] = letter
modified_grid_score = self.score_grid(modified_grid)
# print('candidate edit')
variables = set(sum([cell.crossing_vars for cell, _ in replacements], []))
for var in variables:
original_fill = ''.join([grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
modified_fill = ''.join([modified_grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
clue_index = list(set(var.ordered_cells[0].crossing_clues).intersection(*[set(cell.crossing_clues) for cell in var.ordered_cells]))[0]
# print('original:', original_fill, 'modified:', modified_fill)
# print('gold answer', self.crossword.variables[clue_index]['gold'])
# print('clue', self.crossword.variables[clue_index]['clue'])
# print('original score:', original_grid_score, 'modified score:', modified_grid_score)
if modified_grid_score - original_grid_score > 0.5:
# print('found a possible edit')
possible_edits.append((modified_grid, modified_grid_score, replacements))
# print()
if len(possible_edits) > 0:
variables_modified = set()
possible_edits = sorted(possible_edits, key=lambda x: x[1], reverse=True)
selected_edits = []
for edit in possible_edits:
replacements = edit[2]
variables = set(sum([cell.crossing_vars for cell, _ in replacements], []))
if len(variables_modified.intersection(variables)) == 0: # we can do multiple updates at once if they don't share clues
variables_modified.update(variables)
selected_edits.append(edit)
new_grid = deepcopy(grid)
for edit in selected_edits:
# print('\nactually applying edit')
replacements = edit[2]
for cell, letter in replacements:
new_grid[cell.position[0]][cell.position[1]] = letter
variables = set(sum([cell.crossing_vars for cell, _ in replacements], []))
for var in variables:
original_fill = ''.join([grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
modified_fill = ''.join([new_grid[cell.position[0]][cell.position[1]] for cell in var.ordered_cells])
# print('original:', original_fill, 'modified:', modified_fill)
return new_grid, True
else:
return grid, False