Spaces:
Runtime error
Runtime error
from __future__ import print_function | |
import os | |
import sys | |
import json | |
import numpy as np | |
import argparse | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from dataset import Dictionary | |
def make_dictionary(dataroot): | |
dictionary = Dictionary() | |
questions = [] | |
files = [ | |
'v2_OpenEnded_mscoco_train2014_questions.json', | |
'v2_OpenEnded_mscoco_val2014_questions.json', | |
'v2_OpenEnded_mscoco_test2015_questions.json', | |
'v2_OpenEnded_mscoco_test-dev2015_questions.json' | |
] | |
for path in files: | |
question_path = os.path.join(dataroot, 'clean', path) | |
qs = json.load(open(question_path))['questions'] | |
for q in qs: | |
dictionary.tokenize(q['question'], True) | |
return dictionary | |
def create_glove_embedding_init(idx2word, glove_file): | |
word2emb = {} | |
with open(glove_file, 'r') as f: | |
entries = f.readlines() | |
emb_dim = len(entries[0].split(' ')) - 1 | |
print('embedding dim is %d' % emb_dim) | |
weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) | |
for entry in entries: | |
vals = entry.split(' ') | |
word = vals[0] | |
vals = list(map(float, vals[1:])) | |
word2emb[word] = np.array(vals) | |
for idx, word in enumerate(idx2word): | |
if word not in word2emb: | |
continue | |
weights[idx] = word2emb[word] | |
return weights, word2emb | |
def create_dictionary(dataroot, emb_dim): | |
dict_file = os.path.join(dataroot, 'dictionary.pkl') | |
if os.path.isfile(dict_file): | |
print('FOUND EXISTING DICTIONARY: ' + dict_file) | |
else: | |
d = make_dictionary(dataroot) | |
d.dump_to_file(dict_file) | |
d = Dictionary.load_from_file(dict_file) | |
glove_file = os.path.join(dataroot, 'glove/glove.6B.%dd.txt' % emb_dim) | |
glove_out = os.path.join(dataroot, 'glove6b_init_%dd.npy' % emb_dim) | |
if os.path.isfile(glove_out): | |
print('FOUND EXISTING GLOVE FILE: ' + glove_out) | |
else: | |
weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) | |
np.save(glove_out, weights) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataroot', type=str, default='../data/') | |
parser.add_argument('--emb_dim', type=int, default=300) | |
args = parser.parse_args() | |
create_dictionary(args.dataroot, args.emb_dim) | |