Ujjwal123 commited on
Commit
a04b340
1 Parent(s): acb82e9

copied the whole api code from django and updated the dockerfile

Browse files
.gitignore CHANGED
@@ -1,3 +1,4 @@
 
1
  *.png
2
  *.jpg
3
  *.mp4
 
1
+ Inference_components/
2
  *.png
3
  *.jpg
4
  *.mp4
BPSolver_inf.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import string
3
+ from collections import defaultdict
4
+ from copy import deepcopy
5
+
6
+ import numpy as np
7
+ from scipy.special import log_softmax, softmax
8
+ from tqdm import trange
9
+
10
+ from Utils_inf import print_grid, get_word_flips
11
+ from Solver_inf import Solver
12
+ # the probability of each alphabetical character in the crossword
13
+ UNIGRAM_PROBS = [('A', 0.0897379968935765), ('B', 0.02121248877769636), ('C', 0.03482206634145926), ('D', 0.03700942543460491), ('E', 0.1159773210750429), ('F', 0.017257461694024614), ('G', 0.025429024796296124), ('H', 0.033122967601502), ('I', 0.06800036223479956), ('J', 0.00294611331754349), ('K', 0.013860682888259786), ('L', 0.05130800574373874), ('M', 0.027962776827660175), ('N', 0.06631994270448001), ('O', 0.07374646543246745), ('P', 0.026750756212433214), ('Q', 0.001507814175439393), ('R', 0.07080460813737305), ('S', 0.07410988246048224), ('T', 0.07242993582154593), ('U', 0.0289272388037645), ('V', 0.009153522059555467), ('W', 0.01434705167591524), ('X', 0.003096729223103298), ('Y', 0.01749958208224007), ('Z', 0.002659777584995724)]
14
+
15
+ # the LETTER_SMOOTHING_FACTOR controls how much we interpolate with the unigram LM. TODO this should be tuned.
16
+ # Right now it is set according to the probability that the answer is not in the answer set
17
+ LETTER_SMOOTHING_FACTOR = [0.0, 0.0, 0.04395604395604396, 0.0001372495196266813, 0.0005752186417796561, 0.0019841824329989103, 0.0048042463338563764, 0.013325257419745608, 0.027154447774285505, 0.06513517299341645, 0.12527790128946198, 0.22003002358996354, 0.23172376584839494, 0.254873006497342, 0.3985086992543496, 0.2764976958525346, 0.672645739910314, 0.6818181818181818, 0.8571428571428571, 0.8245614035087719, 0.8, 0.71900826446281, 0.0]
18
+
19
+ class BPVar:
20
+ def __init__(self, name, variable, candidates, cells):
21
+ self.name = name
22
+ cells_by_position = {}
23
+ for cell in cells:
24
+ cells_by_position[cell.position] = cell
25
+ cell._connect(self)
26
+ self.length = len(cells)
27
+ self.ordered_cells = [cells_by_position[pos] for pos in variable['cells']]
28
+ self.candidates = candidates
29
+ self.words = self.candidates['words']
30
+ self.word_indices = np.array([[string.ascii_uppercase.index(l) for l in fill] for fill in self.candidates['words']]) # words x length of letter indices
31
+ self.scores = -np.array([self.candidates['weights'][fill] for fill in self.candidates['words']]) # the incoming 'weights' are costs
32
+ self.prior_log_probs = log_softmax(self.scores)
33
+ self.log_probs = log_softmax(self.scores)
34
+ self.directional_scores = [np.zeros(len(self.log_probs)) for _ in range(len(self.ordered_cells))]
35
+
36
+ def _propagate_to_var(self, other, belief_state):
37
+ assert other in self.ordered_cells
38
+ other_idx = self.ordered_cells.index(other)
39
+ letter_scores = belief_state
40
+ self.directional_scores[other_idx] = letter_scores[self.word_indices[:, other_idx]]
41
+
42
+ def _postprocess(self, all_letter_probs):
43
+ # unigram smoothing
44
+ unigram_probs = np.array([x[1] for x in UNIGRAM_PROBS])
45
+ for i in range(len(all_letter_probs)):
46
+ all_letter_probs[i] = (1 - LETTER_SMOOTHING_FACTOR[self.length]) * all_letter_probs[i] + LETTER_SMOOTHING_FACTOR[self.length] * unigram_probs
47
+ return all_letter_probs
48
+
49
+ def sync_state(self):
50
+ self.log_probs = log_softmax(sum(self.directional_scores) + self.prior_log_probs)
51
+
52
+ def propagate(self):
53
+ all_letter_probs = []
54
+ for i in range(len(self.ordered_cells)):
55
+ word_scores = self.log_probs - self.directional_scores[i]
56
+ word_probs = softmax(word_scores)
57
+ letter_probs = (self.candidates['bit_array'][:, i] * np.expand_dims(word_probs, axis=0)).sum(axis=1) + 1e-8
58
+ all_letter_probs.append(letter_probs)
59
+ all_letter_probs = self._postprocess(all_letter_probs) # unigram postprocessing
60
+ for i, cell in enumerate(self.ordered_cells):
61
+ cell._propagate_to_cell(self, np.log(all_letter_probs[i]))
62
+
63
+
64
+ class BPCell:
65
+ def __init__(self, position, clue_pair):
66
+ self.crossing_clues = clue_pair
67
+ self.position = tuple(position)
68
+ self.letters = list(string.ascii_uppercase)
69
+ self.log_probs = np.log(np.array([1./len(self.letters) for _ in range(len(self.letters))]))
70
+ self.crossing_vars = []
71
+ self.directional_scores = []
72
+ self.prediction = {}
73
+
74
+ def _connect(self, other):
75
+ self.crossing_vars.append(other)
76
+ self.directional_scores.append(None)
77
+ assert len(self.crossing_vars) <= 2
78
+
79
+ def _propagate_to_cell(self, other, belief_state):
80
+ assert other in self.crossing_vars
81
+ other_idx = self.crossing_vars.index(other)
82
+ self.directional_scores[other_idx] = belief_state
83
+
84
+ def sync_state(self):
85
+ self.log_probs = log_softmax(sum(self.directional_scores))
86
+
87
+ def propagate(self):
88
+ assert len(self.crossing_vars) == 2
89
+ for i, v in enumerate(self.crossing_vars):
90
+ v._propagate_to_var(self, self.directional_scores[1-i])
91
+
92
+
93
+ class BPSolver(Solver):
94
+ def __init__(self,
95
+ crossword,
96
+ model_path,
97
+ ans_tsv_path,
98
+ dense_embd_path,
99
+ max_candidates = 5000,
100
+ process_id = 0,
101
+ model_type = 'bert',
102
+ **kwargs):
103
+ super().__init__(crossword,
104
+ model_path,
105
+ ans_tsv_path,
106
+ dense_embd_path,
107
+ max_candidates = max_candidates,
108
+ process_id = process_id,
109
+ model_type = model_type,
110
+ **kwargs)
111
+ self.crossword = crossword
112
+
113
+ # our answer set
114
+ self.answer_set = set()
115
+ with open(ans_tsv_path, 'r') as rf:
116
+ for line in rf:
117
+ w = ''.join([c.upper() for c in (line.split('\t')[-1]).upper() if c in string.ascii_uppercase])
118
+ self.answer_set.add(w)
119
+ self.reset()
120
+
121
+ def reset(self):
122
+ self.bp_cells = []
123
+ self.bp_cells_by_clue = defaultdict(lambda: [])
124
+ for position, clue_pair in self.crossword.grid_cells.items():
125
+ cell = BPCell(position, clue_pair)
126
+ self.bp_cells.append(cell)
127
+ for clue in clue_pair:
128
+ self.bp_cells_by_clue[clue].append(cell)
129
+ self.bp_vars = []
130
+ for key, value in self.crossword.variables.items():
131
+ var = BPVar(key, value, self.candidates[key], self.bp_cells_by_clue[key])
132
+ self.bp_vars.append(var)
133
+
134
+ def solve(self, num_iters=10, iterative_improvement_steps=5, return_greedy_states = False, return_ii_states = False):
135
+ # run solving for num_iters iterations
136
+ print('beginning BP iterations')
137
+ for _ in trange(num_iters):
138
+ for var in self.bp_vars:
139
+ var.propagate()
140
+ for cell in self.bp_cells:
141
+ cell.sync_state()
142
+ for cell in self.bp_cells:
143
+ cell.propagate()
144
+ for var in self.bp_vars:
145
+ var.sync_state()
146
+ print('done BP iterations')
147
+
148
+ # Get the current based grid based on greedy selection from the marginals
149
+ if return_greedy_states:
150
+ grid, all_grids = self.greedy_sequential_word_solution(return_grids = True)
151
+ else:
152
+ grid = self.greedy_sequential_word_solution()
153
+ all_grids = []
154
+ grid = self.greedy_sequential_word_solution()
155
+ # print('=====Greedy search grid=====')
156
+ # print_grid(grid)
157
+
158
+ if iterative_improvement_steps < 1:
159
+ if return_greedy_states or return_ii_states:
160
+ return grid, all_grids
161
+ else:
162
+ return grid
163
+
164
+ def greedy_sequential_word_solution(self, return_grids = False):
165
+ all_grids = []
166
+ # after we've run BP, we run a simple greedy search to get the final.
167
+ # We repeatedly pick the highest-log-prob candidate across all clues which fits the grid, and fill it.
168
+ # at the end, if you have any cells left (due to missing gold candidates) just fill it with the argmax on that letter.
169
+ cache = [(deepcopy(var.words), deepcopy(var.log_probs)) for var in self.bp_vars]
170
+
171
+ grid = [["" for _ in row] for row in self.crossword.letter_grid]
172
+ unfilled_cells = set([cell.position for cell in self.bp_cells])
173
+ for var in self.bp_vars:
174
+ # postprocess log probs to estimate probability that you don't have the right candidate
175
+ var.log_probs = var.log_probs + math.log(1 - LETTER_SMOOTHING_FACTOR[var.length])
176
+ best_per_var = [var.log_probs.max() for var in self.bp_vars]
177
+ while not all([x is None for x in best_per_var]):
178
+ all_grids.append(deepcopy(grid))
179
+ best_index = best_per_var.index(max([x for x in best_per_var if x is not None]))
180
+ best_var = self.bp_vars[best_index]
181
+ best_word = best_var.words[best_var.log_probs.argmax()]
182
+ # print('greedy filling in', best_word)
183
+ for i, cell in enumerate(best_var.ordered_cells):
184
+ letter = best_word[i]
185
+ grid[cell.position[0]][cell.position[1]] = letter
186
+ if cell.position in unfilled_cells:
187
+ unfilled_cells.remove(cell.position)
188
+ for var in cell.crossing_vars:
189
+ if var != best_var:
190
+ cell_index = var.ordered_cells.index(cell)
191
+ keep_indices = [j for j in range(len(var.words)) if var.words[j][cell_index] == letter]
192
+ var.words = [var.words[j] for j in keep_indices]
193
+ var.log_probs = var.log_probs[keep_indices]
194
+ var_index = self.bp_vars.index(var)
195
+ if len(keep_indices) > 0:
196
+ best_per_var[var_index] = var.log_probs.max()
197
+ else:
198
+ best_per_var[var_index] = None
199
+ best_var.words = []
200
+ best_var.log_probs = best_var.log_probs[[]]
201
+ best_per_var[best_index] = None
202
+ for cell in self.bp_cells:
203
+ if cell.position in unfilled_cells:
204
+ grid[cell.position[0]][cell.position[1]] = string.ascii_uppercase[cell.log_probs.argmax()]
205
+
206
+ for var, (words, log_probs) in zip(self.bp_vars, cache): # restore state
207
+ var.words = words
208
+ var.log_probs = log_probs
209
+ if return_grids:
210
+ return grid, all_grids
211
+ else:
212
+ return grid
Crossword_inf.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Utils_inf import clean
2
+
3
+ class Crossword:
4
+ def __init__(self, data):
5
+ self.initialize_grids(grid=data["grid"])
6
+ self.initialize_clues(clues=data["clues"])
7
+ self.initialize_variables()
8
+
9
+ def initialize_grids(self, grid):
10
+ self.letter_grid = [[grid[j][i][1] if type(grid[j][i]) == list else "" for i in
11
+ range(len(grid[0]))] for j in range(len(grid))]
12
+ self.number_grid = [[grid[j][i][0] if type(grid[j][i]) == list else "" for i in
13
+ range(len(grid[0]))] for j in range(len(grid))]
14
+ self.grid_cells = {}
15
+
16
+ def initialize_clues(self, clues):
17
+ self.across = clues["across"]
18
+ self.down = clues["down"]
19
+
20
+ def initialize_variable(self, position, clues, across=True):
21
+ row, col = position
22
+ cell_number = self.number_grid[row][col]
23
+ assert cell_number in clues, print("Missing clue")
24
+ word_id = cell_number + "A" if across else cell_number + "D"
25
+ clue = clean(clues[cell_number][0])
26
+ answer = clean(clues[cell_number][1])
27
+ for idx in range(len(answer)):
28
+ cell = (row, col + idx) if across else (row + idx, col)
29
+ if cell in self.grid_cells:
30
+ self.grid_cells[cell].append(word_id)
31
+ else:
32
+ self.grid_cells[cell] = [word_id]
33
+ if word_id in self.variables:
34
+ self.variables[word_id]["cells"].append(cell)
35
+ else:
36
+ self.variables[word_id] = {"clue": clue, "gold": answer, "cells": [cell], "crossing": []}
37
+
38
+ def initialize_crossing(self):
39
+ for word_id in self.variables:
40
+ cells = self.variables[word_id]["cells"]
41
+ crossing_ids = []
42
+ for cell in cells:
43
+ crossing_ids += list(filter(lambda x: x!= word_id, self.grid_cells[cell]))
44
+ self.variables[word_id]["crossing"] = crossing_ids
45
+
46
+ def initialize_variables(self):
47
+ self.variables = {}
48
+ for row in range(len(self.number_grid)):
49
+ for col in range(len(self.number_grid[0])):
50
+ cell_number = self.number_grid[row][col]
51
+ if self.number_grid[row][col] != "":
52
+ if cell_number in self.across:
53
+ self.initialize_variable((row, col), self.across, across=True)
54
+ if cell_number in self.down:
55
+ self.initialize_variable((row, col), self.down, across=False)
56
+ self.initialize_crossing()
Data_utils_inf.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import math
4
+ import pickle
5
+ import random
6
+ from typing import List, Iterator, Callable
7
+
8
+ from torch import Tensor as T
9
+
10
+ logger = logging.getLogger()
11
+
12
+ def read_serialized_data_from_files(paths: List[str]) -> List:
13
+ results = []
14
+ for i, path in enumerate(paths):
15
+ with open(path, "rb") as reader:
16
+ logger.info("Reading file %s", path)
17
+ data = pickle.load(reader)
18
+ results.extend(data)
19
+ logger.info("Aggregated data size: {}".format(len(results)))
20
+ logger.info("Total data size: {}".format(len(results)))
21
+ return results
22
+
23
+ def read_data_from_json_files(paths: List[str], upsample_rates: List = None) -> List:
24
+ results = []
25
+ if upsample_rates is None:
26
+ upsample_rates = [1] * len(paths)
27
+
28
+ assert len(upsample_rates) == len(
29
+ paths
30
+ ), "up-sample rates parameter doesn't match input files amount"
31
+
32
+ for i, path in enumerate(paths):
33
+ with open(path, "r", encoding="utf-8") as f:
34
+ logger.info("Reading file %s" % path)
35
+ data = json.load(f)
36
+ upsample_factor = int(upsample_rates[i])
37
+ data = data * upsample_factor
38
+ results.extend(data)
39
+ logger.info("Aggregated data size: {}".format(len(results)))
40
+ return results
41
+
42
+
43
+ class ShardedDataIterator(object):
44
+ """
45
+ General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of
46
+ the data.
47
+ Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size.
48
+ It fills the extra sample by just taking first samples in a shard.
49
+ It can also optionally enforce identical batch size for all iterations (might be useful for DP mode).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ data: list,
55
+ shard_id: int = 0,
56
+ num_shards: int = 1,
57
+ batch_size: int = 1,
58
+ shuffle=True,
59
+ shuffle_seed: int = 0,
60
+ offset: int = 0,
61
+ strict_batch_size: bool = False,
62
+ ):
63
+
64
+ self.data = data
65
+ total_size = len(data)
66
+
67
+ self.shards_num = max(num_shards, 1)
68
+ self.shard_id = max(shard_id, 0)
69
+
70
+ samples_per_shard = math.ceil(total_size / self.shards_num)
71
+
72
+ self.shard_start_idx = self.shard_id * samples_per_shard
73
+
74
+ self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size)
75
+
76
+ if strict_batch_size:
77
+ self.max_iterations = math.ceil(samples_per_shard / batch_size)
78
+ else:
79
+ self.max_iterations = int(samples_per_shard / batch_size)
80
+
81
+ logger.debug(
82
+ "samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d",
83
+ samples_per_shard,
84
+ self.shard_start_idx,
85
+ self.shard_end_idx,
86
+ self.max_iterations,
87
+ )
88
+
89
+ self.iteration = offset # to track in-shard iteration status
90
+ self.shuffle = shuffle
91
+ self.batch_size = batch_size
92
+ self.shuffle_seed = shuffle_seed
93
+ self.strict_batch_size = strict_batch_size
94
+
95
+ def total_data_len(self) -> int:
96
+ return len(self.data)
97
+
98
+ def iterate_data(self, epoch: int = 0) -> Iterator[List]:
99
+ if self.shuffle:
100
+ # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration
101
+ epoch_rnd = random.Random(self.shuffle_seed + epoch)
102
+ epoch_rnd.shuffle(self.data)
103
+
104
+ # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations
105
+
106
+ max_iterations = self.max_iterations - self.iteration
107
+
108
+ shard_samples = self.data[self.shard_start_idx : self.shard_end_idx]
109
+ for i in range(
110
+ self.iteration * self.batch_size, len(shard_samples), self.batch_size
111
+ ):
112
+ items = shard_samples[i : i + self.batch_size]
113
+ if self.strict_batch_size and len(items) < self.batch_size:
114
+ logger.debug("Extending batch to max size")
115
+ items.extend(shard_samples[0 : self.batch_size - len(items)])
116
+ self.iteration += 1
117
+ yield items
118
+
119
+ # some shards may done iterating while the others are at the last batch. Just return the first batch
120
+ while self.iteration < max_iterations:
121
+ logger.debug("Fulfilling non complete shard=".format(self.shard_id))
122
+ self.iteration += 1
123
+ batch = shard_samples[0 : self.batch_size]
124
+ yield batch
125
+
126
+ logger.debug(
127
+ "Finished iterating, iteration={}, shard={}".format(
128
+ self.iteration, self.shard_id
129
+ )
130
+ )
131
+ # reset the iteration status
132
+ self.iteration = 0
133
+
134
+ def get_iteration(self) -> int:
135
+ return self.iteration
136
+
137
+ def apply(self, visitor_func: Callable):
138
+ for sample in self.data:
139
+ visitor_func(sample)
140
+
141
+
142
+ def normalize_question(question: str) -> str:
143
+ if question[-1] == "?":
144
+ question = question[:-1]
145
+ return question
146
+
147
+
148
+ class Tensorizer(object):
149
+ """
150
+ Component for all text to model input data conversions and related utility methods
151
+ """
152
+
153
+ # Note: title, if present, is supposed to be put before text (i.e. optional title + document body)
154
+ def text_to_tensor(
155
+ self, text: str, title: str = None, add_special_tokens: bool = True
156
+ ):
157
+ raise NotImplementedError
158
+
159
+ def get_pair_separator_ids(self) -> T:
160
+ raise NotImplementedError
161
+
162
+ def get_pad_id(self) -> int:
163
+ raise NotImplementedError
164
+
165
+ def get_attn_mask(self, tokens_tensor: T):
166
+ raise NotImplementedError
167
+
168
+ def is_sub_word_id(self, token_id: int):
169
+ raise NotImplementedError
170
+
171
+ def to_string(self, token_ids, skip_special_tokens=True):
172
+ raise NotImplementedError
173
+
174
+ def set_pad_to_max(self, pad: bool):
175
+ raise NotImplementedError
Dockerfile CHANGED
@@ -20,4 +20,10 @@ WORKDIR $HOME/app
20
 
21
  COPY --chown=user . $HOME/app/
22
 
 
 
 
 
 
 
23
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
20
 
21
  COPY --chown=user . $HOME/app/
22
 
23
+ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/blob/main/all_answer_list.tsv $HOME/app/Inference_components/
24
+ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/blob/main/distilbert_7_epochs_embeddings.pkl $HOME/app/Inference_components/
25
+ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/blob/main/distilbert_EPOCHs_7_COMPLETE.bin $HOME/app/Inference_components/
26
+ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/blob/main/dpr_biencoder_trained_500k.bin $HOME/app/Inference_components/
27
+ ADD --chown=user https://huggingface.co/prajesh069/clue-answer.multi-answer-scoring.dual-bert-encoder/blob/main/embeddings_all_answers_json_0.pkl $HOME/app/Inference_components/
28
+
29
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
Faiss_Indexers_inf.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import pickle
5
+ from typing import List, Tuple, Iterator
6
+
7
+ import faiss
8
+ import numpy as np
9
+
10
+
11
+ logger = logging.getLogger()
12
+
13
+
14
+ class DenseIndexer(object):
15
+ def __init__(self, buffer_size: int = 50000):
16
+ self.buffer_size = buffer_size
17
+ self.index_id_to_db_id = []
18
+ self.index = None
19
+
20
+ def index_data(self, vector_files: List[str]):
21
+ start_time = time.time()
22
+ buffer = []
23
+ for i, item in enumerate(iterate_encoded_files(vector_files)):
24
+ db_id, doc_vector = item
25
+ buffer.append((db_id, doc_vector))
26
+ if 0 < self.buffer_size == len(buffer):
27
+ # indexing in batches is beneficial for many faiss index types
28
+ self._index_batch(buffer)
29
+ logger.info(
30
+ "data indexed %d, used_time: %f sec.",
31
+ len(self.index_id_to_db_id),
32
+ time.time() - start_time,
33
+ )
34
+ buffer = []
35
+ self._index_batch(buffer)
36
+
37
+ indexed_cnt = len(self.index_id_to_db_id)
38
+ logger.info("Total data indexed %d", indexed_cnt)
39
+ logger.info("Data indexing completed.")
40
+
41
+ def _index_batch(self, data: List[Tuple[object, np.array]]):
42
+ raise NotImplementedError
43
+
44
+ def search_knn(
45
+ self, query_vectors: np.array, top_docs: int
46
+ ) -> List[Tuple[List[object], List[float]]]:
47
+ raise NotImplementedError
48
+
49
+ def serialize(self, file: str):
50
+ logger.info("Serializing index to %s", file)
51
+
52
+ if os.path.isdir(file):
53
+ index_file = os.path.join(file, "index.dpr")
54
+ meta_file = os.path.join(file, "index_meta.dpr")
55
+ else:
56
+ index_file = file + ".index.dpr"
57
+ meta_file = file + ".index_meta.dpr"
58
+
59
+ faiss.write_index(self.index, index_file)
60
+ with open(meta_file, mode="wb") as f:
61
+ pickle.dump(self.index_id_to_db_id, f)
62
+
63
+ def deserialize_from(self, file: str):
64
+ logger.info("Loading index from %s", file)
65
+
66
+ if os.path.isdir(file):
67
+ index_file = os.path.join(file, "index.dpr")
68
+ meta_file = os.path.join(file, "index_meta.dpr")
69
+ else:
70
+ index_file = file + ".index.dpr"
71
+ meta_file = file + ".index_meta.dpr"
72
+
73
+ self.index = faiss.read_index(index_file)
74
+ logger.info(
75
+ "Loaded index of type %s and size %d", type(self.index), self.index.ntotal
76
+ )
77
+
78
+ with open(meta_file, "rb") as reader:
79
+ self.index_id_to_db_id = pickle.load(reader)
80
+ assert (
81
+ len(self.index_id_to_db_id) == self.index.ntotal
82
+ ), "Deserialized index_id_to_db_id should match faiss index size"
83
+
84
+ def _update_id_mapping(self, db_ids: List):
85
+ self.index_id_to_db_id.extend(db_ids)
86
+
87
+
88
+ class DenseFlatIndexer(DenseIndexer):
89
+ def __init__(self, vector_sz: int, buffer_size: int = 50000):
90
+ super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size)
91
+ #res = faiss.StandardGpuResources()
92
+ #cpu_index = faiss.IndexFlatIP(vector_sz)
93
+ #self.index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
94
+ self.index = faiss.IndexFlatIP(vector_sz)
95
+ self.all_vectors = None
96
+
97
+ def _index_batch(self, data: List[Tuple[object, np.array]]):
98
+ db_ids = [t[0] for t in data]
99
+ vectors = [np.reshape(t[1], (1, -1)) for t in data]
100
+ vectors = np.concatenate(vectors, axis=0)
101
+ self._update_id_mapping(db_ids)
102
+ self.index.add(vectors)
103
+ #if self.all_vectors is None:
104
+ # self.all_vectors = vectors
105
+ #else:
106
+ # self.all_vectors = np.concatenate((self.all_vectors, vectors), axis=0)
107
+
108
+ def search_knn(
109
+ self, query_vectors: np.array, top_docs: int
110
+ ) -> List[Tuple[List[object], List[float]]]:
111
+ scores, indexes = self.index.search(query_vectors, top_docs)
112
+ # convert to external ids
113
+ db_ids = [
114
+ [self.index_id_to_db_id[i] for i in query_top_idxs]
115
+ for query_top_idxs in indexes
116
+ ]
117
+ result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
118
+ return result
119
+
120
+
121
+ class DenseHNSWFlatIndexer(DenseIndexer):
122
+ """
123
+ Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ vector_sz: int,
129
+ buffer_size: int = 50000,
130
+ store_n: int = 512,
131
+ ef_search: int = 128,
132
+ ef_construction: int = 200,
133
+ ):
134
+ super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size)
135
+
136
+ # IndexHNSWFlat supports L2 similarity only
137
+ # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension
138
+ index = faiss.IndexHNSWFlat(vector_sz + 1, store_n)
139
+ index.hnsw.efSearch = ef_search
140
+ index.hnsw.efConstruction = ef_construction
141
+ self.index = index
142
+ self.phi = None
143
+
144
+ def index_data(self, vector_files: List[str]):
145
+ self._set_phi(vector_files)
146
+ super(DenseHNSWFlatIndexer, self).index_data(vector_files)
147
+
148
+ def _set_phi(self, vector_files: List[str]):
149
+ """
150
+ Calculates the max norm from the whole data and assign it to self.phi: necessary to transform IP -> L2 space
151
+ :param vector_files: file names to get passages vectors from
152
+ :return:
153
+ """
154
+ phi = 0
155
+ for i, item in enumerate(iterate_encoded_files(vector_files)):
156
+ id, doc_vector = item
157
+ norms = (doc_vector ** 2).sum()
158
+ phi = max(phi, norms)
159
+ logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi))
160
+ self.phi = phi
161
+
162
+ def _index_batch(self, data: List[Tuple[object, np.array]]):
163
+ # max norm is required before putting all vectors in the index to convert inner product similarity to L2
164
+ if self.phi is None:
165
+ raise RuntimeError(
166
+ "Max norm needs to be calculated from all data at once,"
167
+ "results will be unpredictable otherwise."
168
+ "Run `_set_phi()` before calling this method."
169
+ )
170
+
171
+ db_ids = [t[0] for t in data]
172
+ vectors = [np.reshape(t[1], (1, -1)) for t in data]
173
+
174
+ norms = [(doc_vector ** 2).sum() for doc_vector in vectors]
175
+ aux_dims = [np.sqrt(self.phi - norm) for norm in norms]
176
+ hnsw_vectors = [
177
+ np.hstack((doc_vector, aux_dims[i].reshape(-1, 1)))
178
+ for i, doc_vector in enumerate(vectors)
179
+ ]
180
+ hnsw_vectors = np.concatenate(hnsw_vectors, axis=0)
181
+
182
+ self._update_id_mapping(db_ids)
183
+ self.index.add(hnsw_vectors)
184
+
185
+ def search_knn(
186
+ self, query_vectors: np.array, top_docs: int
187
+ ) -> List[Tuple[List[object], List[float]]]:
188
+
189
+ aux_dim = np.zeros(len(query_vectors), dtype="float32")
190
+ query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1)))
191
+ logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape)
192
+ scores, indexes = self.index.search(query_nhsw_vectors, top_docs)
193
+ # convert to external ids
194
+ db_ids = [
195
+ [self.index_id_to_db_id[i] for i in query_top_idxs]
196
+ for query_top_idxs in indexes
197
+ ]
198
+ result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
199
+ return result
200
+
201
+ def deserialize_from(self, file: str):
202
+ super(DenseHNSWFlatIndexer, self).deserialize_from(file)
203
+ # to trigger warning on subsequent indexing
204
+ self.phi = None
205
+
206
+
207
+ def iterate_encoded_files(vector_files: str) -> Iterator[Tuple[object, np.array]]:
208
+ # for i, file in enumerate(vector_files):
209
+ logger.info("Reading file %s", vector_files)
210
+ with open(vector_files, "rb") as reader:
211
+ doc_vectors = pickle.load(reader)
212
+ for doc in doc_vectors:
213
+ db_id, doc_vector = doc
214
+ yield db_id, doc_vector
Inference_components/test.py DELETED
@@ -1 +0,0 @@
1
- print('hello')
 
 
Model_utils_inf.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import glob
3
+ import logging
4
+ import os
5
+ from typing import List
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ from torch.serialization import default_restore_location
11
+
12
+ logger = logging.getLogger()
13
+
14
+ CheckpointState = collections.namedtuple(
15
+ "CheckpointState",
16
+ [
17
+ "model_dict",
18
+ "optimizer_dict",
19
+ "scheduler_dict",
20
+ "offset",
21
+ "epoch",
22
+ "encoder_params",
23
+ ],
24
+ )
25
+
26
+ def setup_for_distributed_mode(
27
+ model: nn.Module,
28
+ optimizer: torch.optim.Optimizer,
29
+ device: object,
30
+ n_gpu: int = 1,
31
+ local_rank: int = -1,
32
+ fp16: bool = False,
33
+ fp16_opt_level: str = "O1",
34
+ ) -> (nn.Module, torch.optim.Optimizer):
35
+ model.to(device)
36
+ if fp16:
37
+ try:
38
+ import apex
39
+ from apex import amp
40
+
41
+ apex.amp.register_half_function(torch, "einsum")
42
+ except ImportError:
43
+ raise ImportError(
44
+ "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
45
+ )
46
+
47
+ model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
48
+
49
+ if n_gpu > 1:
50
+ model = torch.nn.DataParallel(model)
51
+
52
+ if local_rank != -1:
53
+ model = torch.nn.parallel.DistributedDataParallel(
54
+ model,
55
+ device_ids=[local_rank],
56
+ output_device=local_rank,
57
+ find_unused_parameters=True,
58
+ )
59
+ return model, optimizer
60
+
61
+
62
+ def move_to_cuda(sample):
63
+ if len(sample) == 0:
64
+ return {}
65
+
66
+ def _move_to_cuda(maybe_tensor):
67
+ if torch.is_tensor(maybe_tensor):
68
+ return maybe_tensor.cuda()
69
+ elif isinstance(maybe_tensor, dict):
70
+ return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
71
+ elif isinstance(maybe_tensor, list):
72
+ return [_move_to_cuda(x) for x in maybe_tensor]
73
+ elif isinstance(maybe_tensor, tuple):
74
+ return [_move_to_cuda(x) for x in maybe_tensor]
75
+ else:
76
+ return maybe_tensor
77
+
78
+ return _move_to_cuda(sample)
79
+
80
+
81
+ def move_to_device(sample, device):
82
+ if len(sample) == 0:
83
+ return {}
84
+
85
+ def _move_to_device(maybe_tensor, device):
86
+ if torch.is_tensor(maybe_tensor):
87
+ return maybe_tensor.to(device)
88
+ elif isinstance(maybe_tensor, dict):
89
+ return {
90
+ key: _move_to_device(value, device)
91
+ for key, value in maybe_tensor.items()
92
+ }
93
+ elif isinstance(maybe_tensor, list):
94
+ return [_move_to_device(x, device) for x in maybe_tensor]
95
+ elif isinstance(maybe_tensor, tuple):
96
+ return [_move_to_device(x, device) for x in maybe_tensor]
97
+ else:
98
+ return maybe_tensor
99
+
100
+ return _move_to_device(sample, device)
101
+
102
+
103
+ def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1):
104
+ """Create a schedule with a learning rate that decreases linearly after
105
+ linearly increasing during a warmup period.
106
+ """
107
+
108
+ def lr_lambda(current_step):
109
+ if current_step < warmup_steps:
110
+ return float(current_step) / float(max(1, warmup_steps))
111
+ return max(
112
+ 0.0,
113
+ float(training_steps - current_step)
114
+ / float(max(1, training_steps - warmup_steps)),
115
+ )
116
+
117
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
118
+
119
+
120
+ def init_weights(modules: List):
121
+ for module in modules:
122
+ if isinstance(module, (nn.Linear, nn.Embedding)):
123
+ module.weight.data.normal_(mean=0.0, std=0.02)
124
+ elif isinstance(module, nn.LayerNorm):
125
+ module.bias.data.zero_()
126
+ module.weight.data.fill_(1.0)
127
+ if isinstance(module, nn.Linear) and module.bias is not None:
128
+ module.bias.data.zero_()
129
+
130
+
131
+ def get_model_obj(model: nn.Module):
132
+ return model.module if hasattr(model, "module") else model
133
+
134
+
135
+ def get_model_file(args, file_prefix) -> str:
136
+ if args.model_file and os.path.exists(args.model_file):
137
+ return args.model_file
138
+
139
+ out_cp_files = (
140
+ glob.glob(os.path.join(args.output_dir, file_prefix + "*"))
141
+ if args.output_dir
142
+ else []
143
+ )
144
+ logger.info("Checkpoint files %s", out_cp_files)
145
+ model_file = None
146
+
147
+ if len(out_cp_files) > 0:
148
+ model_file = max(out_cp_files, key=os.path.getctime)
149
+ return model_file
150
+
151
+
152
+ def load_states_from_checkpoint(model_file: str) -> CheckpointState:
153
+ logger.info("Reading saved model from s", model_file)
154
+ if isinstance(model_file, tuple):
155
+ model_file = model_file[0]
156
+ state_dict = torch.load(
157
+ model_file, map_location=lambda s, l: default_restore_location(s, "cpu")
158
+ )
159
+ logger.info("model_state_dict keys %s", state_dict.keys())
160
+ return CheckpointState(**state_dict)
Models_inf.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains the inference code for loading and running the closed-book and open-book QA models
2
+ import os
3
+ import csv
4
+ import glob
5
+ import gzip
6
+ import string
7
+ import sys
8
+ from typing import List, Tuple, Dict
9
+ import re
10
+ import math
11
+ import collections
12
+
13
+ import numpy as np
14
+ import unicodedata
15
+ import torch
16
+ from torch import Tensor as T
17
+ from torch import nn
18
+
19
+ from models import init_biencoder_components
20
+ from Options_inf import setup_args_gpu, print_args, set_encoder_params_from_state
21
+ from Faiss_Indexers_inf import DenseIndexer, DenseFlatIndexer
22
+ from Data_utils_inf import Tensorizer
23
+ from Model_utils_inf import load_states_from_checkpoint, get_model_obj
24
+
25
+
26
+ SEGMENTER_CACHE = {}
27
+ RERANKER_CACHE = {}
28
+
29
+ def setup_closedbook(model_path, ans_tsv_path, dense_embd_path, process_id, model_type):
30
+ dpr = DPRForCrossword(
31
+ model_path,
32
+ ans_tsv_path,
33
+ dense_embd_path,
34
+ retrievalmodel = False,
35
+ process_id=process_id,
36
+ model_type = model_type
37
+ )
38
+ return dpr
39
+
40
+ def preprocess_clue_fn(clue):
41
+ clue = str(clue)
42
+
43
+ # https://stackoverflow.com/questions/517923/what-is-the-best-way-to-remove-accents-normalize-in-a-python-unicode-string
44
+ clue = ''.join(c for c in unicodedata.normalize('NFD', clue) if unicodedata.category(c) != 'Mn')
45
+
46
+ clue = re.sub("\x17|\x18|\x93|\x94|“|”|''|\"\"", "\"", clue)
47
+ clue = re.sub("\x85|…", "...", clue)
48
+ clue = re.sub("\x91|\x92|‘|’", "'", clue)
49
+
50
+ clue = re.sub("‚", ",", clue)
51
+ clue = re.sub("—|–", "-", clue)
52
+ clue = re.sub("¢", " cents", clue)
53
+ clue = re.sub("¿|¡|^;|\{|\}", "", clue)
54
+ clue = re.sub("÷", "division", clue)
55
+ clue = re.sub("°", " degrees", clue)
56
+
57
+ euro = re.search("^£[0-9]+(,*[0-9]*){0,}| £[0-9]+(,*[0-9]*){0,}", clue)
58
+ if euro:
59
+ num = clue[:euro.end()]
60
+ rest_clue = clue[euro.end():]
61
+ clue = num + " Euros" + rest_clue
62
+ clue = re.sub(", Euros", " Euros", clue)
63
+ clue = re.sub("Euros [Mm]illion", "million Euros", clue)
64
+ clue = re.sub("Euros [Bb]illion", "billion Euros", clue)
65
+ clue = re.sub("Euros[Kk]", "K Euros", clue)
66
+ clue = re.sub(" K Euros", "K Euros", clue)
67
+ clue = re.sub("£", "", clue)
68
+
69
+ clue = re.sub(" *\(\d{1,},*\)$| *\(\d{1,},* \d{1,}\)$", "", clue)
70
+
71
+ clue = re.sub("&amp;", "&", clue)
72
+ clue = re.sub("&lt;", "<", clue)
73
+ clue = re.sub("&gt;", ">", clue)
74
+
75
+ clue = re.sub("e\.g\.|for ex\.", "for example", clue)
76
+ clue = re.sub(": [Aa]bbreviat\.|: [Aa]bbrev\.|: [Aa]bbrv\.|: [Aa]bbrv|: [Aa]bbr\.|: [Aa]bbr", " abbreviation", clue)
77
+ clue = re.sub("abbr\.|abbrv\.", "abbreviation", clue)
78
+ clue = re.sub("Abbr\.|Abbrv\.", "Abbreviation", clue)
79
+ clue = re.sub("\(anag\.\)|\(anag\)", "(anagram)", clue)
80
+ clue = re.sub("org\.", "organization", clue)
81
+ clue = re.sub("Org\.", "Organization", clue)
82
+ clue = re.sub("Grp\.|Gp\.", "Group", clue)
83
+ clue = re.sub("grp\.|gp\.", "group", clue)
84
+ clue = re.sub(": Sp\.", " (Spanish)", clue)
85
+ clue = re.sub("\(Sp\.\)|Sp\.", "(Spanish)", clue)
86
+ clue = re.sub("Ave\.", "Avenue", clue)
87
+ clue = re.sub("Sch\.", "School", clue)
88
+ clue = re.sub("sch\.", "school", clue)
89
+ clue = re.sub("Agcy\.", "Agency", clue)
90
+ clue = re.sub("agcy\.", "agency", clue)
91
+ clue = re.sub("Co\.", "Company", clue)
92
+ clue = re.sub("co\.", "company", clue)
93
+ clue = re.sub("No\.", "Number", clue)
94
+ clue = re.sub("no\.", "number", clue)
95
+ clue = re.sub(": [Vv]ar\.", " variable", clue)
96
+ clue = re.sub("Subj\.", "Subject", clue)
97
+ clue = re.sub("subj\.", "subject", clue)
98
+ clue = re.sub("Subjs\.", "Subjects", clue)
99
+ clue = re.sub("subjs\.", "subjects", clue)
100
+
101
+ theme_clue = re.search("^.+\|[A-Z]{1,}", clue)
102
+ if theme_clue:
103
+ clue = re.sub("\|", " | ", clue)
104
+
105
+ if "Partner of" in clue:
106
+ clue = re.sub("Partner of", "", clue)
107
+ clue = clue + " and ___"
108
+
109
+ link = re.search("^.+-.+ [Ll]ink$", clue)
110
+ if link:
111
+ no_link = re.search("^.+-.+ ", clue)
112
+ x_y = clue[no_link.start():no_link.end() - 1]
113
+ x_y_lst = x_y.split("-")
114
+ clue = x_y_lst[0] + " ___ " + x_y_lst[1]
115
+
116
+ follower = re.search("^.+ [Ff]ollower$", clue)
117
+ if follower:
118
+ no_follower = re.search("^.+ ", clue)
119
+ x = clue[:no_follower.end() - 1]
120
+ clue = x + " ___"
121
+
122
+ preceder = re.search("^.+ [Pp]receder$", clue)
123
+ if preceder:
124
+ no_preceder = re.search("^.+ ", clue)
125
+ x = clue[:no_preceder.end() - 1]
126
+ clue = "___ " + x
127
+
128
+ if re.search("--[^A-Za-z]|--$", clue):
129
+ clue = re.sub("--", "__", clue)
130
+ if not re.search("_-[A-Za-z]|_-$", clue):
131
+ clue = re.sub("_-", "__", clue)
132
+
133
+ clue = re.sub("_{2,}", "___", clue)
134
+
135
+ clue = re.sub("\?$", " (wordplay)", clue)
136
+
137
+ nonverbal = re.search("\[[^0-9]+,* *[^0-9]*\]", clue)
138
+ if nonverbal:
139
+ clue = re.sub("\[|\]", "", clue)
140
+ clue = clue + " (nonverbal)"
141
+
142
+ if clue[:4] == "\"\"\" " and clue[-4:] == " \"\"\"":
143
+ clue = "\"" + clue[4:-4] + "\""
144
+ if clue[:4] == "''' " and clue[-4:] == " '''":
145
+ clue = "'" + clue[4:-4] + "'"
146
+ if clue[:3] == "\"\"\"" and clue[-3:] == "\"\"\"":
147
+ clue = "\"" + clue[3:-3] + "\""
148
+ if clue[:3] == "'''" and clue[-3:] == "'''":
149
+ clue = "'" + clue[3:-3] + "'"
150
+
151
+ return clue
152
+
153
+
154
+ def answer_clues(dpr, clues, max_answers, output_strings=False):
155
+ clues = [preprocess_clue_fn(c.rstrip()) for c in clues]
156
+ outputs = dpr.answer_clues_closedbook(clues, max_answers, output_strings=output_strings)
157
+ return outputs
158
+
159
+ class DenseRetriever(object):
160
+ """
161
+ Does passage retrieving over the provided index and question encoder
162
+ """
163
+ def __init__(
164
+ self,
165
+ question_encoder: nn.Module,
166
+ batch_size: int,
167
+ tensorizer: Tensorizer,
168
+ index: DenseIndexer,
169
+ device=None,
170
+ model_type = 'bert'
171
+ ):
172
+ self.question_encoder = question_encoder
173
+ self.batch_size = batch_size
174
+ self.tensorizer = tensorizer
175
+ self.index = index
176
+ self.device = device
177
+ self.model_type = model_type
178
+
179
+ def generate_question_vectors(self, questions: List[str]) -> T:
180
+ n = len(questions)
181
+ bsz = self.batch_size
182
+ query_vectors = []
183
+ self.question_encoder.eval()
184
+
185
+ with torch.no_grad():
186
+ for j, batch_start in enumerate(range(0, n, bsz)):
187
+ batch_token_tensors = [
188
+ self.tensorizer.text_to_tensor(q)
189
+ for q in questions[batch_start : batch_start + bsz]
190
+ ]
191
+
192
+ q_ids_batch = torch.stack(batch_token_tensors, dim=0).to(self.device)
193
+ q_seg_batch = torch.zeros_like(q_ids_batch).to(self.device)
194
+ # q_attn_mask = self.tensorizer.get_attn_mask(q_ids_batch)
195
+ q_attn_mask = (q_ids_batch != 0)
196
+
197
+ if self.model_type == 'bert':
198
+ _, out, _ = self.question_encoder(q_ids_batch, q_seg_batch, q_attn_mask)
199
+ elif self.model_type == 'distilbert':
200
+ _, out, _ = self.question_encoder(q_ids_batch, q_attn_mask)
201
+
202
+ query_vectors.extend(out.cpu().split(1, dim=0))
203
+
204
+ query_tensor = torch.cat(query_vectors, dim=0)
205
+ assert query_tensor.size(0) == len(questions)
206
+ return query_tensor
207
+
208
+ def get_top_docs(self, query_vectors: np.array, top_docs: int = 100) -> List[Tuple[List[object], List[float]]]:
209
+ """
210
+ Does the retrieval of the best matching passages given the query vectors batch
211
+ :param query_vectors:
212
+ :param top_docs:
213
+ :return:
214
+ """
215
+ results = self.index.search_knn(query_vectors, top_docs)
216
+ return results
217
+
218
+ class FakeRetrieverArgs:
219
+ """Used to surpress the existing argparse inside DPR so we can have our own argparse"""
220
+ def __init__(self):
221
+ self.do_lower_case = False
222
+ self.pretrained_model_cfg = None
223
+ self.encoder_model_type = None
224
+ self.model_file = None
225
+ self.projection_dim = 0
226
+ self.sequence_length = 512
227
+ self.do_fill_lower_case = False
228
+ self.desegment_valid_fill = False
229
+ self.no_cuda = True
230
+ self.local_rank = -1
231
+ self.fp16 = False
232
+ self.fp16_opt_level = "O1"
233
+
234
+
235
+ class DPRForCrossword(object):
236
+ """Closedbook model for Crossword clue answering"""
237
+
238
+ def __init__(
239
+ self,
240
+ model_file,
241
+ ctx_file,
242
+ encoded_ctx_file,
243
+ batch_size = 16,
244
+ retrievalmodel=False,
245
+ process_id = 0,
246
+ model_type = 'bert'
247
+ ):
248
+ self.retrievalmodel = retrievalmodel # am I a wikipedia retrieval model or a closed-book model
249
+ args = FakeRetrieverArgs()
250
+ args.model_file = model_file
251
+ args.ctx_file = ctx_file
252
+ args.encoded_ctx_file = encoded_ctx_file
253
+ args.batch_size = batch_size
254
+ # self.device = torch.device("cuda:"+str(process_id%torch.cuda.device_count()))
255
+ self.device = 'cpu'
256
+ self.model_type = model_type
257
+
258
+ setup_args_gpu(args)
259
+ saved_state = load_states_from_checkpoint(args.model_file)
260
+ set_encoder_params_from_state(saved_state.encoder_params, args)
261
+
262
+ tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only = True)
263
+
264
+ question_encoder = encoder.question_model
265
+ question_encoder = question_encoder.to(self.device)
266
+ question_encoder.eval()
267
+
268
+ # load weights from the model file
269
+ model_to_load = get_model_obj(question_encoder)
270
+
271
+ prefix_len = len("question_model.")
272
+ question_encoder_state = {
273
+ key[prefix_len:]: value
274
+ for (key, value) in saved_state.model_dict.items()
275
+ if key.startswith("question_model.")
276
+ }
277
+ model_to_load.load_state_dict(question_encoder_state, strict = False)
278
+ vector_size = model_to_load.get_out_size()
279
+
280
+ index = DenseFlatIndexer(vector_size, 50000)
281
+
282
+ self.retriever = DenseRetriever(
283
+ question_encoder,
284
+ args.batch_size,
285
+ tensorizer,
286
+ index,
287
+ self.device,
288
+ self.model_type
289
+ )
290
+
291
+ # index all passages
292
+ embd_file_path = args.encoded_ctx_file
293
+ if isinstance(embd_file_path, str):
294
+ file_path = embd_file_path
295
+ else:
296
+ file_path = embd_file_path[0]
297
+ self.retriever.index.index_data(file_path)
298
+
299
+ self.all_passages = self.load_passages(args.ctx_file)
300
+ self.fill2id = {}
301
+ for key in self.all_passages.keys():
302
+ self.fill2id[
303
+ "".join(
304
+ [
305
+ letter
306
+ for letter in self.all_passages[key][1].upper()
307
+ if letter in string.ascii_uppercase
308
+ ]
309
+ )
310
+ ] = key
311
+
312
+ # might as well uppercase and remove non-alphas from the fills before we start to save time later
313
+ if not retrievalmodel:
314
+ temp = {}
315
+ for my_id in self.all_passages.keys():
316
+ temp[my_id] = "".join([c.upper() for c in self.all_passages[my_id][1] if c.upper() in string.ascii_uppercase])
317
+ self.len_all_passages = len(list(self.all_passages.values()))
318
+ self.all_passages = temp
319
+
320
+
321
+ @staticmethod
322
+ def load_passages(ctx_file: str) -> Dict[object, Tuple[str, str]]:
323
+ docs = {}
324
+ if isinstance(ctx_file, tuple):
325
+ ctx_file = ctx_file[0]
326
+ if ctx_file.endswith(".gz"):
327
+ with gzip.open(ctx_file, "rt") as tsvfile:
328
+ reader = csv.reader(
329
+ tsvfile,
330
+ delimiter="\t",
331
+ )
332
+ # file format: doc_id, doc_text, title
333
+ for row in reader:
334
+ if row[0] != "id":
335
+ docs[row[0]] = (row[1], row[2])
336
+ else:
337
+ with open(ctx_file) as tsvfile:
338
+ reader = csv.reader(
339
+ tsvfile,
340
+ delimiter="\t",
341
+ )
342
+ # file format: doc_id, doc_text, title
343
+ for row in reader:
344
+ if row[0] != "id":
345
+ docs[row[0]] = (row[1], row[2])
346
+ return docs
347
+
348
+ def answer_clues_closedbook(self, questions, max_answers, output_strings=False):
349
+ # assumes clues are preprocessed
350
+ assert self.retrievalmodel == False
351
+ questions_tensor = self.retriever.generate_question_vectors(questions)
352
+
353
+ if max_answers > self.len_all_passages:
354
+ max_answers = self.len_all_passages
355
+
356
+ # get top k results
357
+ top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_answers)
358
+
359
+ if not output_strings:
360
+ return top_ids_and_scores
361
+ else:
362
+ # get the string forms
363
+ all_answers = []
364
+ all_scores = []
365
+ for ans in top_ids_and_scores:
366
+ all_answers.append(list(map(self.all_passages.get, ans[0])))
367
+ all_scores.append(ans[1])
368
+ return all_answers, all_scores
369
+
370
+ def get_wikipedia_docs(self, questions, max_docs):
371
+ # assumes clues are preprocessed
372
+ assert self.retrievalmodel
373
+ questions_tensor = self.retriever.generate_question_vectors(questions)
374
+
375
+ # get top k results. add 2 in case of duplicates (see below
376
+ top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_docs + 2)
377
+
378
+ all_paragraphs = []
379
+ for ans in top_ids_and_scores:
380
+ paragraphs = []
381
+ for i in range(len(ans[0])):
382
+ id_ = ans[0][i]
383
+ id_ = id_.replace("wiki:", "")
384
+ mydocument = self.all_passages[id_]
385
+ if mydocument in paragraphs:
386
+ print("woah, duplicate!!!")
387
+ continue
388
+ paragraphs.append(mydocument)
389
+ all_paragraphs.append(paragraphs[0:max_docs])
390
+
391
+ return all_paragraphs
Normal_utils_inf.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import puz
2
+ import re
3
+ import unicodedata
4
+ import sys
5
+
6
+ def puz_to_json(fname):
7
+ """ Converts a puzzle in .puz format to .json format
8
+ """
9
+ p = puz.read(fname)
10
+ numbering = p.clue_numbering()
11
+
12
+ grid = [[None for _ in range(p.width)] for _ in range(p.height)]
13
+ for row_idx in range(p.height):
14
+ cell = row_idx * p.width
15
+ row_solution = p.solution[cell:cell + p.width]
16
+ for col_index, item in enumerate(row_solution):
17
+ if p.solution[cell + col_index:cell + col_index + 1] == '.':
18
+ grid[row_idx][col_index] = 'BLACK'
19
+ else:
20
+ grid[row_idx][col_index] = ["", row_solution[col_index: col_index + 1]]
21
+
22
+ across_clues = {}
23
+ for clue in numbering.across:
24
+ answer = ''.join(p.solution[clue['cell'] + i] for i in range(clue['len']))
25
+ across_clues[str(clue['num'])] = [clue['clue'] + ' ', ' ' + answer]
26
+ grid[int(clue['cell'] / p.width)][clue['cell'] % p.width][0] = str(clue['num'])
27
+
28
+ down_clues = {}
29
+ for clue in numbering.down:
30
+ answer = ''.join(p.solution[clue['cell'] + i * numbering.width] for i in range(clue['len']))
31
+ down_clues[str(clue['num'])] = [clue['clue'] + ' ', ' ' + answer]
32
+ grid[int(clue['cell'] / p.width)][clue['cell'] % p.width][0] = str(clue['num'])
33
+
34
+
35
+ mydict = {'metadata': {'date': None, 'rows': p.height, 'cols': p.width}, 'clues': {'across': across_clues, 'down': down_clues}, 'grid': grid}
36
+ return mydict
37
+
38
+ def puz_to_pairs(filepath):
39
+ """ Takes in a filepath pointing to a .puz file and returns a list of (clue, fill) pairs in a list
40
+ """
41
+ p = puz.read(filepath)
42
+
43
+ numbering = p.clue_numbering()
44
+
45
+ grid = [[None for _ in range(p.width)] for _ in range(p.height)]
46
+ for row_idx in range(p.height):
47
+ cell = row_idx * p.width
48
+ row_solution = p.solution[cell:cell + p.width]
49
+ for col_index, item in enumerate(row_solution):
50
+ if p.solution[cell + col_index:cell + col_index + 1] == '.':
51
+ grid[row_idx][col_index] = 'BLACK'
52
+ else:
53
+ grid[row_idx][col_index] = ["", row_solution[col_index: col_index + 1]]
54
+
55
+ pairs = {}
56
+ for clue in numbering.across:
57
+ answer = ''.join(p.solution[clue['cell'] + i] for i in range(clue['len']))
58
+ pairs[clue['clue']] = answer
59
+
60
+ for clue in numbering.down:
61
+ answer = ''.join(p.solution[clue['cell'] + i * numbering.width] for i in range(clue['len']))
62
+ pairs[clue['clue']] = answer
63
+
64
+ return [(k, v) for k, v in pairs.items()]
65
+
Options_inf.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import random
5
+ import socket
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ logger = logging.getLogger()
11
+
12
+
13
+ def add_tokenizer_params(parser: argparse.ArgumentParser):
14
+ parser.add_argument(
15
+ "--do_lower_case",
16
+ action="store_true",
17
+ help="Whether to lower case the input text. True for uncased models, False for cased models.",
18
+ )
19
+
20
+
21
+ def add_encoder_params(parser: argparse.ArgumentParser):
22
+ """
23
+ Common parameters to initialize an encoder-based model
24
+ """
25
+ parser.add_argument(
26
+ "--pretrained_model_cfg",
27
+ default=None,
28
+ type=str,
29
+ help="config name for model initialization",
30
+ )
31
+ parser.add_argument(
32
+ "--encoder_model_type",
33
+ default=None,
34
+ type=str,
35
+ help="model type. One of [hf_bert, pytext_bert, fairseq_roberta]",
36
+ )
37
+ parser.add_argument(
38
+ "--pretrained_file",
39
+ type=str,
40
+ help="Some encoders need to be initialized from a file",
41
+ )
42
+ parser.add_argument(
43
+ "--model_file",
44
+ default=None,
45
+ type=str,
46
+ help="Saved bi-encoder checkpoint file to initialize the model",
47
+ )
48
+ parser.add_argument(
49
+ "--projection_dim",
50
+ default=0,
51
+ type=int,
52
+ help="Extra linear layer on top of standard bert/roberta encoder",
53
+ )
54
+ parser.add_argument(
55
+ "--sequence_length",
56
+ type=int,
57
+ default=512,
58
+ help="Max length of the encoder input sequence",
59
+ )
60
+ parser.add_argument(
61
+ "--do_fill_lower_case",
62
+ action="store_true",
63
+ help="Make all fills lower case. e.g. for cased models such as roberta"
64
+ )
65
+ parser.add_argument(
66
+ "--desegment_valid_fill",
67
+ action="store_true",
68
+ help="Desegment model fill output for validation"
69
+ )
70
+
71
+
72
+ def add_training_params(parser: argparse.ArgumentParser):
73
+ """
74
+ Common parameters for training
75
+ """
76
+ add_cuda_params(parser)
77
+ parser.add_argument(
78
+ "--train_file", default=None, type=str, help="File pattern for the train set"
79
+ )
80
+ parser.add_argument("--dev_file", default=None, type=str, help="")
81
+
82
+ parser.add_argument(
83
+ "--batch_size", default=2, type=int, help="Amount of questions per batch"
84
+ )
85
+ parser.add_argument(
86
+ "--dev_batch_size",
87
+ type=int,
88
+ default=4,
89
+ help="amount of questions per batch for dev set validation",
90
+ )
91
+ parser.add_argument(
92
+ "--seed",
93
+ type=int,
94
+ default=0,
95
+ help="random seed for initialization and dataset shuffling",
96
+ )
97
+
98
+ parser.add_argument(
99
+ "--adam_eps", default=1e-8, type=float, help="Epsilon for Adam optimizer."
100
+ )
101
+ parser.add_argument(
102
+ "--adam_betas",
103
+ default="(0.9, 0.999)",
104
+ type=str,
105
+ help="Betas for Adam optimizer.",
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
110
+ )
111
+ parser.add_argument("--log_batch_step", default=100, type=int, help="")
112
+ parser.add_argument("--train_rolling_loss_step", default=100, type=int, help="")
113
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="")
114
+ parser.add_argument(
115
+ "--learning_rate",
116
+ default=1e-5,
117
+ type=float,
118
+ help="The initial learning rate for Adam.",
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--warmup_steps", default=100, type=int, help="Linear warmup over warmup_steps."
123
+ )
124
+ parser.add_argument("--dropout", default=0.1, type=float, help="")
125
+
126
+ parser.add_argument(
127
+ "--gradient_accumulation_steps",
128
+ type=int,
129
+ default=1,
130
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
131
+ )
132
+ parser.add_argument(
133
+ "--num_train_epochs",
134
+ default=3.0,
135
+ type=float,
136
+ help="Total number of training epochs to perform.",
137
+ )
138
+
139
+
140
+ def add_cuda_params(parser: argparse.ArgumentParser):
141
+ parser.add_argument(
142
+ "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
143
+ )
144
+ parser.add_argument(
145
+ "--local_rank",
146
+ type=int,
147
+ default=-1,
148
+ help="local_rank for distributed training on gpus",
149
+ )
150
+ parser.add_argument(
151
+ "--fp16",
152
+ action="store_true",
153
+ help="Whether to use 16-bit float precision instead of 32-bit",
154
+ )
155
+
156
+ parser.add_argument(
157
+ "--fp16_opt_level",
158
+ type=str,
159
+ default="O1",
160
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
161
+ "See details at https://nvidia.github.io/apex/amp.html",
162
+ )
163
+
164
+
165
+ def add_reader_preprocessing_params(parser: argparse.ArgumentParser):
166
+ parser.add_argument(
167
+ "--gold_passages_src",
168
+ type=str,
169
+ help="File with the original dataset passages (json format). Required for train set",
170
+ )
171
+ parser.add_argument(
172
+ "--gold_passages_src_dev",
173
+ type=str,
174
+ help="File with the original dataset passages (json format). Required for dev set",
175
+ )
176
+ parser.add_argument(
177
+ "--num_workers",
178
+ type=int,
179
+ default=16,
180
+ help="number of parallel processes to binarize reader data",
181
+ )
182
+
183
+
184
+ def get_encoder_checkpoint_params_names():
185
+ return [
186
+ "do_lower_case",
187
+ "pretrained_model_cfg",
188
+ "encoder_model_type",
189
+ "pretrained_file",
190
+ "projection_dim",
191
+ "sequence_length",
192
+ ]
193
+
194
+
195
+ def get_encoder_params_state(args):
196
+ """
197
+ Selects the param values to be saved in a checkpoint, so that a trained model faile can be used for downstream
198
+ tasks without the need to specify these parameter again
199
+ :return: Dict of params to memorize in a checkpoint
200
+ """
201
+ params_to_save = get_encoder_checkpoint_params_names()
202
+
203
+ r = {}
204
+ for param in params_to_save:
205
+ r[param] = getattr(args, param)
206
+ return r
207
+
208
+
209
+ def set_encoder_params_from_state(state, args):
210
+ if not state:
211
+ return
212
+ params_to_save = get_encoder_checkpoint_params_names()
213
+
214
+ override_params = [
215
+ (param, state[param])
216
+ for param in params_to_save
217
+ if param in state and state[param]
218
+ ]
219
+ for param, value in override_params:
220
+ if hasattr(args, param):
221
+ logger.warning(
222
+ "Overriding args parameter value from checkpoint state. Param = %s, value = %s",
223
+ param,
224
+ value,
225
+ )
226
+ setattr(args, param, value)
227
+ return args
228
+
229
+
230
+ def set_seed(args):
231
+ seed = args.seed
232
+ random.seed(seed)
233
+ np.random.seed(seed)
234
+ torch.manual_seed(seed)
235
+ if args.n_gpu > 0:
236
+ torch.cuda.manual_seed_all(seed)
237
+
238
+
239
+ def setup_args_gpu(args):
240
+ """
241
+ Setup arguments CUDA, GPU & distributed training
242
+ """
243
+
244
+ if args.local_rank == -1 or args.no_cuda: # single-node multi-gpu (or cpu) mode
245
+ device = torch.device(
246
+ "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
247
+ )
248
+ args.n_gpu = torch.cuda.device_count()
249
+ else: # distributed mode
250
+ torch.cuda.set_device(args.local_rank)
251
+ device = torch.device("cuda", args.local_rank)
252
+ torch.distributed.init_process_group(backend="nccl")
253
+ args.n_gpu = 1
254
+ args.device = device
255
+ ws = os.environ.get("WORLD_SIZE")
256
+
257
+ args.distributed_world_size = int(ws) if ws else 1
258
+
259
+ logger.info(
260
+ "Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d",
261
+ socket.gethostname(),
262
+ args.local_rank,
263
+ device,
264
+ args.n_gpu,
265
+ args.distributed_world_size,
266
+ )
267
+ logger.info("16-bits training: %s ", args.fp16)
268
+
269
+
270
+ def print_args(args):
271
+ logger.info(" **************** CONFIGURATION **************** ")
272
+ for key, val in sorted(vars(args).items()):
273
+ keystr = "{}".format(key) + (" " * (30 - len(key)))
274
+ logger.info("%s --> %s", keystr, val)
275
+ logger.info(" **************** CONFIGURATION **************** ")
Solver_inf.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import defaultdict
3
+ import string
4
+
5
+ from scipy.special import softmax
6
+ import numpy as np
7
+
8
+ from Models_inf import answer_clues, setup_closedbook
9
+
10
+ class Solver:
11
+ """
12
+ This class represents an abstraction over different types of crossword solvers. Each puzzle contains
13
+ a list of clues, which are associated with (weighted) values for each candidate answer.
14
+
15
+ Args:
16
+ crossword (Crossword): puzzle to solve
17
+ max_candidates (int): number of answer candidates to consider per clue
18
+ """
19
+ def __init__(self, crossword, model_path, ans_tsv_path, dense_embd_path, max_candidates=1000, process_id = 0, model_type = 'bert'):
20
+ self.crossword = crossword
21
+ self.max_candidates = max_candidates
22
+ self.process_id = process_id
23
+ self.model_path = model_path
24
+ self.ans_tsv_path = ans_tsv_path
25
+ self.dense_embd_glob = dense_embd_path,
26
+ self.model_type = model_type
27
+ self.get_candidates()
28
+
29
+ def get_candidates(self):
30
+ # get answers from neural model and fill up data structures with the results
31
+ chars = string.ascii_uppercase
32
+ self.char_map = {char: idx for idx, char in enumerate(chars)}
33
+ self.candidates = {}
34
+
35
+ all_clues = []
36
+ for var in self.crossword.variables:
37
+ all_clues.append(self.crossword.variables[var]['clue'])
38
+
39
+ # replaces stuff like "Opposite of 29-across" with "Opposite of X", where X is the clue for 29-across
40
+ r = re.compile('([0-9]+)[-\s](down|across)', re.IGNORECASE)
41
+ matches = [(idx, r.search(clue)) for idx, clue in enumerate(all_clues) if r.search(clue) != None]
42
+ for (idx, match) in matches:
43
+ clue = all_clues[idx]
44
+ var = str(match.group(1)) + str(match.group(2)[0]).upper()
45
+ if var in self.crossword.variables:
46
+ clue = clue[:match.start()] + self.crossword.variables[var]['clue'] + clue[match.end():]
47
+ all_clues[idx] = clue
48
+
49
+ # print("MODEL PATH: ", type(self.dense_embd_glob))
50
+ # get predictions
51
+ dpr = setup_closedbook(self.model_path, self.ans_tsv_path, self.dense_embd_glob, self.process_id, self.model_type)
52
+ all_words, all_scores = answer_clues(dpr, all_clues, max_answers=self.max_candidates, output_strings=True)
53
+ for index, var in enumerate(self.crossword.variables):
54
+ length = len(self.crossword.variables[var]["gold"])
55
+ self.candidates[var] = {"words": [], "bit_array": None, "weights": {}}
56
+
57
+ clue = all_clues[index]
58
+ words, scores = all_words[index], all_scores[index]
59
+ # remove answers that are not of the correct length
60
+ keep_positions = []
61
+ for word_index, word in enumerate(words):
62
+ if len(word) == length:
63
+ keep_positions.append(word_index)
64
+ words = [words[i] for i in keep_positions]
65
+ scores = [scores[i] for i in keep_positions]
66
+ scores = list(-np.log(softmax(np.array(scores) / 0.75)))
67
+
68
+ for word, score in zip(words, scores):
69
+ self.candidates[var]["weights"][word] = score
70
+
71
+ # for debugging purposes, print the rank of the gold answer on our candidate list
72
+ # the gold answer is otherwise *not* used in any way during solving
73
+ # if self.crossword.variables[var]["gold"] in words:
74
+ # print(clue, self.crossword.variables[var]["gold"], words.index(self.crossword.variables[var]["gold"]))
75
+ # else:
76
+ # print('not found', clue, self.crossword.variables[var]["gold"])
77
+
78
+ # fill up some data structures used later in solving
79
+ for word, score in zip(words, scores):
80
+ self.candidates[var]["weights"][word] = score
81
+ weights = self.candidates[var]["weights"]
82
+ self.candidates[var]["words"] = sorted(weights, key=weights.get)
83
+ self.candidates[var]["bit_array"] = np.zeros((len(chars), length, len(self.candidates[var]["words"])))
84
+ self.candidates[var]["single_query_cache"] = [defaultdict(lambda:[]) for _ in range(len(chars))]
85
+ self.candidates[var]["single_query_cache_indices"] = [defaultdict(lambda:[]) for _ in range(len(chars))]
86
+ for word_idx, word in enumerate(self.candidates[var]["words"]):
87
+ for pos_idx, char in enumerate(word):
88
+ char_idx = self.char_map[char]
89
+ self.candidates[var]["bit_array"][char_idx, pos_idx, word_idx] = 1
90
+ self.candidates[var]["single_query_cache"][pos_idx][char].append(word)
91
+ self.candidates[var]["single_query_cache_indices"][pos_idx][char].append(word_idx)
92
+ # NOTE: TODO, it's possible to cache more here in exchange for doing more work at init time
93
+
94
+ # cleanup a bit
95
+ del dpr
96
+
97
+ def evaluate(self, solution):
98
+ # print puzzle accuracy results given a generated solution
99
+ letters_correct = 0
100
+ letters_total = 0
101
+ for i in range(len(self.crossword.letter_grid)):
102
+ for j in range(len(self.crossword.letter_grid[0])):
103
+ if self.crossword.letter_grid[i][j] != "":
104
+ letters_correct += (self.crossword.letter_grid[i][j] == solution[i][j])
105
+ letters_total += 1
106
+ words_correct = 0
107
+ words_total = 0
108
+ for var in self.crossword.variables:
109
+ cells = self.crossword.variables[var]["cells"]
110
+ matching_cells = [self.crossword.letter_grid[cell[0]][cell[1]] == solution[cell[0]][cell[1]] for cell in cells]
111
+ if len(cells) == sum(matching_cells):
112
+ words_correct += 1
113
+ else:
114
+ # print('evaluation: correct word', ''.join([self.crossword.letter_grid[cell[0]][cell[1]] for cell in cells]), 'our prediction:', ''.join([solution[cell[0]][cell[1]] for cell in cells]))
115
+ pass
116
+ words_total += 1
117
+
118
+ print("Letters Correct: {}/{} | Words Correct: {}/{}".format(int(letters_correct), int(letters_total), int(words_correct), int(words_total)))
119
+ print("Letters Correct: {}% | Words Correct: {}%".format(float(letters_correct/letters_total*100), float(words_correct/words_total*100)))
120
+
121
+ info = {
122
+ "total_letters" : int(letters_total),
123
+ "total_words" : int(words_total),
124
+ "correct_letters" : int(letters_correct),
125
+ "correct_words" : int(words_correct),
126
+ "correct_letters_percent" : float(letters_correct/letters_total*100),
127
+ "correct_words_percent" : float(words_correct/words_total*100),
128
+ }
129
+ return info
Strict_json.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def json_CA_json_converter(json_file_path, is_path):
4
+ if is_path:
5
+ with open(json_file_path, "r") as file:
6
+ data = json.load(file)
7
+ else:
8
+ data = json_file_path
9
+
10
+ json_conversion_dict = {}
11
+
12
+ rows = data['size']['rows']
13
+ cols = data['size']['cols']
14
+
15
+ clues = data['clues']
16
+ answers = data['answers']
17
+
18
+ json_conversion_dict['metadata'] = {'rows': rows, 'cols': cols}
19
+
20
+ across_clue_answer = {}
21
+ down_clue_answer = {}
22
+
23
+ for clue, ans in zip(clues['across'], answers['across']):
24
+ split_clue = clue.split(' ')
25
+ clue_num = split_clue[0][:-1]
26
+ clue_ = " ".join(split_clue[1:])
27
+ clue_ = clue_.replace("[", '').replace("]", '')
28
+ across_clue_answer[clue_num] = [clue_, ans]
29
+
30
+ for clue, ans in zip(clues['down'], answers['down']):
31
+ split_clue = clue.split(' ')
32
+ clue_num = split_clue[0][:-1]
33
+ clue_ = " ".join(split_clue[1:])
34
+ clue_ = clue_.replace("[", '').replace("]", '')
35
+ down_clue_answer[clue_num] = [clue_, ans]
36
+
37
+ json_conversion_dict['clues'] = {'across' : across_clue_answer, 'down' : down_clue_answer}
38
+
39
+ grid_info = data['grid']
40
+ grid_num = data['gridnums']
41
+
42
+ grid_info_list = []
43
+ for i in range(rows):
44
+ row_list = []
45
+ for j in range(cols):
46
+ if grid_info[i * rows + j] == '.':
47
+ row_list.append('BLACK')
48
+ else:
49
+ if grid_num[i * rows + j] == 0:
50
+ row_list.append(['', grid_info[i * rows + j]])
51
+ else:
52
+ row_list.append([str(grid_num[i * rows + j]), grid_info[i * rows + j]])
53
+ grid_info_list.append(row_list)
54
+
55
+ json_conversion_dict['grid'] = grid_info_list
56
+
57
+ return json_conversion_dict
Utils_inf.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import puz
3
+ import wordsegment
4
+ import math
5
+ from wordsegment import load, segment, clean
6
+ import os
7
+ load()
8
+
9
+ dictionary = set([a.strip() for a in open('./words_alpha.txt','r').readlines()])
10
+
11
+ def num_words(fill):
12
+ '''segment the text into multiple words and count how many words the text has in total'''
13
+ segmented = segment(fill)
14
+ prob = 0.0
15
+ for word in segmented:
16
+ if word not in dictionary:
17
+ return 999, -9999999999999
18
+ prob += math.log(wordsegment.UNIGRAMS[word])
19
+ return (len(segmented), prob)
20
+
21
+ def get_word_flips(fill, num_candidates=10):
22
+ '''
23
+ We take as input a word/phrase that is probably mispelled, something like iluveyou. We then try flipping each one of the letters
24
+ to all other letters. We then segment those texts into multiple words using num_words, e.g., iloveyou -> i love you. We return the candidates
25
+ that segment into the fewest number of words.
26
+ '''
27
+ results = {}
28
+ min_length = 999
29
+ fill = clean(fill)
30
+ for index, char in enumerate(fill):
31
+ for new_letter in 'abcdefghijklmnopqrstuvwxyz':
32
+ new_fill = list(fill)
33
+ new_fill[index] = new_letter
34
+ new_fill = ''.join(new_fill)
35
+ curr_num_words, prob = num_words(new_fill)
36
+ if curr_num_words not in results:
37
+ results[curr_num_words] = []
38
+ results[curr_num_words].append((new_fill, prob))
39
+ if curr_num_words < min_length:
40
+ min_length = curr_num_words
41
+ if min_length == 999:
42
+ return [fill.upper()]
43
+ all_results = sum([sorted(results[length], key=lambda x:-x[1]) for length in sorted(list(results.keys()))], [])
44
+ return [a[0].upper() for a in all_results[0:num_candidates]]
45
+
46
+ def convert_puz(fname):
47
+ # requires pypuz library to run
48
+ # converts a puzzle in .puz format to .json format
49
+ p = puz.read(fname)
50
+
51
+ numbering = p.clue_numbering()
52
+
53
+ grid = [[None for _ in range(p.width)] for _ in range(p.height)]
54
+ for row_idx in range(p.height):
55
+ cell = row_idx * p.width
56
+ row_solution = p.solution[cell:cell + p.width]
57
+ for col_index, item in enumerate(row_solution):
58
+ if p.solution[cell + col_index:cell + col_index + 1] == '.':
59
+ grid[row_idx][col_index] = 'BLACK'
60
+ else:
61
+ grid[row_idx][col_index] = ["", row_solution[col_index: col_index + 1]]
62
+
63
+ across_clues = {}
64
+ for clue in numbering.across:
65
+ answer = ''.join(p.solution[clue['cell'] + i] for i in range(clue['len']))
66
+ across_clues[str(clue['num'])] = [clue['clue'] + ' ', ' ' + answer]
67
+ grid[int(clue['cell'] / p.width)][clue['cell'] % p.width][0] = str(clue['num'])
68
+
69
+ down_clues = {}
70
+ for clue in numbering.down:
71
+ answer = ''.join(p.solution[clue['cell'] + i * numbering.width] for i in range(clue['len']))
72
+ down_clues[str(clue['num'])] = [clue['clue'] + ' ', ' ' + answer]
73
+ grid[int(clue['cell'] / p.width)][clue['cell'] % p.width][0] = str(clue['num'])
74
+
75
+
76
+ mydict = {'metadata': {'date': None, 'rows': p.height, 'cols': p.width}, 'clues': {'across': across_clues, 'down': down_clues}, 'grid': grid}
77
+ return mydict
78
+
79
+ def clean(text):
80
+ '''
81
+ :param text: question or answer text
82
+ :return: text with line breaks and trailing spaces removed
83
+ '''
84
+ return " ".join(text.strip().split())
85
+
86
+ def print_grid(letter_grid):
87
+ for row in letter_grid:
88
+ row = [" " if val == "" else val for val in row]
89
+ print("".join(row), flush=True)
extractpuzzle.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import math
4
+ from sklearn.linear_model import LinearRegression
5
+ import pytesseract
6
+ import re
7
+ import matplotlib.pyplot as plt
8
+
9
+ pytesseract.pytesseract.tesseract_cmd = 'C:/Program Files/Tesseract-OCR/tesseract.exe'
10
+ image_path = "try heree.jpg"
11
+
12
+ def first_preprocessing(image):
13
+ gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
14
+ canny = cv2.Canny(gray,75,25)
15
+ contours,hierarchies = cv2.findContours(canny,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE)
16
+ sorted_contours = sorted(contours,key = cv2.contourArea,reverse = True)
17
+ largest_contour = sorted_contours[0]
18
+ box = cv2.boundingRect(sorted_contours[0])
19
+ x = box[0]
20
+ y = box[1]
21
+ w = box[2]
22
+ h = box[3]
23
+ result = cv2.rectangle(image, (x, y), (x + w, y + h), (255, 255, 255), -1)
24
+ return result
25
+
26
+ def remove_head(image):
27
+ custom_config = r'--oem 3 --psm 6' # Tesseract OCR configuration
28
+ detected_text = pytesseract.image_to_string(image, config=custom_config)
29
+ lines = detected_text.split('\n')
30
+
31
+ # Find the first line containing some text
32
+ line_index = 0
33
+ for i, line in enumerate(lines):
34
+ if line.strip() != '':
35
+ line_index = i
36
+ break
37
+ first_newline_idx = detected_text.find('\n')
38
+ result = cv2.rectangle(image, (0, line_index), (image.shape[1], first_newline_idx), (255,255,255), thickness=cv2.FILLED)
39
+ return result
40
+
41
+ def second_preprocessing(image):
42
+ gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
43
+ canny = cv2.Canny(gray,75,25)
44
+ contours,hierarchies = cv2.findContours(canny,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE)
45
+ sorted_contours = sorted(contours,key = cv2.contourArea,reverse = True)
46
+ largest_contour = sorted_contours[0]
47
+ box2 = cv2.boundingRect(sorted_contours[0])
48
+ x = box2[0]
49
+ y = box2[1]
50
+ w = box2[2]
51
+ h = box2[3]
52
+ result2 = cv2.rectangle(image, (x, y), (x + w, y + h), (255, 255, 255), -1)
53
+ return result2
54
+
55
+ def find_vertical_profile(image):
56
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
57
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
58
+ vertical_profile = np.sum(binary, axis=0)
59
+ return vertical_profile
60
+
61
+ def detect_steepest_changes(projection_profile, threshold=0.4, start_idx=0, min_valley_width=10, min_search_width=50):
62
+ differences = np.diff(projection_profile)
63
+ change_points = np.where(np.abs(differences) > threshold * np.max(np.abs(differences)))[0]
64
+ left_boundaries = []
65
+ right_boundaries = []
66
+
67
+ for idx in change_points:
68
+ if idx <= start_idx:
69
+ continue
70
+
71
+ if idx - start_idx >= min_search_width:
72
+ decreasing_profile = projection_profile[idx:]
73
+ if np.any(decreasing_profile > 0):
74
+ right_boundary = idx + np.argmin(decreasing_profile)
75
+ right_boundaries.append(right_boundary)
76
+ else:
77
+ continue
78
+ valley_start = max(start_idx, idx - min_valley_width)
79
+ valley_start = valley_start-40
80
+ valley_end = min(idx + min_valley_width, len(projection_profile) - 1)
81
+ valley = valley_start + np.argmin(projection_profile[valley_start:valley_end])
82
+ left_boundaries.append(valley)
83
+
84
+ break
85
+
86
+ return left_boundaries, right_boundaries
87
+
88
+ def crop_text_columns(image, projection_profile, threshold=0.4):
89
+ start_idx = 0
90
+ text_columns = []
91
+
92
+ while True:
93
+ left_boundaries, right_boundaries = detect_steepest_changes(projection_profile, threshold, start_idx)
94
+ if not left_boundaries or not right_boundaries:
95
+ break
96
+ left = left_boundaries[0]
97
+ right = right_boundaries[0]
98
+ text_column = image[:, left:right]
99
+ text_columns.append(text_column)
100
+
101
+ start_idx = right
102
+
103
+ return text_columns
104
+
105
+
106
+ def parse_clues(clue_text):
107
+ lines = clue_text.split('\n')
108
+ clues = {}
109
+ number = None
110
+ column = 0
111
+ for line in lines:
112
+ if "column separation" in line:
113
+ column += 1
114
+ continue
115
+ pattern = r"^(\d+(?:\.\d+)?)\s*(.+)" # Updated pattern to handle decimal point numbers for clues
116
+ match = re.search(pattern, line)
117
+ if match:
118
+ number = float(match.group(1)) # Convert the matched number to float if there is a decimal point
119
+ if number not in clues:
120
+ clues[number] = [column,match.group(2).strip()]
121
+ else:
122
+ continue
123
+ elif number is None:
124
+ continue
125
+ elif clues[number][0] != column:
126
+ continue
127
+ else:
128
+ clues[number][1] += " " + line.strip() # Append to the previous clue if it's a multiline clue
129
+
130
+ return clues
131
+
132
+ def parse_crossword_clues(text):
133
+ # Check if "Down" clues are present
134
+ match = re.search(r'[dD][oO][wW][nN]\n', text)
135
+ if match:
136
+ across_clues, down_clues = re.split(r'[dD][oO][wW][nN]\n', text)
137
+ else:
138
+ # If "Down" clues are not present, set down_clues to an empty string
139
+ across_clues, down_clues = text, ""
140
+
141
+ across = parse_clues(across_clues)
142
+ down = parse_clues(down_clues)
143
+
144
+ return across, down
145
+
146
+
147
+ def classify_text(filtered_columns):
148
+ text = ""
149
+ custom_config = r'--oem 3 --psm 6'
150
+ for i, column in enumerate(filtered_columns):
151
+ column2 = cv2.cvtColor(column, cv2.COLOR_BGR2RGB)
152
+ scale_factor = 2.0 # You can adjust this value
153
+
154
+ # Calculate the new dimensions after scaling
155
+ new_width = int(column2.shape[1] * scale_factor)
156
+ new_height = int(column2.shape[0] * scale_factor)
157
+
158
+ # Resize the image using OpenCV
159
+ scaled_image = cv2.resize(column2, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
160
+
161
+ # Apply image enhancement techniques
162
+ denoised_image = cv2.fastNlMeansDenoising(scaled_image, None, h=10, templateWindowSize=7, searchWindowSize=21)
163
+ enhanced_image = cv2.cvtColor(denoised_image, cv2.COLOR_BGR2GRAY) # Convert to grayscale # Apply histogram equalization
164
+ detected_text = pytesseract.image_to_string(enhanced_image, config=custom_config)
165
+ # print(detected_text)
166
+ text+=detected_text
167
+ across_clues, down_clues = parse_crossword_clues(text)
168
+ return across_clues,down_clues
169
+
170
+ def get_text(image):
171
+ image = cv2.cvtColor(image,cv2.COLOR_GRAY2BGR)
172
+ result = first_preprocessing(image)
173
+ result1 = remove_head(result)
174
+ result2 = second_preprocessing(result1)
175
+ vertical_profile = find_vertical_profile(result2)
176
+ combined_columns = crop_text_columns(result2,vertical_profile)
177
+ across,down = classify_text(combined_columns)
178
+ return across,down
179
+
180
+
181
+ ################################ Grid Extraction begins here ###########################
182
+ ########################################################################################
183
+
184
+
185
+ # for applying non max suppression of the contours
186
+ def calculate_iou(image, contour1, contour2):
187
+ # Create masks for each contour
188
+ mask1 = np.zeros_like(image, dtype=np.uint8)
189
+ cv2.drawContours(mask1, [contour1], -1, 255, thickness=cv2.FILLED)
190
+
191
+ mask2 = np.zeros_like(image, dtype=np.uint8)
192
+ cv2.drawContours(mask2, [contour2], -1, 255, thickness=cv2.FILLED)
193
+
194
+ # Find the intersection between the two masks
195
+ intersection = cv2.bitwise_and(mask1, mask2)
196
+
197
+ # Calculate the intersection area
198
+ intersection_area = cv2.countNonZero(intersection)
199
+
200
+ # Calculate the union area (Not the accurate one but works alright XD !)
201
+ union_area = cv2.contourArea(cv2.convexHull(np.concatenate((contour1, contour2))))
202
+
203
+ # Calculate the IoU
204
+ iou = intersection_area / union_area
205
+ return iou
206
+
207
+ # remove overlapping contours, non square and not quardatic contours
208
+ # this check every contour with every other contour so be careful
209
+ def filter_contours(img_gray2, contours, iou_threshold = 0.6, asp_ratio = 1,tolerance = 0.5):
210
+ # Remove overlapping contours, removing that are not square
211
+ filtered_contours = []
212
+ epsilon = 0.02
213
+ for contour in contours:
214
+
215
+ # Approximate the contour to reduce the number of points
216
+ epsilon_multiplier = epsilon * cv2.arcLength(contour, True)
217
+ approximated_contour = cv2.approxPolyDP(contour, epsilon_multiplier, True)
218
+
219
+ # find the aspect ratio of the contour, if it is close to 1 then keep it otherwise discard
220
+ _,_,w,h = cv2.boundingRect(approximated_contour)
221
+ if(abs(float(w)/h - asp_ratio) > tolerance ): continue
222
+
223
+ # Calculate the IoU with all existing contours
224
+ iou_values = [calculate_iou(img_gray2,np.array(approximated_contour), np.array(existing_contour)) for existing_contour in filtered_contours]
225
+
226
+ # If the IoU value with all existing contours is below the threshold, add the current contour
227
+ if not any(iou_value > iou_threshold for iou_value in iou_values):
228
+ filtered_contours.append(approximated_contour)
229
+
230
+ return filtered_contours
231
+
232
+ # https://stackoverflow.com/questions/383480/intersection-of-two-lines-defined-in-rho-theta-parameterization/383527#383527
233
+ # Define the parametricIntersect function
234
+ def parametricIntersect(r1, t1, r2, t2):
235
+ ct1 = np.cos(t1)
236
+ st1 = np.sin(t1)
237
+ ct2 = np.cos(t2)
238
+ st2 = np.sin(t2)
239
+ d = ct1 * st2 - st1 * ct2
240
+ if d != 0.0:
241
+ x = int((st2 * r1 - st1 * r2) / d)
242
+ y = int((-ct2 * r1 + ct1 * r2) / d)
243
+ return x, y
244
+ else:
245
+ return None
246
+
247
+ # Group the coordinate to a list such that each point in a list may belong to a line
248
+ def group_lines(coordinates,axis=0,threshold=10):
249
+ sorted_coordinates = list(sorted(coordinates,key=lambda x: x[axis]))
250
+ groups = []
251
+ current_group = []
252
+
253
+ for i in range(len(sorted_coordinates)):
254
+ if i!=0 and abs(current_group[0][axis] - sorted_coordinates[i][axis]) > threshold: # condition to change the group
255
+ if len(current_group) > 4:
256
+ groups.append(current_group)
257
+ current_group = []
258
+ current_group.append(sorted_coordinates[i]) # condition to append to the group
259
+ if(len(current_group) > 4):
260
+ groups.append(current_group)
261
+ return groups
262
+
263
+ # Use the Grouped Lines to Fit a line using Linear Regression
264
+ def fit_lines(grouped_lines,is_horizontal = False):
265
+ actual_lines = []
266
+ for coordinates in grouped_lines:
267
+ # Converting into numpy array
268
+ coordinates_arr = np.array(coordinates)
269
+ # Separate the x and y coordinates
270
+ x = coordinates_arr[:, 0]
271
+ y = coordinates_arr[:, 1]
272
+ # Fit a linear regression model
273
+ regressor = LinearRegression()
274
+ regressor.fit(y.reshape(-1, 1), x)
275
+ # Get the slope and intercept of the fitted line
276
+ slope = regressor.coef_[0]
277
+ intercept = regressor.intercept_
278
+
279
+ if(is_horizontal):
280
+ intercept = np.mean(y)
281
+ actual_lines.append((slope,intercept))
282
+
283
+ return actual_lines
284
+
285
+ # Calculates difference between two consecutive elements in an array
286
+ def average_distance(arr):
287
+ n = len(arr)
288
+ distance_sum = 0
289
+
290
+ for i in range(n - 1):
291
+ distance_sum += abs(arr[i+1] - arr[i])
292
+
293
+ average = distance_sum / (n - 1)
294
+ return average
295
+
296
+ # If two adjacent lines are near than some threshold, then merge them
297
+ # Returns Results in y = mx + b from
298
+ def average_out_similar_lines(lines_m_c,lines_coord,del_threshold,is_horizontal=False):
299
+ averaged_lines = []
300
+ i = 0
301
+ while(i < len(lines_m_c) - 1):
302
+
303
+ _, intercept1 = lines_m_c[i]
304
+ _, intercept2 = lines_m_c[i + 1]
305
+
306
+ if abs(intercept2 - intercept1) < del_threshold:
307
+ new_points = np.array(lines_coord[i] + lines_coord[i+1][:-1])
308
+ # Separate the x and y coordinates
309
+ x = new_points[:, 0]
310
+ y = new_points[:, 1]
311
+
312
+ # Fit a linear regression model
313
+ regressor = LinearRegression()
314
+ regressor.fit(y.reshape(-1, 1), x)
315
+
316
+ # Get the slope and intercept of the fitted line
317
+ slope = regressor.coef_[0]
318
+ intercept = regressor.intercept_
319
+
320
+ if(is_horizontal):
321
+ intercept = np.mean(y)
322
+ averaged_lines.append((slope,intercept))
323
+ i+=2
324
+ else:
325
+ averaged_lines.append(lines_m_c[i])
326
+ i+=1
327
+ if(i < len(lines_m_c)):
328
+ averaged_lines.append(lines_m_c[i])
329
+
330
+ return averaged_lines
331
+
332
+ # If two adjacent lines are near than some threshold, then merge them
333
+ # Returns Results in normalized vector form
334
+ def average_out_similar_lines1(lines_m_c,lines_coord,del_threshold):
335
+ averaged_lines = []
336
+ i = 0
337
+ while(i < len(lines_m_c) - 1):
338
+
339
+ _, intercept1 = lines_m_c[i]
340
+ _, intercept2 = lines_m_c[i + 1]
341
+
342
+ if abs(intercept2 - intercept1) < del_threshold:
343
+ new_points = np.array(lines_coord[i] + lines_coord[i+1][:-1])
344
+ coordinates = np.array(new_points)
345
+ points = coordinates[:, None, :].astype(np.int32)
346
+ # Fit a line using linear regression
347
+ [vx, vy, x, y] = cv2.fitLine(points, cv2.DIST_L2, 0, 0.01, 0.01)
348
+ averaged_lines.append((vx, vy, x, y))
349
+ i+=2
350
+ else:
351
+ new_points = np.array(lines_coord[i])
352
+
353
+ coordinates = np.array(new_points)
354
+ points = coordinates[:, None, :].astype(np.int32)
355
+ # Fit a line using linear regression
356
+ [vx, vy, x, y] = cv2.fitLine(points, cv2.DIST_L2, 0, 0.01, 0.01)
357
+ averaged_lines.append((vx, vy, x, y))
358
+ i+=1
359
+ if(i < len(lines_m_c)):
360
+ new_points = np.array(lines_coord[i])
361
+ coordinates = np.array(new_points)
362
+ points = coordinates[:, None, :].astype(np.int32)
363
+ # Fit a line using linear regression
364
+ [vx, vy, x, y] = cv2.fitLine(points, cv2.DIST_L2, 0, 0.01, 0.01)
365
+ averaged_lines.append((vx, vy, x, y))
366
+
367
+ return averaged_lines
368
+
369
+ def get_square_color(image, box):
370
+
371
+ # Determine the size of the square region
372
+ square_size = (box[1][0] - box[0][0]) / 3
373
+
374
+ # Determine the coordinates of the square region inside the box
375
+ top_left = (box[0][0] + square_size, box[0][1] + square_size)
376
+ bottom_right = (box[0][0] + square_size*2, box[0][1] + square_size*2)
377
+
378
+ # Extract the square region from the image
379
+ square_region = image[int(top_left[1]):int(bottom_right[1]), int(top_left[0]):int(bottom_right[0])]
380
+
381
+ # Calculate the mean pixel value of the square region
382
+ mean_value = np.mean(square_region)
383
+
384
+ # Determine whether the square region is predominantly black or white
385
+ if mean_value < 128:
386
+ square_color = "."
387
+ else:
388
+ square_color = " "
389
+
390
+ return square_color
391
+
392
+ # accepts image in grayscale
393
+ def extract_grid(image):
394
+
395
+ # Apply Gaussian blur to reduce noise and improve edge detection
396
+ blurred = cv2.GaussianBlur(image, (3, 3), 0)
397
+ # Apply Canny edge detection
398
+ edges = cv2.Canny(blurred, 50, 150)
399
+
400
+ # Apply dilation to connect nearby edges and make them more contiguous
401
+ kernel = np.ones((5, 5), np.uint8)
402
+ dilated = cv2.dilate(edges, kernel, iterations=1)
403
+
404
+ # # Applying canny edge detector
405
+ # detecting contours on the canny image
406
+ contours, _ = cv2.findContours(dilated, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
407
+
408
+ # sorting the contours by the descending order area of the contour
409
+ sorted_contours = list(sorted(contours, key=cv2.contourArea,reverse=True))
410
+ # filtering out the top 10 largest by applying NMS and only selecting square ones (Apsect ratio 1)
411
+ filtered_contours = filter_contours(image, sorted_contours[0:10],iou_threshold=0.6,asp_ratio=1,tolerance=0.2)
412
+
413
+ # largest Contour Extraction
414
+ largest_contour = []
415
+ if(len(filtered_contours)):
416
+ largest_contour = filtered_contours[0]
417
+ else:
418
+ largest_contour = sorted_contours[0]
419
+
420
+ # --- Performing Perspective warp of the largest contour ---
421
+ coordinates_list = []
422
+
423
+ if(largest_contour.shape != (4,1,2)):
424
+ largest_contour = cv2.convexHull(largest_contour)
425
+ if(largest_contour.shape != (4,1,2)):
426
+ rect = cv2.minAreaRect(largest_contour)
427
+ largest_contour = cv2.boxPoints(rect)
428
+ largest_contour = largest_contour.astype('int')
429
+
430
+ coordinates_list = largest_contour.reshape(4, 2).tolist()
431
+
432
+ # Convert coordinates_list to a numpy array
433
+ coordinates_array = np.array(coordinates_list)
434
+
435
+ # Find the convex hull of the points
436
+ hull = cv2.convexHull(coordinates_array)
437
+
438
+ # Find the extreme points of the convex hull
439
+ extreme_points = np.squeeze(hull)
440
+
441
+ # Sort the extreme points by their x and y coordinates to determine the order
442
+ sorted_points = extreme_points[np.lexsort((extreme_points[:, 1], extreme_points[:, 0]))]
443
+
444
+ # Extract top left, bottom right, top right, and bottom left points
445
+ tl = sorted_points[0]
446
+ tr = sorted_points[1]
447
+ bl = sorted_points[2]
448
+ br = sorted_points[3]
449
+
450
+ if(tr[1] < tl[1]):
451
+ tl,tr = tr,tl
452
+ if(br[1] < bl[1]):
453
+ bl,br = br,bl
454
+
455
+ # Define pts1
456
+ pts1 = [tl, bl, tr, br]
457
+
458
+ # Calculate the bounding rectangle coordinates
459
+ x, y, w, h = 0,0,400,400
460
+ # Define pts2 as the corners of the bounding rectangle
461
+ pts2 = [[3, 3], [400, 3], [3, 400], [400, 400]]
462
+
463
+ # Calculate the perspective transformation matrix
464
+ matrix = cv2.getPerspectiveTransform(np.float32(pts1), np.float32(pts2))
465
+
466
+ # Apply the perspective transformation to the cropped_image
467
+ transformed_img = cv2.warpPerspective(image, matrix, (403, 403))
468
+ cropped_image = transformed_img.copy()
469
+
470
+ plt.figure(figsize=(12,8))
471
+ plt.axis("off")
472
+ plt.imsave("noice1.jpg",cv2.cvtColor(cropped_image,cv2.COLOR_GRAY2RGB))
473
+
474
+ # if the largest contour was not exactly quadilateral
475
+
476
+ # -- Performing Hough Transform --
477
+
478
+ similarity_threshold = math.floor(w/30) # Thresholds for filtering Similar Hough Lines
479
+
480
+ # Applying Gaussian Blur to reduce noice and improve dege detection
481
+ blurred = cv2.GaussianBlur(cropped_image, (5, 5), 0)
482
+ # Perform Canny edge detection on the GrayScale Image
483
+ edges = cv2.Canny(blurred, 50, 150)
484
+ lines = cv2.HoughLines(edges, 1, np.pi/180, 200)
485
+
486
+ # Filter out similar lines
487
+ filtered_lines = []
488
+ for line in lines:
489
+ for r_theta in lines:
490
+ arr = np.array(r_theta[0], dtype=np.float64)
491
+ rho, theta = arr
492
+ is_similar = False
493
+ for filtered_line in filtered_lines:
494
+ filtered_rho, filtered_theta = filtered_line
495
+ # similarity threshold is 10
496
+ if abs(rho - filtered_rho) < similarity_threshold and abs(theta - filtered_theta) < np.pi/180 * similarity_threshold:
497
+ is_similar = True
498
+ break
499
+ if not is_similar:
500
+ filtered_lines.append((rho, theta))
501
+
502
+ # Filter out the horizontal and the vertical lines
503
+ horizontal_lines = []
504
+ vertical_lines = []
505
+ for rho, theta in filtered_lines:
506
+ a = np.cos(theta)
507
+ b = np.sin(theta)
508
+ x0 = a * rho
509
+ y0 = b * rho
510
+ x1 = int(x0 + 1000 * (-b))
511
+ y1 = int(y0 + 1000 * (a))
512
+ x2 = int(x0 - 1000 * (-b))
513
+ y2 = int(y0 - 1000 * (a))
514
+
515
+ slope = (y2 - y1) / (x2 - x1 + 0.0001)
516
+ # do taninv(0.17) it is nearly equal to 10
517
+ if( abs(slope) <= 0.18 ):
518
+ horizontal_lines.append((rho,theta))
519
+ elif (abs(slope) > 6):
520
+ vertical_lines.append((rho,theta))
521
+
522
+ # Find the intersection points of horizontal and vertical lines
523
+ hough_corners = []
524
+ for h_rho, h_theta in horizontal_lines:
525
+ for v_rho, v_theta in vertical_lines:
526
+ x, y = parametricIntersect(h_rho, h_theta, v_rho, v_theta)
527
+ if x is not None and y is not None:
528
+ hough_corners.append((x, y))
529
+
530
+ # -- Performing Harris Corner Detection --
531
+
532
+ # Create CLAHE object with specified clip limit
533
+ clahe = cv2.createCLAHE(clipLimit=3, tileGridSize=(8, 8))
534
+ clahe_image = clahe.apply(cropped_image)
535
+
536
+ # harris corner detection for CLHAE IMAGE
537
+ dst = cv2.cornerHarris(clahe_image,2,3,0.04)
538
+ ret,dst = cv2.threshold(dst,0.1*dst.max(),255,0)
539
+ dst = np.uint8(dst)
540
+ dst = cv2.dilate(dst,None)
541
+ ret, labels, stats, centroids = cv2.connectedComponentsWithStats(dst)
542
+ criteria = (cv2.TERM_CRITERIA_EPS+cv2.TermCriteria_MAX_ITER,100,0.001)
543
+ harris_corners = cv2.cornerSubPix(clahe_image,np.float32(centroids),(5,5),(-1,-1),criteria)
544
+
545
+ drawn_image = cv2.cvtColor(cropped_image, cv2.COLOR_GRAY2BGR)
546
+ for i in harris_corners:
547
+ x,y = i
548
+ image2 = cv2.circle(drawn_image, (int(x),int(y)), radius=0, color=(0, 0, 255), thickness=3)
549
+
550
+ # -- Using Regression Model to approximate horizontal and vertical Lines
551
+
552
+ # reducing to 0 decimal places
553
+ corners1 = list(map(lambda coord: (round(coord[0], 0), round(coord[1], 0)), harris_corners))
554
+
555
+ # adding the corners obtained from hough transform
556
+ corners1 += hough_corners
557
+
558
+ # removing the duplicate corners
559
+ corners_no_dup = list(set(corners1))
560
+
561
+ min_cell_width = w/30
562
+ min_cell_height = h/30
563
+
564
+ # grouping coordinates into probabale array that could fit a horizontal and vertical lien
565
+ vertical_lines = group_lines(corners_no_dup,0,min_cell_height)
566
+ horizontal_lines = group_lines(corners_no_dup,1,min_cell_height)
567
+
568
+ actual_vertical_lines = fit_lines(vertical_lines)
569
+ actual_horizontal_lines = fit_lines(horizontal_lines,is_horizontal=True)
570
+
571
+
572
+ # Lines obtained from above method are not appropriate, we have to refine them
573
+
574
+ x_probable = [i[1] for i in actual_horizontal_lines] # looking at the intercepts
575
+ y_probable = [i[1] for i in actual_vertical_lines]
576
+
577
+ del_x_avg = average_distance(x_probable)
578
+ del_y_avg = average_distance(y_probable)
579
+
580
+ averaged_horizontal_lines1 = [] # This step here is fishy and needs refinement
581
+ averaged_vertical_lines1 = []
582
+ multiplier = 0.95
583
+ i = 0
584
+ while(1):
585
+ averaged_horizontal_lines = average_out_similar_lines(actual_horizontal_lines,horizontal_lines,del_y_avg*multiplier,is_horizontal=True)
586
+ averaged_vertical_lines = average_out_similar_lines(actual_vertical_lines,vertical_lines,del_x_avg*multiplier,is_horizontal=False)
587
+ i += 1
588
+ if(i >= 20 or len(averaged_horizontal_lines) == len(averaged_vertical_lines)):
589
+ break
590
+ else:
591
+ multiplier -= 0.05
592
+
593
+ averaged_horizontal_lines1 = average_out_similar_lines1(actual_horizontal_lines,horizontal_lines,del_y_avg*multiplier)
594
+ averaged_vertical_lines1 = average_out_similar_lines1(actual_vertical_lines,vertical_lines,del_x_avg*multiplier)
595
+
596
+
597
+ # plotting the lines to image to find the intersection points
598
+ drawn_image6 = np.ones_like(cropped_image)*255
599
+ for vx,vy,cx,cy in averaged_horizontal_lines1 + averaged_vertical_lines1:
600
+ w = cropped_image.shape[1]
601
+ cv2.line(drawn_image6, (int(cx-vx*w), int(cy-vy*w)), (int(cx+vx*w), int(cy+vy*w)), (0, 0, 255),1,cv2.LINE_AA)
602
+
603
+ # -- Finding Intersection points --
604
+
605
+ # Applying Harris Corner Detection to find the intersection points
606
+ mesh_image = drawn_image6.copy()
607
+ dst = cv2.cornerHarris(mesh_image,2,3,0.04)
608
+
609
+ ret,dst = cv2.threshold(dst,0.1*dst.max(),255,0)
610
+ dst = np.uint8(dst)
611
+ dst = cv2.dilate(dst,None)
612
+ ret, labels, stats, centroids = cv2.connectedComponentsWithStats(dst)
613
+ criteria = (cv2.TERM_CRITERIA_EPS+cv2.TermCriteria_MAX_ITER,100,0.001)
614
+ harris_corners = cv2.cornerSubPix(mesh_image,np.float32(centroids),(5,5),(-1,-1),criteria)
615
+ drawn_image = cv2.cvtColor(drawn_image6, cv2.COLOR_GRAY2BGR)
616
+ harris_corners = list(sorted(harris_corners[1:],key = lambda x : x[1]))
617
+
618
+ # -- Finding out the grid color --
619
+
620
+
621
+ grayscale = cropped_image.copy()
622
+ # Perform adaptive thresholding to obtain binary image
623
+ _, binary = cv2.threshold(grayscale, 128, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
624
+
625
+ # Perform morphological operations to remove small text regions
626
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
627
+ binary = cv2.morphologyEx(binary, cv2.MORPH_ELLIPSE, kernel, iterations=1)
628
+
629
+ # Invert the binary image
630
+ inverted_binary = cv2.bitwise_not(binary)
631
+
632
+ # Restore the image by blending the inverted binary image with the grayscale image
633
+ restored_image = cv2.bitwise_or(inverted_binary, grayscale)
634
+
635
+ # Apply morphological opening to remove small black dots
636
+ kernel_opening = np.ones((3, 3), np.uint8)
637
+ opened_image = cv2.morphologyEx(restored_image, cv2.MORPH_OPEN, kernel_opening, iterations=1)
638
+
639
+ # Apply morphological closing to further refine the restored image
640
+ kernel_closing = np.ones((5, 5), np.uint8)
641
+ refined_image = cv2.morphologyEx(opened_image, cv2.MORPH_CLOSE, kernel_closing, iterations=1)
642
+
643
+ # finding out the grid corner
644
+ grid = []
645
+ grid_nums = []
646
+ across_clue_num = []
647
+ down_clue_num = []
648
+
649
+ sorted_corners = np.array(list(sorted(harris_corners,key=lambda x:x[1])))
650
+ if(len(sorted_corners) == len(averaged_horizontal_lines1) * len(averaged_vertical_lines1)):
651
+ sorted_corners_grouped = []
652
+ for i in range(0,len(sorted_corners),len(averaged_vertical_lines1)):
653
+ temp_arr = sorted_corners[i:i+len(averaged_vertical_lines1)]
654
+ temp_arr = list(sorted(temp_arr,key=lambda x: x[0]))
655
+ sorted_corners_grouped.append(temp_arr)
656
+
657
+ for h_line_idx in range(0,len(sorted_corners_grouped)-1):
658
+ for corner_idx in range(0,len(sorted_corners_grouped[h_line_idx])-1):
659
+ # grabbing the four box coordinates
660
+ box = [sorted_corners_grouped[h_line_idx][corner_idx],sorted_corners_grouped[h_line_idx][corner_idx+1],
661
+ sorted_corners_grouped[h_line_idx+1][corner_idx],sorted_corners_grouped[h_line_idx+1][corner_idx+1]]
662
+ grid.append(get_square_color(refined_image,box))
663
+
664
+ grid_formatted = []
665
+ for i in range(0, len(grid), len(averaged_vertical_lines1) - 1):
666
+ grid_formatted.append(grid[i:i + len(averaged_vertical_lines1) - 1])
667
+
668
+
669
+ # if (x,y) is present in these array the cell (x,y) is already accounted as a part of answer of across or down
670
+ in_horizontal = []
671
+ in_vertical = []
672
+
673
+ num = 0
674
+
675
+
676
+
677
+ for x in range(0, len(averaged_vertical_lines1) - 1):
678
+ for y in range(0, len(averaged_horizontal_lines1) - 1):
679
+
680
+ # if the cell is black there's no need to number
681
+ if grid_formatted[x][y] == '.':
682
+ grid_nums.append(0)
683
+ continue
684
+
685
+ # if the cell is part of both horizontal and vertical cell then there's no need to number
686
+ horizontal_presence = (x, y) in in_horizontal
687
+ vertical_presence = (x, y) in in_vertical
688
+
689
+ # present in both 1 1
690
+ if horizontal_presence and vertical_presence:
691
+ grid_nums.append(0)
692
+ continue
693
+
694
+ # present in one i.e 1 0
695
+ if not horizontal_presence and vertical_presence:
696
+ horizontal_length = 0
697
+ temp_horizontal_arr = []
698
+ # iterate in x direction until the end of the grid or until a black box is found
699
+ while x + horizontal_length < len(averaged_horizontal_lines1) - 1 and grid_formatted[x + horizontal_length][y] != '.':
700
+ temp_horizontal_arr.append((x + horizontal_length, y))
701
+ horizontal_length += 1
702
+ # if horizontal length is greater than 1, then append the temp_horizontal_arr to in_horizontal array
703
+ if horizontal_length > 1:
704
+ in_horizontal.extend(temp_horizontal_arr)
705
+ num += 1
706
+ across_clue_num.append(num)
707
+ grid_nums.append(num)
708
+ continue
709
+ grid_nums.append(0)
710
+ # present in one 1 0
711
+ if not vertical_presence and horizontal_presence:
712
+ # do the same for vertical
713
+ vertical_length = 0
714
+ temp_vertical_arr = []
715
+ # iterate in y direction until the end of the grid or until a black box is found
716
+ while y + vertical_length < len(averaged_vertical_lines1) - 1 and grid_formatted[x][y+vertical_length] != '.':
717
+ temp_vertical_arr.append((x, y+vertical_length))
718
+ vertical_length += 1
719
+ # if vertical length is greater than 1, then append the temp_vertical_arr to in_vertical array
720
+ if vertical_length > 1:
721
+ in_vertical.extend(temp_vertical_arr)
722
+ num += 1
723
+ down_clue_num.append(num)
724
+ grid_nums.append(num)
725
+ continue
726
+ grid_nums.append(0)
727
+
728
+ if(not horizontal_presence and not vertical_presence):
729
+
730
+ horizontal_length = 0
731
+ temp_horizontal_arr = []
732
+ # iterate in x direction until the end of the grid or until a black box is found
733
+ while x + horizontal_length < len(averaged_horizontal_lines1) - 1 and grid_formatted[x + horizontal_length][y] != '.':
734
+ temp_horizontal_arr.append((x + horizontal_length, y))
735
+ horizontal_length += 1
736
+ # if horizontal length is greater than 1, then append the temp_horizontal_arr to in_horizontal array
737
+
738
+ # do the same for vertical
739
+ vertical_length = 0
740
+ temp_vertical_arr = []
741
+ # iterate in y direction until the end of the grid or until a black box is found
742
+ while y + vertical_length < len(averaged_vertical_lines1) - 1 and grid_formatted[x][y+vertical_length] != '.':
743
+ temp_vertical_arr.append((x, y+vertical_length))
744
+ vertical_length += 1
745
+ # if vertical length is greater than 1, then append the temp_vertical_arr to in_vertical array
746
+
747
+ if horizontal_length > 1 and horizontal_length > 1:
748
+ in_horizontal.extend(temp_horizontal_arr)
749
+ in_vertical.extend(temp_vertical_arr)
750
+ num += 1
751
+ across_clue_num.append(num)
752
+ down_clue_num.append(num)
753
+ grid_nums.append(num)
754
+ elif vertical_length > 1:
755
+ in_vertical.extend(temp_vertical_arr)
756
+ num += 1
757
+ down_clue_num.append(num)
758
+ grid_nums.append(num)
759
+ elif horizontal_length > 1:
760
+ in_horizontal.extend(temp_horizontal_arr)
761
+ num += 1
762
+ across_clue_num.append(num)
763
+ grid_nums.append(num)
764
+ else:
765
+ grid_nums.append(0)
766
+
767
+
768
+ size = { 'rows' : len(averaged_horizontal_lines1)-1,
769
+ 'cols' : len(averaged_vertical_lines1)-1,
770
+ }
771
+
772
+ dict = {
773
+ 'size' : size,
774
+ 'grid' : grid,
775
+ 'gridnums': grid_nums,
776
+ 'across_nums': down_clue_num,
777
+ 'down_nums' : across_clue_num,
778
+ 'clues':{
779
+ 'across' : [],
780
+ 'down': []
781
+ }
782
+ }
783
+
784
+ return dict
785
+
786
+ if __name__ == "__main__":
787
+ img = cv2.imread("D:\\D\\Major Project files\\opencv\\movie.png",0)
788
+ down = extract_grid(img)
789
+ print(down)
790
+ # img = Image.open("chalena3.jpg")
791
+ # img_gray = img.convert("L")
792
+ # print(extract_grid(img_gray))
main.py CHANGED
@@ -1,6 +1,31 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  app = FastAPI()
3
 
4
  @app.get("/")
5
  async def index():
6
- return {"message": "Hello World"}
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Request,FastAPI
2
+ import os
3
+ from Crossword_inf import Crossword
4
+ from BPSolver_inf import BPSolver
5
+ from Strict_json import json_CA_json_converter
6
+
7
+ import json
8
+
9
+ MODEL_PATH = os.path.join("Inference_components","dpr_biencoder_trained_500k.bin")
10
+ ANS_TSV_PATH = os.path.join("Inference_components","all_answer_list.tsv")
11
+ DENSE_EMBD_PATH = os.path.join("Inference_components","embeddings_all_answers_json_0*")
12
+
13
+ MODEL_PATH_DISTIL = os.path.join("Inference_components","distilbert_EPOCHs_7_COMPLETE.bin")
14
+ ANS_TSV_PATH_DISTIL = os.path.join("Inference_components","all_answer_list.tsv")
15
+ DENSE_EMBD_PATH_DISTIL = os.path.join("Inference_components","distilbert_7_epochs_embeddings.pkl")
16
+
17
+
18
  app = FastAPI()
19
 
20
  @app.get("/")
21
  async def index():
22
+ return {"message": "Hello World"}
23
+
24
+ @app.post("/solve")
25
+ async def solve(request: Request):
26
+ json = await request.json()
27
+ puzzle = json_CA_json_converter(json, False)
28
+ crossword = Crossword(puzzle)
29
+ solver = BPSolver(crossword, model_path = MODEL_PATH_DISTIL, ans_tsv_path = ANS_TSV_PATH_DISTIL, dense_embd_path = DENSE_EMBD_PATH_DISTIL, max_candidates = 40000, model_type = 'distilbert')
30
+ solution = solver.solve(num_iters = 100, iterative_improvement_steps = 0)
31
+ return solution, solver.evaluate(solution)
models/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def init_hf_bert_biencoder(args, **kwargs):
2
+ from .hf_models import get_bert_biencoder_components
3
+ return get_bert_biencoder_components(args, **kwargs)
4
+
5
+ def init_hf_distilbert_biencoder(args, **kwargs):
6
+ from .hf_models import get_distilbert_biencoder_components
7
+ return get_distilbert_biencoder_components(args, **kwargs)
8
+
9
+ def init_hf_bert_tenzorizer(args, **kwargs):
10
+ from .hf_models import get_bert_tensorizer
11
+ return get_bert_tensorizer(args)
12
+
13
+ def init_hf_distilbert_tenzorizer(args, **kwargs):
14
+ from .hf_models import get_distilbert_tensorizer
15
+ return get_distilbert_tensorizer(args)
16
+
17
+
18
+ BIENCODER_INITIALIZERS = {
19
+ 'hf_bert': init_hf_bert_biencoder,
20
+ 'hf_distilbert': init_hf_distilbert_biencoder
21
+ }
22
+
23
+ TENSORIZER_INITIALIZERS = {
24
+ 'hf_bert': init_hf_bert_tenzorizer,
25
+ 'hf_distilbert': init_hf_distilbert_tenzorizer
26
+ }
27
+
28
+ def init_comp(initializers_dict, type, args, **kwargs):
29
+ if type in initializers_dict:
30
+ return initializers_dict[type](args, **kwargs)
31
+ else:
32
+ raise RuntimeError('unsupported model type: {}'.format(type))
33
+
34
+ def init_biencoder_components(encoder_type: str, args, **kwargs):
35
+ return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs)
36
+
37
+ def init_tenzorizer(encoder_type: str, args, **kwargs):
38
+ return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs)
models/biencoder.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import random
4
+ from typing import Tuple, List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor as T
10
+ from torch import nn
11
+
12
+ import sys
13
+ import os
14
+
15
+ current_dir = os.path.dirname(__file__)
16
+ data_utils_path = os.path.join(current_dir, '..')
17
+ sys.path.append(data_utils_path)
18
+
19
+ from Data_utils_inf import Tensorizer
20
+ from Data_utils_inf import normalize_question
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ BiEncoderBatch = collections.namedtuple(
25
+ "BiENcoderInput",
26
+ [
27
+ "question_ids",
28
+ "question_segments",
29
+ "context_ids",
30
+ "ctx_segments",
31
+ "is_positive",
32
+ "hard_negatives",
33
+ ],
34
+ )
35
+
36
+ def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T:
37
+ """
38
+ calculates q->ctx scores for every row in ctx_vector
39
+ :param q_vector:
40
+ :param ctx_vector:
41
+ :return:
42
+ """
43
+ # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
44
+ r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
45
+ return r
46
+
47
+
48
+ def cosine_scores(q_vector: T, ctx_vectors: T):
49
+ # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
50
+ return F.cosine_similarity(q_vector, ctx_vectors, dim=1)
51
+
52
+
53
+ class BiEncoder(nn.Module):
54
+ """Bi-Encoder model component. Encapsulates query/question and context/passage encoders."""
55
+
56
+ def __init__(
57
+ self,
58
+ question_model: nn.Module,
59
+ ctx_model: nn.Module,
60
+ fix_q_encoder: bool = False,
61
+ fix_ctx_encoder: bool = False,
62
+ ):
63
+ super(BiEncoder, self).__init__()
64
+ self.question_model = question_model
65
+ self.ctx_model = ctx_model
66
+ self.fix_q_encoder = fix_q_encoder
67
+ self.fix_ctx_encoder = fix_ctx_encoder
68
+
69
+ @staticmethod
70
+ def get_representation(
71
+ sub_model: nn.Module,
72
+ ids: T,
73
+ segments: T,
74
+ attn_mask: T,
75
+ fix_encoder: bool = False,
76
+ ) -> (T, T, T):
77
+ sequence_output = None
78
+ pooled_output = None
79
+ hidden_states = None
80
+ if ids is not None:
81
+ if fix_encoder:
82
+ with torch.no_grad():
83
+ sequence_output, pooled_output, hidden_states = sub_model(
84
+ ids, segments, attn_mask
85
+ )
86
+
87
+ if sub_model.training:
88
+ sequence_output.requires_grad_(requires_grad=True)
89
+ pooled_output.requires_grad_(requires_grad=True)
90
+ else:
91
+ sequence_output, pooled_output, hidden_states = sub_model(
92
+ ids, segments, attn_mask
93
+ )
94
+
95
+ return sequence_output, pooled_output, hidden_states
96
+
97
+ def forward(
98
+ self,
99
+ question_ids: T,
100
+ question_segments: T,
101
+ question_attn_mask: T,
102
+ context_ids: T,
103
+ ctx_segments: T,
104
+ ctx_attn_mask: T,
105
+ ) -> Tuple[T, T]:
106
+
107
+ _q_seq, q_pooled_out, _q_hidden = self.get_representation(
108
+ self.question_model,
109
+ question_ids,
110
+ question_segments,
111
+ question_attn_mask,
112
+ self.fix_q_encoder,
113
+ )
114
+ _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation(
115
+ self.ctx_model,
116
+ context_ids,
117
+ ctx_segments,
118
+ ctx_attn_mask,
119
+ self.fix_ctx_encoder,
120
+ )
121
+
122
+ return q_pooled_out, ctx_pooled_out
123
+
124
+ @classmethod
125
+ def create_biencoder_input(
126
+ cls,
127
+ samples: List,
128
+ tensorizer: Tensorizer,
129
+ insert_title: bool,
130
+ num_hard_negatives: int = 0,
131
+ num_other_negatives: int = 0,
132
+ shuffle: bool = True,
133
+ shuffle_positives: bool = False,
134
+ do_lower_fill: bool = False,
135
+ desegment_valid_fill: bool =False
136
+ ) -> BiEncoderBatch:
137
+ """
138
+ Creates a batch of the biencoder training tuple.
139
+ :param samples: list of data items (from json) to create the batch for
140
+ :param tensorizer: components to create model input tensors from a text sequence
141
+ :param insert_title: enables title insertion at the beginning of the context sequences
142
+ :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
143
+ :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
144
+ :param shuffle: shuffles negative passages pools
145
+ :param shuffle_positives: shuffles positive passages pools
146
+ :return: BiEncoderBatch tuple
147
+ """
148
+ question_tensors = []
149
+ ctx_tensors = []
150
+ positive_ctx_indices = []
151
+ hard_neg_ctx_indices = []
152
+
153
+ for sample in samples:
154
+ # ctx+ & [ctx-] composition
155
+ # as of now, take the first(gold) ctx+ only
156
+ if shuffle and shuffle_positives:
157
+ positive_ctxs = sample["positive_ctxs"]
158
+ positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))]
159
+ else:
160
+ positive_ctx = sample["positive_ctxs"][0]
161
+ if do_lower_fill:
162
+ positive_ctx["text"] = positive_ctx["text"].lower()
163
+ neg_ctxs = sample["negative_ctxs"]
164
+ hard_neg_ctxs = sample["hard_negative_ctxs"]
165
+ if do_lower_fill:
166
+ neg_ctxs, hard_neg_ctxs = list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, neg_ctxs)), list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, hard_neg_ctxs))
167
+ question = normalize_question(sample["question"])
168
+
169
+ if shuffle:
170
+ random.shuffle(neg_ctxs)
171
+ random.shuffle(hard_neg_ctxs)
172
+
173
+ neg_ctxs = neg_ctxs[0:num_other_negatives]
174
+ hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]
175
+
176
+ all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
177
+ hard_negatives_start_idx = 1
178
+ hard_negatives_end_idx = 1 + len(hard_neg_ctxs)
179
+
180
+ current_ctxs_len = len(ctx_tensors)
181
+
182
+ sample_ctxs_tensors = [
183
+ tensorizer.text_to_tensor(
184
+ ctx["text"], title=ctx["title"] if insert_title else None
185
+ )
186
+ for ctx in all_ctxs
187
+ ]
188
+
189
+ ctx_tensors.extend(sample_ctxs_tensors)
190
+ positive_ctx_indices.append(current_ctxs_len)
191
+ hard_neg_ctx_indices.append(
192
+ [
193
+ i
194
+ for i in range(
195
+ current_ctxs_len + hard_negatives_start_idx,
196
+ current_ctxs_len + hard_negatives_end_idx,
197
+ )
198
+ ]
199
+ )
200
+
201
+ question_tensors.append(tensorizer.text_to_tensor(question))
202
+
203
+ ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0)
204
+ questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0)
205
+
206
+ ctx_segments = torch.zeros_like(ctxs_tensor)
207
+ question_segments = torch.zeros_like(questions_tensor)
208
+
209
+ return BiEncoderBatch(
210
+ questions_tensor,
211
+ question_segments,
212
+ ctxs_tensor,
213
+ ctx_segments,
214
+ positive_ctx_indices,
215
+ hard_neg_ctx_indices,
216
+ )
217
+
218
+ class DistilBertBiEncoder(nn.Module):
219
+ """Bi-Encoder model component. Encapsulates query/question and context/passage encoders."""
220
+
221
+ def __init__(
222
+ self,
223
+ question_model: nn.Module,
224
+ ctx_model: nn.Module,
225
+ fix_q_encoder: bool = False,
226
+ fix_ctx_encoder: bool = False,
227
+ ):
228
+ super(DistilBertBiEncoder, self).__init__()
229
+ self.question_model = question_model
230
+ self.ctx_model = ctx_model
231
+ self.fix_q_encoder = fix_q_encoder
232
+ self.fix_ctx_encoder = fix_ctx_encoder
233
+
234
+ @staticmethod
235
+ def get_representation(
236
+ sub_model: nn.Module,
237
+ ids: T,
238
+ segments: T,
239
+ attn_mask: T,
240
+ fix_encoder: bool = False,
241
+ ) -> (T, T, T):
242
+ sequence_output = None
243
+ pooled_output = None
244
+ hidden_states = None
245
+ if ids is not None:
246
+ if fix_encoder:
247
+ with torch.no_grad():
248
+ sequence_output, pooled_output, hidden_states = sub_model(
249
+ # ids, segments, attn_mask
250
+ ids, attn_mask
251
+ )
252
+
253
+ if sub_model.training:
254
+ sequence_output.requires_grad_(requires_grad=True)
255
+ pooled_output.requires_grad_(requires_grad=True)
256
+ else:
257
+ sequence_output, pooled_output, hidden_states = sub_model(
258
+ # ids, segments, attn_mask
259
+ ids, attn_mask
260
+ )
261
+
262
+ return sequence_output, pooled_output, hidden_states
263
+
264
+ def forward(
265
+ self,
266
+ question_ids: T,
267
+ question_segments: T,
268
+ question_attn_mask: T,
269
+ context_ids: T,
270
+ ctx_segments: T,
271
+ ctx_attn_mask: T,
272
+ ) -> Tuple[T, T]:
273
+
274
+ _q_seq, q_pooled_out, _q_hidden = self.get_representation(
275
+ self.question_model,
276
+ question_ids,
277
+ question_segments,
278
+ question_attn_mask,
279
+ self.fix_q_encoder,
280
+ )
281
+ _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation(
282
+ self.ctx_model,
283
+ context_ids,
284
+ ctx_segments,
285
+ ctx_attn_mask,
286
+ self.fix_ctx_encoder,
287
+ )
288
+
289
+ return q_pooled_out, ctx_pooled_out
290
+
291
+ @classmethod
292
+ def create_biencoder_input(
293
+ cls,
294
+ samples: List,
295
+ tensorizer: Tensorizer,
296
+ insert_title: bool,
297
+ num_hard_negatives: int = 0,
298
+ num_other_negatives: int = 0,
299
+ shuffle: bool = True,
300
+ shuffle_positives: bool = False,
301
+ do_lower_fill: bool = False,
302
+ desegment_valid_fill: bool =False
303
+ ) -> BiEncoderBatch:
304
+ """
305
+ Creates a batch of the biencoder training tuple.
306
+ :param samples: list of data items (from json) to create the batch for
307
+ :param tensorizer: components to create model input tensors from a text sequence
308
+ :param insert_title: enables title insertion at the beginning of the context sequences
309
+ :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
310
+ :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
311
+ :param shuffle: shuffles negative passages pools
312
+ :param shuffle_positives: shuffles positive passages pools
313
+ :return: BiEncoderBatch tuple
314
+ """
315
+ question_tensors = []
316
+ ctx_tensors = []
317
+ positive_ctx_indices = []
318
+ hard_neg_ctx_indices = []
319
+
320
+ for sample in samples:
321
+ # ctx+ & [ctx-] composition
322
+ # as of now, take the first(gold) ctx+ only
323
+ if shuffle and shuffle_positives:
324
+ positive_ctxs = sample["positive_ctxs"]
325
+ positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))]
326
+ else:
327
+ positive_ctx = sample["positive_ctxs"][0]
328
+ if do_lower_fill:
329
+ positive_ctx["text"] = positive_ctx["text"].lower()
330
+ neg_ctxs = sample["negative_ctxs"]
331
+ hard_neg_ctxs = sample["hard_negative_ctxs"]
332
+ if do_lower_fill:
333
+ neg_ctxs, hard_neg_ctxs = list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, neg_ctxs)), list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, hard_neg_ctxs))
334
+ question = normalize_question(sample["question"])
335
+
336
+ if shuffle:
337
+ random.shuffle(neg_ctxs)
338
+ random.shuffle(hard_neg_ctxs)
339
+
340
+ neg_ctxs = neg_ctxs[0:num_other_negatives]
341
+ hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]
342
+
343
+ all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
344
+ hard_negatives_start_idx = 1
345
+ hard_negatives_end_idx = 1 + len(hard_neg_ctxs)
346
+
347
+ current_ctxs_len = len(ctx_tensors)
348
+
349
+ sample_ctxs_tensors = [
350
+ tensorizer.text_to_tensor(
351
+ ctx["text"], title=ctx["title"] if insert_title else None
352
+ )
353
+ for ctx in all_ctxs
354
+ ]
355
+
356
+ ctx_tensors.extend(sample_ctxs_tensors)
357
+ positive_ctx_indices.append(current_ctxs_len)
358
+ hard_neg_ctx_indices.append(
359
+ [
360
+ i
361
+ for i in range(
362
+ current_ctxs_len + hard_negatives_start_idx,
363
+ current_ctxs_len + hard_negatives_end_idx,
364
+ )
365
+ ]
366
+ )
367
+
368
+ question_tensors.append(tensorizer.text_to_tensor(question))
369
+
370
+ ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0)
371
+ questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0)
372
+
373
+ ctx_segments = torch.zeros_like(ctxs_tensor)
374
+ question_segments = torch.zeros_like(questions_tensor)
375
+
376
+ return BiEncoderBatch(
377
+ questions_tensor,
378
+ question_segments,
379
+ ctxs_tensor,
380
+ ctx_segments,
381
+ positive_ctx_indices,
382
+ hard_neg_ctx_indices,
383
+ )
384
+
385
+
386
+ class BiEncoderNllLoss(object):
387
+ def calc(
388
+ self,
389
+ q_vectors: T,
390
+ ctx_vectors: T,
391
+ positive_idx_per_question: list,
392
+ hard_negatice_idx_per_question: list = None,
393
+ ) -> Tuple[T, int]:
394
+ """
395
+ Computes nll loss for the given lists of question and ctx vectors.
396
+ Note that although hard_negative_idx_per_question in not currently in use, one can use it for the
397
+ loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.
398
+ :return: a tuple of loss value and amount of correct predictions per batch
399
+ """
400
+ scores = self.get_scores(q_vectors, ctx_vectors)
401
+
402
+ if len(q_vectors.size()) > 1:
403
+ q_num = q_vectors.size(0)
404
+ scores = scores.view(q_num, -1)
405
+
406
+ softmax_scores = F.log_softmax(scores, dim=1)
407
+
408
+ loss = F.nll_loss(
409
+ softmax_scores,
410
+ torch.tensor(positive_idx_per_question).to(softmax_scores.device),
411
+ reduction="mean",
412
+ )
413
+
414
+ max_score, max_idxs = torch.max(softmax_scores, 1)
415
+ correct_predictions_count = (
416
+ max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)
417
+ ).sum()
418
+ return loss, correct_predictions_count
419
+
420
+ @staticmethod
421
+ def get_scores(q_vector: T, ctx_vectors: T) -> T:
422
+ f = BiEncoderNllLoss.get_similarity_function()
423
+ return f(q_vector, ctx_vectors)
424
+
425
+ @staticmethod
426
+ def get_similarity_function():
427
+ return dot_product_scores
models/hf_models.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import Tensor as T
6
+ from torch import nn
7
+ from transformers import BertConfig, BertModel
8
+ from transformers.optimization import AdamW
9
+ from transformers import BertTokenizer
10
+ from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig
11
+
12
+ import sys
13
+ import os
14
+
15
+ current_dir = os.path.dirname(__file__)
16
+ data_utils_path = os.path.join(current_dir, '..')
17
+ sys.path.append(data_utils_path)
18
+
19
+ from Data_utils_inf import Tensorizer
20
+ from .biencoder import BiEncoder, DistilBertBiEncoder
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ def count_parameters(model):
25
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
26
+
27
+ def get_bert_biencoder_components(args, inference_only: bool = False, **kwargs):
28
+ dropout = args.dropout if hasattr(args, "dropout") else 0.0
29
+ question_encoder = HFBertEncoder.init_encoder(
30
+ args.pretrained_model_cfg,
31
+ projection_dim=args.projection_dim,
32
+ dropout=dropout,
33
+ **kwargs
34
+ )
35
+ ctx_encoder = HFBertEncoder.init_encoder(
36
+ args.pretrained_model_cfg,
37
+ projection_dim=args.projection_dim,
38
+ dropout=dropout,
39
+ **kwargs
40
+ )
41
+
42
+ fix_ctx_encoder = (
43
+ args.fix_ctx_encoder if hasattr(args, "fix_ctx_encoder") else False
44
+ )
45
+ biencoder = BiEncoder(
46
+ question_encoder, ctx_encoder, fix_ctx_encoder=fix_ctx_encoder
47
+ )
48
+
49
+ optimizer = (
50
+ get_optimizer(
51
+ biencoder,
52
+ learning_rate=args.learning_rate,
53
+ adam_eps=args.adam_eps,
54
+ weight_decay=args.weight_decay,
55
+ )
56
+ if not inference_only
57
+ else None
58
+ )
59
+
60
+ tensorizer = get_bert_tensorizer(args)
61
+
62
+ return tensorizer, biencoder, optimizer
63
+
64
+ def get_distilbert_biencoder_components(args, inference_only: bool = False, **kwargs):
65
+ dropout = args.dropout if hasattr(args, "dropout") else 0.0
66
+ question_encoder = HFDistilBertEncoder.init_encoder(
67
+ args.pretrained_model_cfg,
68
+ projection_dim=args.projection_dim,
69
+ dropout=dropout,
70
+ **kwargs
71
+ )
72
+ ctx_encoder = HFDistilBertEncoder.init_encoder(
73
+ args.pretrained_model_cfg,
74
+ projection_dim=args.projection_dim,
75
+ dropout=dropout,
76
+ **kwargs
77
+ )
78
+
79
+ fix_ctx_encoder = (
80
+ args.fix_ctx_encoder if hasattr(args, "fix_ctx_encoder") else False
81
+ )
82
+ biencoder = DistilBertBiEncoder(
83
+ question_encoder, ctx_encoder, fix_ctx_encoder = fix_ctx_encoder
84
+ )
85
+
86
+ optimizer = (
87
+ get_optimizer(
88
+ biencoder,
89
+ learning_rate=args.learning_rate,
90
+ adam_eps=args.adam_eps,
91
+ weight_decay=args.weight_decay,
92
+ )
93
+ if not inference_only
94
+ else None
95
+ )
96
+
97
+ tensorizer = get_distilbert_tensorizer(args)
98
+
99
+ return tensorizer, biencoder, optimizer
100
+
101
+ def get_bert_tensorizer(args, tokenizer=None):
102
+ if not tokenizer:
103
+ tokenizer = get_bert_tokenizer(
104
+ args.pretrained_model_cfg, do_lower_case=args.do_lower_case
105
+ )
106
+ return BertTensorizer(tokenizer, args.sequence_length)
107
+
108
+ def get_distilbert_tensorizer(args, tokenizer=None):
109
+ if not tokenizer:
110
+ tokenizer = get_distilbert_tokenizer(
111
+ args.pretrained_model_cfg, do_lower_case=args.do_lower_case
112
+ )
113
+ return DistilBertTensorizer(tokenizer, args.sequence_length)
114
+
115
+
116
+ def get_optimizer(
117
+ model: nn.Module,
118
+ learning_rate: float = 1e-5,
119
+ adam_eps: float = 1e-8,
120
+ weight_decay: float = 0.0,
121
+ ) -> torch.optim.Optimizer:
122
+ no_decay = ["bias", "LayerNorm.weight"]
123
+
124
+ optimizer_grouped_parameters = [
125
+ {
126
+ "params": [
127
+ p
128
+ for n, p in model.named_parameters()
129
+ if not any(nd in n for nd in no_decay)
130
+ ],
131
+ "weight_decay": weight_decay,
132
+ },
133
+ {
134
+ "params": [
135
+ p
136
+ for n, p in model.named_parameters()
137
+ if any(nd in n for nd in no_decay)
138
+ ],
139
+ "weight_decay": 0.0,
140
+ },
141
+ ]
142
+ optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps)
143
+ return optimizer
144
+
145
+
146
+ def get_bert_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True):
147
+ return BertTokenizer.from_pretrained(
148
+ pretrained_cfg_name, do_lower_case=do_lower_case
149
+ )
150
+
151
+ def get_distilbert_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True):
152
+ # still uses HF code for tokenizer since they are the same
153
+ return DistilBertTokenizer.from_pretrained(
154
+ pretrained_cfg_name, do_lower_case=do_lower_case
155
+ )
156
+
157
+ class HFDistilBertEncoder(DistilBertModel):
158
+ def __init__(self, config, project_dim: int = 0):
159
+ DistilBertModel.__init__(self, config)
160
+ assert config.hidden_size > 0, "Encoder hidden_size can't be zero"
161
+ self.encode_proj = (
162
+ nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None
163
+ )
164
+ self.init_weights()
165
+
166
+ @classmethod
167
+ def init_encoder(
168
+ cls, cfg_name: str, projection_dim: int = 0, dropout: float = 0.1, **kwargs
169
+ ) -> DistilBertModel:
170
+ cfg = DistilBertConfig.from_pretrained(cfg_name if cfg_name else "distilbert-base-uncased")
171
+ if dropout != 0:
172
+ cfg.attention_probs_dropout_prob = dropout
173
+ cfg.hidden_dropout_prob = dropout
174
+ return cls.from_pretrained(
175
+ cfg_name, config=cfg, project_dim=projection_dim, **kwargs
176
+ )
177
+
178
+ def forward(
179
+ self, input_ids: T, attention_mask: T
180
+ ) -> Tuple[T, ...]:
181
+ if self.config.output_hidden_states:
182
+ outputs = super().forward(
183
+ input_ids=input_ids,
184
+ attention_mask=attention_mask,
185
+ )
186
+ sequence_output = outputs.last_hidden_state
187
+ pooled_output = outputs.last_hidden_state[:, 0, :]
188
+ hidden_states = outputs.hidden_states
189
+ else:
190
+ hidden_states = None
191
+ outputs = super().forward(
192
+ input_ids = input_ids,
193
+ attention_mask = attention_mask,
194
+ )
195
+ sequence_output = outputs.last_hidden_state
196
+ pooled_output = outputs.last_hidden_state[:, 0, :]
197
+
198
+ if self.encode_proj:
199
+ pooled_output = self.encode_proj(pooled_output)
200
+ return sequence_output, pooled_output, hidden_states
201
+
202
+ def get_out_size(self):
203
+ if self.encode_proj:
204
+ return self.encode_proj.out_features
205
+ return self.config.hidden_size
206
+
207
+
208
+ class HFBertEncoder(BertModel):
209
+ def __init__(self, config, project_dim: int = 0):
210
+ BertModel.__init__(self, config)
211
+ assert config.hidden_size > 0, "Encoder hidden_size can't be zero"
212
+ self.encode_proj = (
213
+ nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None
214
+ )
215
+ self.init_weights()
216
+
217
+ @classmethod
218
+ def init_encoder(
219
+ cls, cfg_name: str, projection_dim: int = 0, dropout: float = 0.1, **kwargs
220
+ ) -> BertModel:
221
+ cfg = BertConfig.from_pretrained(cfg_name if cfg_name else "bert-base-uncased")
222
+ if dropout != 0:
223
+ cfg.attention_probs_dropout_prob = dropout
224
+ cfg.hidden_dropout_prob = dropout
225
+ return cls.from_pretrained(
226
+ cfg_name, config=cfg, project_dim=projection_dim, **kwargs
227
+ )
228
+
229
+ def forward(
230
+ self, input_ids: T, token_type_ids: T, attention_mask: T
231
+ ) -> Tuple[T, ...]:
232
+ if self.config.output_hidden_states:
233
+ outputs = super().forward(
234
+ input_ids=input_ids,
235
+ token_type_ids=token_type_ids,
236
+ attention_mask=attention_mask,
237
+ )
238
+ sequence_output = outputs.last_hidden_state
239
+ pooled_output = outputs.pooler_output
240
+ hidden_states = outputs.hidden_states
241
+ else:
242
+ hidden_states = None
243
+ outputs = super().forward(
244
+ input_ids=input_ids,
245
+ token_type_ids=token_type_ids,
246
+ attention_mask=attention_mask,
247
+ )
248
+ sequence_output = outputs.last_hidden_state
249
+ pooled_output = outputs.pooler_output
250
+
251
+ if self.encode_proj:
252
+ pooled_output = self.encode_proj(pooled_output)
253
+ return sequence_output, pooled_output, hidden_states
254
+
255
+ def get_out_size(self):
256
+ if self.encode_proj:
257
+ return self.encode_proj.out_features
258
+ return self.config.hidden_size
259
+
260
+
261
+ class DistilBertTensorizer(Tensorizer):
262
+ def __init__(
263
+ self, tokenizer: DistilBertTokenizer, max_length: int, pad_to_max: bool = True
264
+ ):
265
+ self.tokenizer = tokenizer
266
+ self.max_length = max_length
267
+ self.pad_to_max = pad_to_max
268
+
269
+ def text_to_tensor(
270
+ self, text: str, title: str = None, add_special_tokens: bool = True
271
+ ):
272
+ if isinstance(text, float):
273
+ text = 'nan'
274
+ text = text.strip()
275
+
276
+ # tokenizer automatic padding is explicitly disabled since its inconsistent behavior
277
+ if title:
278
+ token_ids = self.tokenizer.encode(
279
+ title,
280
+ text_pair = text,
281
+ add_special_tokens = add_special_tokens,
282
+ max_length = self.max_length,
283
+ pad_to_max_length = False,
284
+ truncation = True,
285
+ )
286
+ else:
287
+ token_ids = self.tokenizer.encode(
288
+ text,
289
+ add_special_tokens = add_special_tokens,
290
+ max_length = self.max_length,
291
+ pad_to_max_length = False,
292
+ truncation = True,
293
+ )
294
+
295
+ seq_len = self.max_length
296
+ if self.pad_to_max and len(token_ids) < seq_len:
297
+ token_ids = token_ids + [self.tokenizer.pad_token_id] * (
298
+ seq_len - len(token_ids)
299
+ )
300
+ if len(token_ids) > seq_len:
301
+ token_ids = token_ids[0:seq_len]
302
+ token_ids[-1] = self.tokenizer.sep_token_id
303
+
304
+ return torch.tensor(token_ids)
305
+
306
+ class BertTensorizer(Tensorizer):
307
+ def __init__(
308
+ self, tokenizer: BertTokenizer, max_length: int, pad_to_max: bool = True
309
+ ):
310
+ self.tokenizer = tokenizer
311
+ self.max_length = max_length
312
+ self.pad_to_max = pad_to_max
313
+
314
+ def text_to_tensor(
315
+ self, text: str, title: str = None, add_special_tokens: bool = True
316
+ ):
317
+ if isinstance(text, float):
318
+ text = 'nan'
319
+ text = text.strip()
320
+
321
+ # tokenizer automatic padding is explicitly disabled since its inconsistent behavior
322
+ if title:
323
+ token_ids = self.tokenizer.encode(
324
+ title,
325
+ text_pair=text,
326
+ add_special_tokens=add_special_tokens,
327
+ max_length=self.max_length,
328
+ pad_to_max_length=False,
329
+ truncation=True,
330
+ )
331
+ else:
332
+ token_ids = self.tokenizer.encode(
333
+ text,
334
+ add_special_tokens=add_special_tokens,
335
+ max_length=self.max_length,
336
+ pad_to_max_length=False,
337
+ truncation=True,
338
+ )
339
+
340
+ seq_len = self.max_length
341
+ if self.pad_to_max and len(token_ids) < seq_len:
342
+ token_ids = token_ids + [self.tokenizer.pad_token_id] * (
343
+ seq_len - len(token_ids)
344
+ )
345
+ if len(token_ids) > seq_len:
346
+ token_ids = token_ids[0:seq_len]
347
+ token_ids[-1] = self.tokenizer.sep_token_id
348
+
349
+ return torch.tensor(token_ids)
350
+
351
+ def get_pair_separator_ids(self) -> T:
352
+ return torch.tensor([self.tokenizer.sep_token_id])
353
+
354
+ def get_pad_id(self) -> int:
355
+ return self.tokenizer.pad_token_id
356
+
357
+ def get_attn_mask(self, tokens_tensor: T) -> T:
358
+ return tokens_tensor != self.get_pad_id()
359
+
360
+ def is_sub_word_id(self, token_id: int):
361
+ token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
362
+ return token.startswith("##") or token.startswith(" ##")
363
+
364
+ def to_string(self, token_ids, skip_special_tokens=True):
365
+ return self.tokenizer.decode(token_ids, skip_special_tokens=True)
366
+
367
+ def set_pad_to_max(self, do_pad: bool):
368
+ self.pad_to_max = do_pad
requirements.txt CHANGED
@@ -1,2 +1,8 @@
1
  fastapi== 0.104.1
2
  uvicorn[standard]
 
 
 
 
 
 
 
1
  fastapi== 0.104.1
2
  uvicorn[standard]
3
+ puzpy
4
+ transformers
5
+ wordsegment
6
+ torch
7
+ faiss
8
+
words_alpha.txt ADDED
The diff for this file is too large to render. See raw diff