Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # mcan-vqa (Deep Modular Co-Attention Networks) | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Written by Yuhao Cui https://github.com/cuiyuhao1996 | |
| # -------------------------------------------------------- | |
| import sys | |
| sys.path.append('../') | |
| from openvqa.utils.ans_punct import prep_ans | |
| from openvqa.core.path_cfgs import PATH | |
| import json, re | |
| path = PATH() | |
| ques_dict_preread = { | |
| 'train': json.load(open(path.RAW_PATH['gqa']['train'], 'r')), | |
| 'val': json.load(open(path.RAW_PATH['gqa']['val'], 'r')), | |
| 'testdev': json.load(open(path.RAW_PATH['gqa']['testdev'], 'r')), | |
| 'test': json.load(open(path.RAW_PATH['gqa']['test'], 'r')), | |
| } | |
| # Loading question word list | |
| stat_ques_dict = { | |
| **ques_dict_preread['train'], | |
| **ques_dict_preread['val'], | |
| **ques_dict_preread['testdev'], | |
| **ques_dict_preread['test'], | |
| } | |
| stat_ans_dict = { | |
| **ques_dict_preread['train'], | |
| **ques_dict_preread['val'], | |
| **ques_dict_preread['testdev'], | |
| } | |
| def tokenize(stat_ques_dict): | |
| token_to_ix = { | |
| 'PAD': 0, | |
| 'UNK': 1, | |
| 'CLS': 2, | |
| } | |
| max_token = 0 | |
| for qid in stat_ques_dict: | |
| ques = stat_ques_dict[qid]['question'] | |
| words = re.sub( | |
| r"([.,'!?\"()*#:;])", | |
| '', | |
| ques.lower() | |
| ).replace('-', ' ').replace('/', ' ').split() | |
| if len(words) > max_token: | |
| max_token = len(words) | |
| for word in words: | |
| if word not in token_to_ix: | |
| token_to_ix[word] = len(token_to_ix) | |
| return token_to_ix, max_token | |
| def ans_stat(stat_ans_dict): | |
| ans_to_ix = {} | |
| ix_to_ans = {} | |
| for qid in stat_ans_dict: | |
| ans = stat_ans_dict[qid]['answer'] | |
| ans = prep_ans(ans) | |
| if ans not in ans_to_ix: | |
| ix_to_ans[ans_to_ix.__len__()] = ans | |
| ans_to_ix[ans] = ans_to_ix.__len__() | |
| return ans_to_ix, ix_to_ans | |
| token_to_ix, max_token = tokenize(stat_ques_dict) | |
| ans_to_ix, ix_to_ans = ans_stat(stat_ans_dict) | |
| # print(ans_to_ix) | |
| # print(ix_to_ans) | |
| # print(token_to_ix) | |
| # print(token_to_ix.__len__()) | |
| # print(max_token) | |
| json.dump([ans_to_ix, ix_to_ans, token_to_ix, max_token], open('../openvqa/datasets/gqa/dicts.json', 'w')) | |