Spaces:
Running
Running
# Copyright 2016 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Functions for constructing vocabulary, converting the examples to integer format and building the required masks for batch computation Author: aneelakantan (Arvind Neelakantan) | |
""" | |
from __future__ import print_function | |
import copy | |
import numbers | |
import numpy as np | |
import wiki_data | |
def return_index(a): | |
for i in range(len(a)): | |
if (a[i] == 1.0): | |
return i | |
def construct_vocab(data, utility, add_word=False): | |
ans = [] | |
for example in data: | |
sent = "" | |
for word in example.question: | |
if (not (isinstance(word, numbers.Number))): | |
sent += word + " " | |
example.original_nc = copy.deepcopy(example.number_columns) | |
example.original_wc = copy.deepcopy(example.word_columns) | |
example.original_nc_names = copy.deepcopy(example.number_column_names) | |
example.original_wc_names = copy.deepcopy(example.word_column_names) | |
if (add_word): | |
continue | |
number_found = 0 | |
if (not (example.is_bad_example)): | |
for word in example.question: | |
if (isinstance(word, numbers.Number)): | |
number_found += 1 | |
else: | |
if (not (utility.word_ids.has_key(word))): | |
utility.words.append(word) | |
utility.word_count[word] = 1 | |
utility.word_ids[word] = len(utility.word_ids) | |
utility.reverse_word_ids[utility.word_ids[word]] = word | |
else: | |
utility.word_count[word] += 1 | |
for col_name in example.word_column_names: | |
for word in col_name: | |
if (isinstance(word, numbers.Number)): | |
number_found += 1 | |
else: | |
if (not (utility.word_ids.has_key(word))): | |
utility.words.append(word) | |
utility.word_count[word] = 1 | |
utility.word_ids[word] = len(utility.word_ids) | |
utility.reverse_word_ids[utility.word_ids[word]] = word | |
else: | |
utility.word_count[word] += 1 | |
for col_name in example.number_column_names: | |
for word in col_name: | |
if (isinstance(word, numbers.Number)): | |
number_found += 1 | |
else: | |
if (not (utility.word_ids.has_key(word))): | |
utility.words.append(word) | |
utility.word_count[word] = 1 | |
utility.word_ids[word] = len(utility.word_ids) | |
utility.reverse_word_ids[utility.word_ids[word]] = word | |
else: | |
utility.word_count[word] += 1 | |
def word_lookup(word, utility): | |
if (utility.word_ids.has_key(word)): | |
return word | |
else: | |
return utility.unk_token | |
def convert_to_int_2d_and_pad(a, utility): | |
ans = [] | |
#print a | |
for b in a: | |
temp = [] | |
if (len(b) > utility.FLAGS.max_entry_length): | |
b = b[0:utility.FLAGS.max_entry_length] | |
for remaining in range(len(b), utility.FLAGS.max_entry_length): | |
b.append(utility.dummy_token) | |
assert len(b) == utility.FLAGS.max_entry_length | |
for word in b: | |
temp.append(utility.word_ids[word_lookup(word, utility)]) | |
ans.append(temp) | |
#print ans | |
return ans | |
def convert_to_bool_and_pad(a, utility): | |
a = a.tolist() | |
for i in range(len(a)): | |
for j in range(len(a[i])): | |
if (a[i][j] < 1): | |
a[i][j] = False | |
else: | |
a[i][j] = True | |
a[i] = a[i] + [False] * (utility.FLAGS.max_elements - len(a[i])) | |
return a | |
seen_tables = {} | |
def partial_match(question, table, number): | |
answer = [] | |
match = {} | |
for i in range(len(table)): | |
temp = [] | |
for j in range(len(table[i])): | |
temp.append(0) | |
answer.append(temp) | |
for i in range(len(table)): | |
for j in range(len(table[i])): | |
for word in question: | |
if (number): | |
if (word == table[i][j]): | |
answer[i][j] = 1.0 | |
match[i] = 1.0 | |
else: | |
if (word in table[i][j]): | |
answer[i][j] = 1.0 | |
match[i] = 1.0 | |
return answer, match | |
def exact_match(question, table, number): | |
#performs exact match operation | |
answer = [] | |
match = {} | |
matched_indices = [] | |
for i in range(len(table)): | |
temp = [] | |
for j in range(len(table[i])): | |
temp.append(0) | |
answer.append(temp) | |
for i in range(len(table)): | |
for j in range(len(table[i])): | |
if (number): | |
for word in question: | |
if (word == table[i][j]): | |
match[i] = 1.0 | |
answer[i][j] = 1.0 | |
else: | |
table_entry = table[i][j] | |
for k in range(len(question)): | |
if (k + len(table_entry) <= len(question)): | |
if (table_entry == question[k:(k + len(table_entry))]): | |
#if(len(table_entry) == 1): | |
#print "match: ", table_entry, question | |
match[i] = 1.0 | |
answer[i][j] = 1.0 | |
matched_indices.append((k, len(table_entry))) | |
return answer, match, matched_indices | |
def partial_column_match(question, table, number): | |
answer = [] | |
for i in range(len(table)): | |
answer.append(0) | |
for i in range(len(table)): | |
for word in question: | |
if (word in table[i]): | |
answer[i] = 1.0 | |
return answer | |
def exact_column_match(question, table, number): | |
#performs exact match on column names | |
answer = [] | |
matched_indices = [] | |
for i in range(len(table)): | |
answer.append(0) | |
for i in range(len(table)): | |
table_entry = table[i] | |
for k in range(len(question)): | |
if (k + len(table_entry) <= len(question)): | |
if (table_entry == question[k:(k + len(table_entry))]): | |
answer[i] = 1.0 | |
matched_indices.append((k, len(table_entry))) | |
return answer, matched_indices | |
def get_max_entry(a): | |
e = {} | |
for w in a: | |
if (w != "UNK, "): | |
if (e.has_key(w)): | |
e[w] += 1 | |
else: | |
e[w] = 1 | |
if (len(e) > 0): | |
(key, val) = sorted(e.items(), key=lambda x: -1 * x[1])[0] | |
if (val > 1): | |
return key | |
else: | |
return -1.0 | |
else: | |
return -1.0 | |
def list_join(a): | |
ans = "" | |
for w in a: | |
ans += str(w) + ", " | |
return ans | |
def group_by_max(table, number): | |
#computes the most frequently occurring entry in a column | |
answer = [] | |
for i in range(len(table)): | |
temp = [] | |
for j in range(len(table[i])): | |
temp.append(0) | |
answer.append(temp) | |
for i in range(len(table)): | |
if (number): | |
curr = table[i] | |
else: | |
curr = [list_join(w) for w in table[i]] | |
max_entry = get_max_entry(curr) | |
#print i, max_entry | |
for j in range(len(curr)): | |
if (max_entry == curr[j]): | |
answer[i][j] = 1.0 | |
else: | |
answer[i][j] = 0.0 | |
return answer | |
def pick_one(a): | |
for i in range(len(a)): | |
if (1.0 in a[i]): | |
return True | |
return False | |
def check_processed_cols(col, utility): | |
return True in [ | |
True for y in col | |
if (y != utility.FLAGS.pad_int and y != | |
utility.FLAGS.bad_number_pre_process) | |
] | |
def complete_wiki_processing(data, utility, train=True): | |
#convert to integers and padding | |
processed_data = [] | |
num_bad_examples = 0 | |
for example in data: | |
number_found = 0 | |
if (example.is_bad_example): | |
num_bad_examples += 1 | |
if (not (example.is_bad_example)): | |
example.string_question = example.question[:] | |
#entry match | |
example.processed_number_columns = example.processed_number_columns[:] | |
example.processed_word_columns = example.processed_word_columns[:] | |
example.word_exact_match, word_match, matched_indices = exact_match( | |
example.string_question, example.original_wc, number=False) | |
example.number_exact_match, number_match, _ = exact_match( | |
example.string_question, example.original_nc, number=True) | |
if (not (pick_one(example.word_exact_match)) and not ( | |
pick_one(example.number_exact_match))): | |
assert len(word_match) == 0 | |
assert len(number_match) == 0 | |
example.word_exact_match, word_match = partial_match( | |
example.string_question, example.original_wc, number=False) | |
#group by max | |
example.word_group_by_max = group_by_max(example.original_wc, False) | |
example.number_group_by_max = group_by_max(example.original_nc, True) | |
#column name match | |
example.word_column_exact_match, wcol_matched_indices = exact_column_match( | |
example.string_question, example.original_wc_names, number=False) | |
example.number_column_exact_match, ncol_matched_indices = exact_column_match( | |
example.string_question, example.original_nc_names, number=False) | |
if (not (1.0 in example.word_column_exact_match) and not ( | |
1.0 in example.number_column_exact_match)): | |
example.word_column_exact_match = partial_column_match( | |
example.string_question, example.original_wc_names, number=False) | |
example.number_column_exact_match = partial_column_match( | |
example.string_question, example.original_nc_names, number=False) | |
if (len(word_match) > 0 or len(number_match) > 0): | |
example.question.append(utility.entry_match_token) | |
if (1.0 in example.word_column_exact_match or | |
1.0 in example.number_column_exact_match): | |
example.question.append(utility.column_match_token) | |
example.string_question = example.question[:] | |
example.number_lookup_matrix = np.transpose( | |
example.number_lookup_matrix)[:] | |
example.word_lookup_matrix = np.transpose(example.word_lookup_matrix)[:] | |
example.columns = example.number_columns[:] | |
example.word_columns = example.word_columns[:] | |
example.len_total_cols = len(example.word_column_names) + len( | |
example.number_column_names) | |
example.column_names = example.number_column_names[:] | |
example.word_column_names = example.word_column_names[:] | |
example.string_column_names = example.number_column_names[:] | |
example.string_word_column_names = example.word_column_names[:] | |
example.sorted_number_index = [] | |
example.sorted_word_index = [] | |
example.column_mask = [] | |
example.word_column_mask = [] | |
example.processed_column_mask = [] | |
example.processed_word_column_mask = [] | |
example.word_column_entry_mask = [] | |
example.question_attention_mask = [] | |
example.question_number = example.question_number_1 = -1 | |
example.question_attention_mask = [] | |
example.ordinal_question = [] | |
example.ordinal_question_one = [] | |
new_question = [] | |
if (len(example.number_columns) > 0): | |
example.len_col = len(example.number_columns[0]) | |
else: | |
example.len_col = len(example.word_columns[0]) | |
for (start, length) in matched_indices: | |
for j in range(length): | |
example.question[start + j] = utility.unk_token | |
#print example.question | |
for word in example.question: | |
if (isinstance(word, numbers.Number) or wiki_data.is_date(word)): | |
if (not (isinstance(word, numbers.Number)) and | |
wiki_data.is_date(word)): | |
word = word.replace("X", "").replace("-", "") | |
number_found += 1 | |
if (number_found == 1): | |
example.question_number = word | |
if (len(example.ordinal_question) > 0): | |
example.ordinal_question[len(example.ordinal_question) - 1] = 1.0 | |
else: | |
example.ordinal_question.append(1.0) | |
elif (number_found == 2): | |
example.question_number_1 = word | |
if (len(example.ordinal_question_one) > 0): | |
example.ordinal_question_one[len(example.ordinal_question_one) - | |
1] = 1.0 | |
else: | |
example.ordinal_question_one.append(1.0) | |
else: | |
new_question.append(word) | |
example.ordinal_question.append(0.0) | |
example.ordinal_question_one.append(0.0) | |
example.question = [ | |
utility.word_ids[word_lookup(w, utility)] for w in new_question | |
] | |
example.question_attention_mask = [0.0] * len(example.question) | |
#when the first question number occurs before a word | |
example.ordinal_question = example.ordinal_question[0:len( | |
example.question)] | |
example.ordinal_question_one = example.ordinal_question_one[0:len( | |
example.question)] | |
#question-padding | |
example.question = [utility.word_ids[utility.dummy_token]] * ( | |
utility.FLAGS.question_length - len(example.question) | |
) + example.question | |
example.question_attention_mask = [-10000.0] * ( | |
utility.FLAGS.question_length - len(example.question_attention_mask) | |
) + example.question_attention_mask | |
example.ordinal_question = [0.0] * (utility.FLAGS.question_length - | |
len(example.ordinal_question) | |
) + example.ordinal_question | |
example.ordinal_question_one = [0.0] * (utility.FLAGS.question_length - | |
len(example.ordinal_question_one) | |
) + example.ordinal_question_one | |
if (True): | |
#number columns and related-padding | |
num_cols = len(example.columns) | |
start = 0 | |
for column in example.number_columns: | |
if (check_processed_cols(example.processed_number_columns[start], | |
utility)): | |
example.processed_column_mask.append(0.0) | |
sorted_index = sorted( | |
range(len(example.processed_number_columns[start])), | |
key=lambda k: example.processed_number_columns[start][k], | |
reverse=True) | |
sorted_index = sorted_index + [utility.FLAGS.pad_int] * ( | |
utility.FLAGS.max_elements - len(sorted_index)) | |
example.sorted_number_index.append(sorted_index) | |
example.columns[start] = column + [utility.FLAGS.pad_int] * ( | |
utility.FLAGS.max_elements - len(column)) | |
example.processed_number_columns[start] += [utility.FLAGS.pad_int] * ( | |
utility.FLAGS.max_elements - | |
len(example.processed_number_columns[start])) | |
start += 1 | |
example.column_mask.append(0.0) | |
for remaining in range(num_cols, utility.FLAGS.max_number_cols): | |
example.sorted_number_index.append([utility.FLAGS.pad_int] * | |
(utility.FLAGS.max_elements)) | |
example.columns.append([utility.FLAGS.pad_int] * | |
(utility.FLAGS.max_elements)) | |
example.processed_number_columns.append([utility.FLAGS.pad_int] * | |
(utility.FLAGS.max_elements)) | |
example.number_exact_match.append([0.0] * | |
(utility.FLAGS.max_elements)) | |
example.number_group_by_max.append([0.0] * | |
(utility.FLAGS.max_elements)) | |
example.column_mask.append(-100000000.0) | |
example.processed_column_mask.append(-100000000.0) | |
example.number_column_exact_match.append(0.0) | |
example.column_names.append([utility.dummy_token]) | |
#word column and related-padding | |
start = 0 | |
word_num_cols = len(example.word_columns) | |
for column in example.word_columns: | |
if (check_processed_cols(example.processed_word_columns[start], | |
utility)): | |
example.processed_word_column_mask.append(0.0) | |
sorted_index = sorted( | |
range(len(example.processed_word_columns[start])), | |
key=lambda k: example.processed_word_columns[start][k], | |
reverse=True) | |
sorted_index = sorted_index + [utility.FLAGS.pad_int] * ( | |
utility.FLAGS.max_elements - len(sorted_index)) | |
example.sorted_word_index.append(sorted_index) | |
column = convert_to_int_2d_and_pad(column, utility) | |
example.word_columns[start] = column + [[ | |
utility.word_ids[utility.dummy_token] | |
] * utility.FLAGS.max_entry_length] * (utility.FLAGS.max_elements - | |
len(column)) | |
example.processed_word_columns[start] += [utility.FLAGS.pad_int] * ( | |
utility.FLAGS.max_elements - | |
len(example.processed_word_columns[start])) | |
example.word_column_entry_mask.append([0] * len(column) + [ | |
utility.word_ids[utility.dummy_token] | |
] * (utility.FLAGS.max_elements - len(column))) | |
start += 1 | |
example.word_column_mask.append(0.0) | |
for remaining in range(word_num_cols, utility.FLAGS.max_word_cols): | |
example.sorted_word_index.append([utility.FLAGS.pad_int] * | |
(utility.FLAGS.max_elements)) | |
example.word_columns.append([[utility.word_ids[utility.dummy_token]] * | |
utility.FLAGS.max_entry_length] * | |
(utility.FLAGS.max_elements)) | |
example.word_column_entry_mask.append( | |
[utility.word_ids[utility.dummy_token]] * | |
(utility.FLAGS.max_elements)) | |
example.word_exact_match.append([0.0] * (utility.FLAGS.max_elements)) | |
example.word_group_by_max.append([0.0] * (utility.FLAGS.max_elements)) | |
example.processed_word_columns.append([utility.FLAGS.pad_int] * | |
(utility.FLAGS.max_elements)) | |
example.word_column_mask.append(-100000000.0) | |
example.processed_word_column_mask.append(-100000000.0) | |
example.word_column_exact_match.append(0.0) | |
example.word_column_names.append([utility.dummy_token] * | |
utility.FLAGS.max_entry_length) | |
seen_tables[example.table_key] = 1 | |
#convert column and word column names to integers | |
example.column_ids = convert_to_int_2d_and_pad(example.column_names, | |
utility) | |
example.word_column_ids = convert_to_int_2d_and_pad( | |
example.word_column_names, utility) | |
for i_em in range(len(example.number_exact_match)): | |
example.number_exact_match[i_em] = example.number_exact_match[ | |
i_em] + [0.0] * (utility.FLAGS.max_elements - | |
len(example.number_exact_match[i_em])) | |
example.number_group_by_max[i_em] = example.number_group_by_max[ | |
i_em] + [0.0] * (utility.FLAGS.max_elements - | |
len(example.number_group_by_max[i_em])) | |
for i_em in range(len(example.word_exact_match)): | |
example.word_exact_match[i_em] = example.word_exact_match[ | |
i_em] + [0.0] * (utility.FLAGS.max_elements - | |
len(example.word_exact_match[i_em])) | |
example.word_group_by_max[i_em] = example.word_group_by_max[ | |
i_em] + [0.0] * (utility.FLAGS.max_elements - | |
len(example.word_group_by_max[i_em])) | |
example.exact_match = example.number_exact_match + example.word_exact_match | |
example.group_by_max = example.number_group_by_max + example.word_group_by_max | |
example.exact_column_match = example.number_column_exact_match + example.word_column_exact_match | |
#answer and related mask, padding | |
if (example.is_lookup): | |
example.answer = example.calc_answer | |
example.number_print_answer = example.number_lookup_matrix.tolist() | |
example.word_print_answer = example.word_lookup_matrix.tolist() | |
for i_answer in range(len(example.number_print_answer)): | |
example.number_print_answer[i_answer] = example.number_print_answer[ | |
i_answer] + [0.0] * (utility.FLAGS.max_elements - | |
len(example.number_print_answer[i_answer])) | |
for i_answer in range(len(example.word_print_answer)): | |
example.word_print_answer[i_answer] = example.word_print_answer[ | |
i_answer] + [0.0] * (utility.FLAGS.max_elements - | |
len(example.word_print_answer[i_answer])) | |
example.number_lookup_matrix = convert_to_bool_and_pad( | |
example.number_lookup_matrix, utility) | |
example.word_lookup_matrix = convert_to_bool_and_pad( | |
example.word_lookup_matrix, utility) | |
for remaining in range(num_cols, utility.FLAGS.max_number_cols): | |
example.number_lookup_matrix.append([False] * | |
utility.FLAGS.max_elements) | |
example.number_print_answer.append([0.0] * utility.FLAGS.max_elements) | |
for remaining in range(word_num_cols, utility.FLAGS.max_word_cols): | |
example.word_lookup_matrix.append([False] * | |
utility.FLAGS.max_elements) | |
example.word_print_answer.append([0.0] * utility.FLAGS.max_elements) | |
example.print_answer = example.number_print_answer + example.word_print_answer | |
else: | |
example.answer = example.calc_answer | |
example.print_answer = [[0.0] * (utility.FLAGS.max_elements)] * ( | |
utility.FLAGS.max_number_cols + utility.FLAGS.max_word_cols) | |
#question_number masks | |
if (example.question_number == -1): | |
example.question_number_mask = np.zeros([utility.FLAGS.max_elements]) | |
else: | |
example.question_number_mask = np.ones([utility.FLAGS.max_elements]) | |
if (example.question_number_1 == -1): | |
example.question_number_one_mask = -10000.0 | |
else: | |
example.question_number_one_mask = np.float64(0.0) | |
if (example.len_col > utility.FLAGS.max_elements): | |
continue | |
processed_data.append(example) | |
return processed_data | |
def add_special_words(utility): | |
utility.words.append(utility.entry_match_token) | |
utility.word_ids[utility.entry_match_token] = len(utility.word_ids) | |
utility.reverse_word_ids[utility.word_ids[ | |
utility.entry_match_token]] = utility.entry_match_token | |
utility.entry_match_token_id = utility.word_ids[utility.entry_match_token] | |
print("entry match token: ", utility.word_ids[ | |
utility.entry_match_token], utility.entry_match_token_id) | |
utility.words.append(utility.column_match_token) | |
utility.word_ids[utility.column_match_token] = len(utility.word_ids) | |
utility.reverse_word_ids[utility.word_ids[ | |
utility.column_match_token]] = utility.column_match_token | |
utility.column_match_token_id = utility.word_ids[utility.column_match_token] | |
print("entry match token: ", utility.word_ids[ | |
utility.column_match_token], utility.column_match_token_id) | |
utility.words.append(utility.dummy_token) | |
utility.word_ids[utility.dummy_token] = len(utility.word_ids) | |
utility.reverse_word_ids[utility.word_ids[ | |
utility.dummy_token]] = utility.dummy_token | |
utility.dummy_token_id = utility.word_ids[utility.dummy_token] | |
utility.words.append(utility.unk_token) | |
utility.word_ids[utility.unk_token] = len(utility.word_ids) | |
utility.reverse_word_ids[utility.word_ids[ | |
utility.unk_token]] = utility.unk_token | |
def perform_word_cutoff(utility): | |
if (utility.FLAGS.word_cutoff > 0): | |
for word in utility.word_ids.keys(): | |
if (utility.word_count.has_key(word) and utility.word_count[word] < | |
utility.FLAGS.word_cutoff and word != utility.unk_token and | |
word != utility.dummy_token and word != utility.entry_match_token and | |
word != utility.column_match_token): | |
utility.word_ids.pop(word) | |
utility.words.remove(word) | |
def word_dropout(question, utility): | |
if (utility.FLAGS.word_dropout_prob > 0.0): | |
new_question = [] | |
for i in range(len(question)): | |
if (question[i] != utility.dummy_token_id and | |
utility.random.random() > utility.FLAGS.word_dropout_prob): | |
new_question.append(utility.word_ids[utility.unk_token]) | |
else: | |
new_question.append(question[i]) | |
return new_question | |
else: | |
return question | |
def generate_feed_dict(data, curr, batch_size, gr, train=False, utility=None): | |
#prepare feed dict dictionary | |
feed_dict = {} | |
feed_examples = [] | |
for j in range(batch_size): | |
feed_examples.append(data[curr + j]) | |
if (train): | |
feed_dict[gr.batch_question] = [ | |
word_dropout(feed_examples[j].question, utility) | |
for j in range(batch_size) | |
] | |
else: | |
feed_dict[gr.batch_question] = [ | |
feed_examples[j].question for j in range(batch_size) | |
] | |
feed_dict[gr.batch_question_attention_mask] = [ | |
feed_examples[j].question_attention_mask for j in range(batch_size) | |
] | |
feed_dict[ | |
gr.batch_answer] = [feed_examples[j].answer for j in range(batch_size)] | |
feed_dict[gr.batch_number_column] = [ | |
feed_examples[j].columns for j in range(batch_size) | |
] | |
feed_dict[gr.batch_processed_number_column] = [ | |
feed_examples[j].processed_number_columns for j in range(batch_size) | |
] | |
feed_dict[gr.batch_processed_sorted_index_number_column] = [ | |
feed_examples[j].sorted_number_index for j in range(batch_size) | |
] | |
feed_dict[gr.batch_processed_sorted_index_word_column] = [ | |
feed_examples[j].sorted_word_index for j in range(batch_size) | |
] | |
feed_dict[gr.batch_question_number] = np.array( | |
[feed_examples[j].question_number for j in range(batch_size)]).reshape( | |
(batch_size, 1)) | |
feed_dict[gr.batch_question_number_one] = np.array( | |
[feed_examples[j].question_number_1 for j in range(batch_size)]).reshape( | |
(batch_size, 1)) | |
feed_dict[gr.batch_question_number_mask] = [ | |
feed_examples[j].question_number_mask for j in range(batch_size) | |
] | |
feed_dict[gr.batch_question_number_one_mask] = np.array( | |
[feed_examples[j].question_number_one_mask for j in range(batch_size) | |
]).reshape((batch_size, 1)) | |
feed_dict[gr.batch_print_answer] = [ | |
feed_examples[j].print_answer for j in range(batch_size) | |
] | |
feed_dict[gr.batch_exact_match] = [ | |
feed_examples[j].exact_match for j in range(batch_size) | |
] | |
feed_dict[gr.batch_group_by_max] = [ | |
feed_examples[j].group_by_max for j in range(batch_size) | |
] | |
feed_dict[gr.batch_column_exact_match] = [ | |
feed_examples[j].exact_column_match for j in range(batch_size) | |
] | |
feed_dict[gr.batch_ordinal_question] = [ | |
feed_examples[j].ordinal_question for j in range(batch_size) | |
] | |
feed_dict[gr.batch_ordinal_question_one] = [ | |
feed_examples[j].ordinal_question_one for j in range(batch_size) | |
] | |
feed_dict[gr.batch_number_column_mask] = [ | |
feed_examples[j].column_mask for j in range(batch_size) | |
] | |
feed_dict[gr.batch_number_column_names] = [ | |
feed_examples[j].column_ids for j in range(batch_size) | |
] | |
feed_dict[gr.batch_processed_word_column] = [ | |
feed_examples[j].processed_word_columns for j in range(batch_size) | |
] | |
feed_dict[gr.batch_word_column_mask] = [ | |
feed_examples[j].word_column_mask for j in range(batch_size) | |
] | |
feed_dict[gr.batch_word_column_names] = [ | |
feed_examples[j].word_column_ids for j in range(batch_size) | |
] | |
feed_dict[gr.batch_word_column_entry_mask] = [ | |
feed_examples[j].word_column_entry_mask for j in range(batch_size) | |
] | |
return feed_dict | |