Spaces:
Runtime error
Runtime error
import time | |
import os | |
import random | |
import json | |
import pickle | |
import numpy as np | |
from tqdm import tqdm | |
from termcolor import colored | |
from program_translator import ProgramTranslator # | |
from config import config | |
# Print bold tex | |
def bold(txt): | |
return colored(str(txt), attrs=["bold"]) | |
# Print bold and colored text | |
def bcolored(txt, color): | |
return colored(str(txt), color, attrs=["bold"]) | |
# Write a line to file | |
def writeline(f, line): | |
f.write(str(line) + "\n") | |
# Write a list to file | |
def writelist(f, l): | |
writeline(f, ",".join(map(str, l))) | |
# 2d list to numpy | |
def vectorize2DList(items, minX=0, minY=0, dtype=np.int): | |
maxX = max(len(items), minX) | |
maxY = max([len(item) for item in items] + [minY]) | |
t = np.zeros((maxX, maxY), dtype=dtype) | |
tLengths = np.zeros((maxX,), dtype=np.int) | |
for i, item in enumerate(items): | |
t[i, 0:len(item)] = np.array(item, dtype=dtype) | |
tLengths[i] = len(item) | |
return t, tLengths | |
# 3d list to numpy | |
def vectorize3DList(items, minX=0, minY=0, minZ=0, dtype=np.int): | |
maxX = max(len(items), minX) | |
maxY = max([len(item) for item in items] + [minY]) | |
maxZ = max([len(subitem) for item in items for subitem in item] + [minZ]) | |
t = np.zeros((maxX, maxY, maxZ), dtype=dtype) | |
tLengths = np.zeros((maxX, maxY), dtype=np.int) | |
for i, item in enumerate(items): | |
for j, subitem in enumerate(item): | |
t[i, j, 0:len(subitem)] = np.array(subitem, dtype=dtype) | |
tLengths[i, j] = len(subitem) | |
return t, tLengths | |
''' | |
Encodes text into integers. Keeps dictionary between string words (symbols) | |
and their matching integers. Supports encoding and decoding. | |
''' | |
class SymbolDict(object): | |
def __init__(self, empty=False): | |
self.padding = "<PAD>" | |
self.unknown = "<UNK>" | |
self.start = "<START>" | |
self.end = "<END>" | |
self.invalidSymbols = [self.padding, self.unknown, self.start, self.end] | |
if empty: | |
self.sym2id = {} | |
self.id2sym = [] | |
else: | |
self.sym2id = {self.padding: 0, self.unknown: 1, self.start: 2, self.end: 3} | |
self.id2sym = [self.padding, self.unknown, self.start, self.end] | |
self.allSeqs = [] | |
def getNumSymbols(self): | |
return len(self.sym2id) | |
def isPadding(self, enc): | |
return enc == 0 | |
def isUnknown(self, enc): | |
return enc == 1 | |
def isStart(self, enc): | |
return enc == 2 | |
def isEnd(self, enc): | |
return enc == 3 | |
def isValid(self, enc): | |
return enc < self.getNumSymbols() and enc >= len(self.invalidSymbols) | |
def resetSeqs(self): | |
self.allSeqs = [] | |
def addSeq(self, seq): | |
self.allSeqs += seq | |
# Call to create the words-to-integers vocabulary after (reading word sequences with addSeq). | |
def createVocab(self, minCount=0): | |
counter = {} | |
for symbol in self.allSeqs: | |
counter[symbol] = counter.get(symbol, 0) + 1 | |
for symbol in counter: | |
if counter[symbol] > minCount and (symbol not in self.sym2id): | |
self.sym2id[symbol] = self.getNumSymbols() | |
self.id2sym.append(symbol) | |
# Encodes a symbol. Returns the matching integer. | |
def encodeSym(self, symbol): | |
if symbol not in self.sym2id: | |
symbol = self.unknown | |
return self.sym2id[symbol] | |
''' | |
Encodes a sequence of symbols. | |
Optionally add start, or end symbols. | |
Optionally reverse sequence | |
''' | |
def encodeSequence(self, decoded, addStart=False, addEnd=False, reverse=False): | |
if reverse: | |
decoded.reverse() | |
if addStart: | |
decoded = [self.start] + decoded | |
if addEnd: | |
decoded = decoded + [self.end] | |
encoded = [self.encodeSym(symbol) for symbol in decoded] | |
return encoded | |
# Decodes an integer into its symbol | |
def decodeId(self, enc): | |
return self.id2sym[enc] if enc < self.getNumSymbols() else self.unknown | |
''' | |
Decodes a sequence of integers into their symbols. | |
If delim is given, joins the symbols using delim, | |
Optionally reverse the resulted sequence | |
''' | |
def decodeSequence(self, encoded, delim=None, reverse=False, stopAtInvalid=True): | |
length = 0 | |
for i in range(len(encoded)): | |
if not self.isValid(encoded[i]) and stopAtInvalid: | |
break | |
length += 1 | |
encoded = encoded[:length] | |
decoded = [self.decodeId(enc) for enc in encoded] | |
if reverse: | |
decoded.reverse() | |
if delim is not None: | |
return delim.join(decoded) | |
return decoded | |
''' | |
Preprocesses a given dataset into numpy arrays. | |
By calling preprocess, the class: | |
1. Reads the input data files into dictionary. | |
2. Saves the results jsons in files and loads them instead of parsing input if files exist/ | |
3. Initializes word embeddings to random / GloVe. | |
4. Optionally filters data according to given filters. | |
5. Encodes and vectorize the data into numpy arrays. | |
6. Buckets the data according to the instances length. | |
''' | |
class Preprocesser(object): | |
def __init__(self): | |
self.questionDict = SymbolDict() | |
self.answerDict = SymbolDict(empty=True) | |
self.qaDict = SymbolDict() | |
self.specificDatasetDicts = None | |
self.programDict = SymbolDict() | |
self.programTranslator = ProgramTranslator(self.programDict, 2) | |
''' | |
Tokenizes string into list of symbols. | |
Args: | |
text: raw string to tokenize. | |
ignorePuncts: punctuation to ignore | |
keptPunct: punctuation to keep (as symbol) | |
endPunct: punctuation to remove if appears at the end | |
delim: delimiter between symbols | |
clean: True to replace text in string | |
replacelistPre: dictionary of replacement to perform on the text before tokanization | |
replacelistPost: dictionary of replacement to perform on the text after tokanization | |
''' | |
# sentence tokenizer | |
allPunct = ["?", "!", "\\", "/", ")", "(", ".", ",", ";", ":"] | |
def tokenize(self, text, ignoredPuncts=["?", "!", "\\", "/", ")", "("], | |
keptPuncts=[".", ",", ";", ":"], endPunct=[">", "<", ":"], delim=" ", | |
clean=False, replacelistPre=dict(), replacelistPost=dict()): | |
if clean: | |
for word in replacelistPre: | |
origText = text | |
text = text.replace(word, replacelistPre[word]) | |
if (origText != text): | |
print(origText) | |
print(text) | |
print("") | |
for punct in endPunct: | |
if text[-1] == punct: | |
print(text) | |
text = text[:-1] | |
print(text) | |
print("") | |
for punct in keptPuncts: | |
text = text.replace(punct, delim + punct + delim) | |
for punct in ignoredPuncts: | |
text = text.replace(punct, "") | |
ret = text.lower().split(delim) | |
if clean: | |
origRet = ret | |
ret = [replacelistPost.get(word, word) for word in ret] | |
if origRet != ret: | |
print(origRet) | |
print(ret) | |
ret = [t for t in ret if t != ""] | |
return ret | |
# Read class' generated files. | |
# files interface | |
def readFiles(self, instancesFilename): | |
with open(instancesFilename, "r") as inFile: | |
instances = json.load(inFile) | |
with open(config.questionDictFile(), "rb") as inFile: | |
self.questionDict = pickle.load(inFile) | |
with open(config.answerDictFile(), "rb") as inFile: | |
self.answerDict = pickle.load(inFile) | |
with open(config.qaDictFile(), "rb") as inFile: | |
self.qaDict = pickle.load(inFile) | |
return instances | |
''' | |
Generate class' files. Save json representation of instances and | |
symbols-to-integers dictionaries. | |
''' | |
def writeFiles(self, instances, instancesFilename): | |
with open(instancesFilename, "w") as outFile: | |
json.dump(instances, outFile) | |
with open(config.questionDictFile(), "wb") as outFile: | |
pickle.dump(self.questionDict, outFile) | |
with open(config.answerDictFile(), "wb") as outFile: | |
pickle.dump(self.answerDict, outFile) | |
with open(config.qaDictFile(), "wb") as outFile: | |
pickle.dump(self.qaDict, outFile) | |
# Write prediction json to file and optionally a one-answer-per-line output file | |
def writePreds(self, res, tier, suffix=""): | |
if res is None: | |
return | |
preds = res["preds"] | |
sortedPreds = sorted(preds, key=lambda instance: instance["index"]) | |
with open(config.predsFile(tier + suffix), "w") as outFile: | |
outFile.write(json.dumps(sortedPreds)) | |
with open(config.answersFile(tier + suffix), "w") as outFile: | |
for instance in sortedPreds: | |
writeline(outFile, instance["prediction"]) | |
def readPDF(self, instancesFilename): | |
instances = [] | |
if os.path.exists(instancesFilename): | |
instances = self.readFiles(instancesFilename) | |
return instances | |
def readData(self, datasetFilename, instancesFilename, train): | |
# data extraction | |
datasetReader = { | |
"PDF": self.readPDF | |
} | |
return datasetReader[config.dataset](datasetFilename, instancesFilename, train) | |
def vectorizeData(self, data): | |
# if "SHARED" tie symbol representations in questions and answers | |
if config.ansEmbMod == "SHARED": | |
qDict = self.qaDict | |
else: | |
qDict = self.questionDict | |
encodedQuestion = [qDict.encodeSequence(d["questionSeq"]) for d in data] | |
question, questionL = vectorize2DList(encodedQuestion) | |
# pass the whole instances? if heavy then not good | |
imageId = [d["imageId"] for d in data] | |
instance = data | |
return {"question": question, | |
"questionLength": questionL, | |
"imageId": imageId | |
} | |
# Separates data based on a field length | |
def lseparator(self, key, lims): | |
maxI = len(lims) | |
def separatorFn(x): | |
v = x[key] | |
for i, lim in enumerate(lims): | |
if len(v) < lim: | |
return i | |
return maxI | |
return {"separate": separatorFn, "groupsNum": maxI + 1} | |
# Buckets data to groups using a separator | |
def bucket(self, instances, separator): | |
buckets = [[] for i in range(separator["groupsNum"])] | |
for instance in instances: | |
bucketI = separator["separate"](instance) | |
buckets[bucketI].append(instance) | |
return [bucket for bucket in buckets if len(bucket) > 0] | |
# Re-buckets bucket list given a seperator | |
def rebucket(self, buckets, separator): | |
res = [] | |
for bucket in buckets: | |
res += self.bucket(bucket, separator) | |
return res | |
# Buckets data based on question / program length | |
def bucketData(self, data, noBucket=False): | |
if noBucket: | |
buckets = [data] | |
else: | |
if config.noBucket: | |
buckets = [data] | |
elif config.noRebucket: | |
questionSep = self.lseparator("questionSeq", config.questionLims) | |
buckets = self.bucket(data, questionSep) | |
else: | |
programSep = self.lseparator("programSeq", config.programLims) | |
questionSep = self.lseparator("questionSeq", config.questionLims) | |
buckets = self.bucket(data, programSep) | |
buckets = self.rebucket(buckets, questionSep) | |
return buckets | |
''' | |
Prepares data: | |
1. Filters data according to above arguments. | |
2. Takes only a subset of the data based on config.trainedNum / config.testedNum | |
3. Buckets data according to question / program length | |
4. Vectorizes data into numpy arrays | |
''' | |
def prepareData(self, data, train, filterKey=None, noBucket=False): | |
filterDefault = {"maxQLength": 0, "maxPLength": 0, "onlyChain": False, "filterOp": 0} | |
filterTrain = {"maxQLength": config.tMaxQ, "maxPLength": config.tMaxP, | |
"onlyChain": config.tOnlyChain, "filterOp": config.tFilterOp} | |
filterVal = {"maxQLength": config.vMaxQ, "maxPLength": config.vMaxP, | |
"onlyChain": config.vOnlyChain, "filterOp": config.vFilterOp} | |
filters = {"train": filterTrain, "evalTrain": filterTrain, | |
"val": filterVal, "test": filterDefault} | |
if filterKey is None: | |
fltr = filterDefault | |
else: | |
fltr = filters[filterKey] | |
# split data when finetuning on validation set | |
if config.trainExtra and config.extraVal and (config.finetuneNum > 0): | |
if train: | |
data = data[:config.finetuneNum] | |
else: | |
data = data[config.finetuneNum:] | |
typeFilter = config.typeFilters[fltr["filterOp"]] | |
# filter specific settings | |
if fltr["onlyChain"]: | |
data = [d for d in data if all((len(inputNum) < 2) for inputNum in d["programInputs"])] | |
if fltr["maxQLength"] > 0: | |
data = [d for d in data if len(d["questionSeq"]) <= fltr["maxQLength"]] | |
if fltr["maxPLength"] > 0: | |
data = [d for d in data if len(d["programSeq"]) <= fltr["maxPLength"]] | |
if len(typeFilter) > 0: | |
data = [d for d in data if d["programSeq"][-1] not in typeFilter] | |
# run on subset of the data. If 0 then use all data | |
num = config.trainedNum if train else config.testedNum | |
# retainVal = True to retain same clevr_sample of validation across runs | |
if (not train) and (not config.retainVal): | |
random.shuffle(data) | |
if num > 0: | |
data = data[:num] | |
# set number to match dataset size | |
if train: | |
config.trainedNum = len(data) | |
else: | |
config.testedNum = len(data) | |
# bucket | |
buckets = self.bucketData(data, noBucket=noBucket) | |
# vectorize | |
return [self.vectorizeData(bucket) for bucket in buckets] | |
# Prepares all the tiers of a dataset. See prepareData method for further details. | |
def prepareDataset(self, dataset, noBucket=False): | |
if dataset is None: | |
return None | |
for tier in dataset: | |
if dataset[tier] is not None: | |
dataset[tier]["data"] = self.prepareData(dataset[tier]["instances"], | |
train=dataset[tier]["train"], filterKey=tier, | |
noBucket=noBucket) | |
for tier in dataset: | |
if dataset[tier] is not None: | |
del dataset[tier]["instances"] | |
return dataset | |
# Initializes word embeddings to random uniform / random normal / GloVe. | |
def initializeWordEmbeddings(self, wordsDict=None, noPadding=False): | |
# default dictionary to use for embeddings | |
if wordsDict is None: | |
wordsDict = self.questionDict | |
# uniform initialization | |
if config.wrdEmbUniform: | |
lowInit = -1.0 * config.wrdEmbScale | |
highInit = 1.0 * config.wrdEmbScale | |
embeddings = np.random.uniform(low=lowInit, high=highInit, | |
size=(wordsDict.getNumSymbols(), config.wrdEmbDim)) | |
# normal initialization | |
else: | |
embeddings = config.wrdEmbScale * np.random.randn(wordsDict.getNumSymbols(), | |
config.wrdEmbDim) | |
# if wrdEmbRandom = False, use GloVE | |
counter = 0 | |
if (not config.wrdEmbRandom): | |
with open(config.wordVectorsFile, 'r') as inFile: | |
for line in inFile: | |
line = line.strip().split() | |
word = line[0].lower() | |
vector = [float(x) for x in line[1:]] | |
index = wordsDict.sym2id.get(word) | |
if index is not None: | |
embeddings[index] = vector | |
counter += 1 | |
print(counter) | |
print(self.questionDict.sym2id) | |
print(len(self.questionDict.sym2id)) | |
print(self.answerDict.sym2id) | |
print(len(self.answerDict.sym2id)) | |
print(self.qaDict.sym2id) | |
print(len(self.qaDict.sym2id)) | |
if noPadding: | |
return embeddings # no embedding for padding symbol | |
else: | |
return embeddings[1:] | |
''' | |
Initializes words embeddings for question words and optionally for answer words | |
(when config.ansEmbMod == "BOTH"). If config.ansEmbMod == "SHARED", tie embeddings for | |
question and answer same symbols. | |
''' | |
def initializeQAEmbeddings(self): | |
# use same embeddings for questions and answers | |
if config.ansEmbMod == "SHARED": | |
qaEmbeddings = self.initializeWordEmbeddings(self.qaDict) | |
ansMap = np.array([self.qaDict.sym2id[sym] for sym in self.answerDict.id2sym]) | |
embeddings = {"qa": qaEmbeddings, "ansMap": ansMap} | |
# use different embeddings for questions and answers | |
else: | |
qEmbeddings = self.initializeWordEmbeddings(self.questionDict) | |
aEmbeddings = None | |
if config.ansEmbMod == "BOTH": | |
aEmbeddings = self.initializeWordEmbeddings(self.answerDict, noPadding=True) | |
embeddings = {"q": qEmbeddings, "a": aEmbeddings} | |
return embeddings | |
''' | |
Preprocesses a given dataset into numpy arrays: | |
1. Reads the input data files into dictionary. | |
2. Saves the results jsons in files and loads them instead of parsing input if files exist/ | |
3. Initializes word embeddings to random / GloVe. | |
4. Optionally filters data according to given filters. | |
5. Encodes and vectorize the data into numpy arrays. | |
5. Buckets the data according to the instances length. | |
''' | |
def preprocessData(self, question, debug=False): | |
# Read data into json and symbols' dictionaries | |
print(bold("Loading data...")) | |
start = time.time() | |
with open(config.questionDictFile(), "rb") as inFile: | |
self.questionDict = pickle.load(inFile) | |
with open(config.qaDictFile(), "rb") as inFile: | |
self.qaDict = pickle.load(inFile) | |
with open(config.answerDictFile(), "rb") as inFile: | |
self.answerDict = pickle.load(inFile) | |
question = question.replace('?', '').replace(', ', '').lower().split() | |
encodedQuestion = self.questionDict.encodeSequence(question) | |
data = {'question': np.array([encodedQuestion]), 'questionLength': np.array([len(encodedQuestion)])} | |
print("took {:.2f} seconds".format(time.time() - start)) | |
# Initialize word embeddings (random / glove) | |
print(bold("Loading word vectors...")) | |
start = time.time() | |
embeddings = self.initializeQAEmbeddings() | |
print("took {:.2f} seconds".format(time.time() - start)) | |
answer = 'yes' # DUMMY_ANSWER | |
self.answerDict.addSeq([answer]) | |
self.qaDict.addSeq([answer]) | |
config.questionWordsNum = self.questionDict.getNumSymbols() | |
config.answerWordsNum = self.answerDict.getNumSymbols() | |
return data, embeddings, self.answerDict | |