Spaces:
Running
Running
import re | |
from collections import defaultdict | |
import string | |
from scipy.special import softmax | |
import numpy as np | |
from Models_inf import answer_clues, setup_closedbook | |
class Solver: | |
""" | |
This class represents an abstraction over different types of crossword solvers. Each puzzle contains | |
a list of clues, which are associated with (weighted) values for each candidate answer. | |
Args: | |
crossword (Crossword): puzzle to solve | |
max_candidates (int): number of answer candidates to consider per clue | |
""" | |
def __init__(self, crossword, model_path, ans_tsv_path, dense_embd_path, max_candidates = 100, process_id = 0, model_type = 'bert'): | |
self.crossword = crossword | |
self.max_candidates = max_candidates | |
self.process_id = process_id | |
self.model_path = model_path | |
self.ans_tsv_path = ans_tsv_path | |
self.dense_embd_glob = dense_embd_path, | |
self.model_type = model_type | |
self.get_candidates() | |
def get_candidates(self): | |
# get answers from neural model and fill up data structures with the results | |
chars = string.ascii_uppercase | |
self.char_map = {char: idx for idx, char in enumerate(chars)} | |
self.candidates = {} | |
all_clues = [] | |
for var in self.crossword.variables: | |
all_clues.append(self.crossword.variables[var]['clue']) | |
# replaces stuff like "Opposite of 29-across" with "Opposite of X", where X is the clue for 29-across | |
r = re.compile('([0-9]+)[-\s](down|across)', re.IGNORECASE) | |
matches = [(idx, r.search(clue)) for idx, clue in enumerate(all_clues) if r.search(clue) != None] | |
for (idx, match) in matches: | |
clue = all_clues[idx] | |
var = str(match.group(1)) + str(match.group(2)[0]).upper() | |
if var in self.crossword.variables: | |
clue = clue[:match.start()] + self.crossword.variables[var]['clue'] + clue[match.end():] | |
all_clues[idx] = clue | |
# get predictions | |
dpr = setup_closedbook(self.model_path, self.ans_tsv_path, self.dense_embd_glob, self.process_id, self.model_type) | |
all_words, all_scores = answer_clues(dpr, all_clues, max_answers = self.max_candidates, output_strings=True) | |
for index, var in enumerate(self.crossword.variables): | |
length = len(self.crossword.variables[var]["gold"]) | |
self.candidates[var] = {"words": [], "bit_array": None, "weights": {}} | |
clue = all_clues[index] | |
words, scores = all_words[index], all_scores[index] | |
# remove answers that are not of the correct length | |
keep_positions = [] | |
for word_index, word in enumerate(words): | |
if len(word) == length: | |
keep_positions.append(word_index) | |
words = [words[i] for i in keep_positions] | |
scores = [scores[i] for i in keep_positions] | |
scores = list(-np.log(softmax(np.array(scores) / 0.75))) | |
for word, score in zip(words, scores): | |
self.candidates[var]["weights"][word] = score | |
weights = self.candidates[var]["weights"] | |
self.candidates[var]["words"] = sorted(weights, key=weights.get) | |
self.candidates[var]["bit_array"] = np.zeros((len(chars), length, len(self.candidates[var]["words"]))) | |
self.candidates[var]["single_query_cache"] = [defaultdict(lambda:[]) for _ in range(len(chars))] | |
self.candidates[var]["single_query_cache_indices"] = [defaultdict(lambda:[]) for _ in range(len(chars))] | |
for word_idx, word in enumerate(self.candidates[var]["words"]): | |
for pos_idx, char in enumerate(word): | |
char_idx = self.char_map[char] | |
self.candidates[var]["bit_array"][char_idx, pos_idx, word_idx] = 1 | |
self.candidates[var]["single_query_cache"][pos_idx][char].append(word) | |
self.candidates[var]["single_query_cache_indices"][pos_idx][char].append(word_idx) | |
# NOTE: TODO, it's possible to cache more here in exchange for doing more work at init time | |
# cleanup a bit | |
del dpr | |
def evaluate(self, solution, print_log = True): | |
letters_correct = 0 | |
letters_total = 0 | |
for i in range(len(self.crossword.letter_grid)): | |
for j in range(len(self.crossword.letter_grid[0])): | |
if self.crossword.letter_grid[i][j] != "": | |
letters_correct += (self.crossword.letter_grid[i][j] == solution[i][j]) | |
letters_total += 1 | |
words_correct = 0 | |
words_total = 0 | |
for var in self.crossword.variables: | |
cells = self.crossword.variables[var]["cells"] | |
matching_cells = [self.crossword.letter_grid[cell[0]][cell[1]] == solution[cell[0]][cell[1]] for cell in cells] | |
if len(cells) == sum(matching_cells): | |
words_correct += 1 | |
words_total += 1 | |
letter_frac_log = "Letters Correct: {}/{} | Words Correct: {}/{}".format(int(letters_correct), int(letters_total), int(words_correct), int(words_total)) | |
letter_acc_log = "Letters Correct: {}% | Words Correct: {}%".format(float(letters_correct/letters_total*100), float(words_correct/words_total*100)) | |
if print_log: | |
print(letter_frac_log) | |
print(letter_acc_log) | |
return letter_frac_log, letter_acc_log | |
def evaluate1(self, solution): | |
# print puzzle accuracy results given a generated solution | |
letters_correct = 0 | |
letters_total = 0 | |
for i in range(len(self.crossword.letter_grid)): | |
for j in range(len(self.crossword.letter_grid[0])): | |
if self.crossword.letter_grid[i][j] != "": | |
letters_correct += (self.crossword.letter_grid[i][j] == solution[i][j]) | |
letters_total += 1 | |
words_correct = 0 | |
words_total = 0 | |
for var in self.crossword.variables: | |
cells = self.crossword.variables[var]["cells"] | |
matching_cells = [self.crossword.letter_grid[cell[0]][cell[1]] == solution[cell[0]][cell[1]] for cell in cells] | |
if len(cells) == sum(matching_cells): | |
words_correct += 1 | |
else: | |
# print('evaluation: correct word', ''.join([self.crossword.letter_grid[cell[0]][cell[1]] for cell in cells]), 'our prediction:', ''.join([solution[cell[0]][cell[1]] for cell in cells])) | |
pass | |
words_total += 1 | |
print("Letters Correct: {}/{} | Words Correct: {}/{}".format(int(letters_correct), int(letters_total), int(words_correct), int(words_total))) | |
print("Letters Correct: {}% | Words Correct: {}%".format(float(letters_correct/letters_total*100), float(words_correct/words_total*100))) | |
info = { | |
"total_letters" : int(letters_total), | |
"total_words" : int(words_total), | |
"correct_letters" : int(letters_correct), | |
"correct_words" : int(words_correct), | |
"correct_letters_percent" : float(letters_correct/letters_total*100), | |
"correct_words_percent" : float(words_correct/words_total*100), | |
} | |
return info |