ydin0771 commited on
Commit
ba880ef
β€’
1 Parent(s): 22774f1

Upload preprocess.py

Browse files
Files changed (1) hide show
  1. preprocess.py +551 -0
preprocess.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import random
4
+ import json
5
+ import pickle
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from termcolor import colored
9
+ from program_translator import ProgramTranslator #
10
+ from config import config
11
+
12
+
13
+ # Print bold tex
14
+ def bold(txt):
15
+ return colored(str(txt), attrs=["bold"])
16
+
17
+
18
+ # Print bold and colored text
19
+ def bcolored(txt, color):
20
+ return colored(str(txt), color, attrs=["bold"])
21
+
22
+
23
+ # Write a line to file
24
+ def writeline(f, line):
25
+ f.write(str(line) + "\n")
26
+
27
+
28
+ # Write a list to file
29
+ def writelist(f, l):
30
+ writeline(f, ",".join(map(str, l)))
31
+
32
+
33
+ # 2d list to numpy
34
+ def vectorize2DList(items, minX=0, minY=0, dtype=np.int):
35
+ maxX = max(len(items), minX)
36
+ maxY = max([len(item) for item in items] + [minY])
37
+ t = np.zeros((maxX, maxY), dtype=dtype)
38
+ tLengths = np.zeros((maxX,), dtype=np.int)
39
+ for i, item in enumerate(items):
40
+ t[i, 0:len(item)] = np.array(item, dtype=dtype)
41
+ tLengths[i] = len(item)
42
+ return t, tLengths
43
+
44
+
45
+ # 3d list to numpy
46
+ def vectorize3DList(items, minX=0, minY=0, minZ=0, dtype=np.int):
47
+ maxX = max(len(items), minX)
48
+ maxY = max([len(item) for item in items] + [minY])
49
+ maxZ = max([len(subitem) for item in items for subitem in item] + [minZ])
50
+ t = np.zeros((maxX, maxY, maxZ), dtype=dtype)
51
+ tLengths = np.zeros((maxX, maxY), dtype=np.int)
52
+ for i, item in enumerate(items):
53
+ for j, subitem in enumerate(item):
54
+ t[i, j, 0:len(subitem)] = np.array(subitem, dtype=dtype)
55
+ tLengths[i, j] = len(subitem)
56
+ return t, tLengths
57
+
58
+
59
+ '''
60
+ Encodes text into integers. Keeps dictionary between string words (symbols)
61
+ and their matching integers. Supports encoding and decoding.
62
+ '''
63
+
64
+
65
+ class SymbolDict(object):
66
+ def __init__(self, empty=False):
67
+ self.padding = "<PAD>"
68
+ self.unknown = "<UNK>"
69
+ self.start = "<START>"
70
+ self.end = "<END>"
71
+
72
+ self.invalidSymbols = [self.padding, self.unknown, self.start, self.end]
73
+
74
+ if empty:
75
+ self.sym2id = {}
76
+ self.id2sym = []
77
+ else:
78
+ self.sym2id = {self.padding: 0, self.unknown: 1, self.start: 2, self.end: 3}
79
+ self.id2sym = [self.padding, self.unknown, self.start, self.end]
80
+ self.allSeqs = []
81
+
82
+ def getNumSymbols(self):
83
+ return len(self.sym2id)
84
+
85
+ def isPadding(self, enc):
86
+ return enc == 0
87
+
88
+ def isUnknown(self, enc):
89
+ return enc == 1
90
+
91
+ def isStart(self, enc):
92
+ return enc == 2
93
+
94
+ def isEnd(self, enc):
95
+ return enc == 3
96
+
97
+ def isValid(self, enc):
98
+ return enc < self.getNumSymbols() and enc >= len(self.invalidSymbols)
99
+
100
+ def resetSeqs(self):
101
+ self.allSeqs = []
102
+
103
+ def addSeq(self, seq):
104
+ self.allSeqs += seq
105
+
106
+ # Call to create the words-to-integers vocabulary after (reading word sequences with addSeq).
107
+ def createVocab(self, minCount=0):
108
+ counter = {}
109
+ for symbol in self.allSeqs:
110
+ counter[symbol] = counter.get(symbol, 0) + 1
111
+ for symbol in counter:
112
+ if counter[symbol] > minCount and (symbol not in self.sym2id):
113
+ self.sym2id[symbol] = self.getNumSymbols()
114
+ self.id2sym.append(symbol)
115
+
116
+ # Encodes a symbol. Returns the matching integer.
117
+ def encodeSym(self, symbol):
118
+ if symbol not in self.sym2id:
119
+ symbol = self.unknown
120
+ return self.sym2id[symbol]
121
+
122
+ '''
123
+ Encodes a sequence of symbols.
124
+ Optionally add start, or end symbols.
125
+ Optionally reverse sequence
126
+ '''
127
+
128
+ def encodeSequence(self, decoded, addStart=False, addEnd=False, reverse=False):
129
+ if reverse:
130
+ decoded.reverse()
131
+ if addStart:
132
+ decoded = [self.start] + decoded
133
+ if addEnd:
134
+ decoded = decoded + [self.end]
135
+ encoded = [self.encodeSym(symbol) for symbol in decoded]
136
+ return encoded
137
+
138
+ # Decodes an integer into its symbol
139
+ def decodeId(self, enc):
140
+ return self.id2sym[enc] if enc < self.getNumSymbols() else self.unknown
141
+
142
+ '''
143
+ Decodes a sequence of integers into their symbols.
144
+ If delim is given, joins the symbols using delim,
145
+ Optionally reverse the resulted sequence
146
+ '''
147
+
148
+ def decodeSequence(self, encoded, delim=None, reverse=False, stopAtInvalid=True):
149
+ length = 0
150
+ for i in range(len(encoded)):
151
+ if not self.isValid(encoded[i]) and stopAtInvalid:
152
+ break
153
+ length += 1
154
+ encoded = encoded[:length]
155
+
156
+ decoded = [self.decodeId(enc) for enc in encoded]
157
+ if reverse:
158
+ decoded.reverse()
159
+
160
+ if delim is not None:
161
+ return delim.join(decoded)
162
+
163
+ return decoded
164
+
165
+
166
+ '''
167
+ Preprocesses a given dataset into numpy arrays.
168
+ By calling preprocess, the class:
169
+ 1. Reads the input data files into dictionary.
170
+ 2. Saves the results jsons in files and loads them instead of parsing input if files exist/
171
+ 3. Initializes word embeddings to random / GloVe.
172
+ 4. Optionally filters data according to given filters.
173
+ 5. Encodes and vectorize the data into numpy arrays.
174
+ 6. Buckets the data according to the instances length.
175
+ '''
176
+
177
+
178
+ class Preprocesser(object):
179
+ def __init__(self):
180
+ self.questionDict = SymbolDict()
181
+ self.answerDict = SymbolDict(empty=True)
182
+ self.qaDict = SymbolDict()
183
+
184
+ self.specificDatasetDicts = None
185
+
186
+ self.programDict = SymbolDict()
187
+ self.programTranslator = ProgramTranslator(self.programDict, 2)
188
+
189
+ '''
190
+ Tokenizes string into list of symbols.
191
+
192
+ Args:
193
+ text: raw string to tokenize.
194
+ ignorePuncts: punctuation to ignore
195
+ keptPunct: punctuation to keep (as symbol)
196
+ endPunct: punctuation to remove if appears at the end
197
+ delim: delimiter between symbols
198
+ clean: True to replace text in string
199
+ replacelistPre: dictionary of replacement to perform on the text before tokanization
200
+ replacelistPost: dictionary of replacement to perform on the text after tokanization
201
+ '''
202
+ # sentence tokenizer
203
+ allPunct = ["?", "!", "\\", "/", ")", "(", ".", ",", ";", ":"]
204
+
205
+ def tokenize(self, text, ignoredPuncts=["?", "!", "\\", "/", ")", "("],
206
+ keptPuncts=[".", ",", ";", ":"], endPunct=[">", "<", ":"], delim=" ",
207
+ clean=False, replacelistPre=dict(), replacelistPost=dict()):
208
+
209
+ if clean:
210
+ for word in replacelistPre:
211
+ origText = text
212
+ text = text.replace(word, replacelistPre[word])
213
+ if (origText != text):
214
+ print(origText)
215
+ print(text)
216
+ print("")
217
+
218
+ for punct in endPunct:
219
+ if text[-1] == punct:
220
+ print(text)
221
+ text = text[:-1]
222
+ print(text)
223
+ print("")
224
+
225
+ for punct in keptPuncts:
226
+ text = text.replace(punct, delim + punct + delim)
227
+
228
+ for punct in ignoredPuncts:
229
+ text = text.replace(punct, "")
230
+
231
+ ret = text.lower().split(delim)
232
+
233
+ if clean:
234
+ origRet = ret
235
+ ret = [replacelistPost.get(word, word) for word in ret]
236
+ if origRet != ret:
237
+ print(origRet)
238
+ print(ret)
239
+
240
+ ret = [t for t in ret if t != ""]
241
+ return ret
242
+
243
+ # Read class' generated files.
244
+ # files interface
245
+ def readFiles(self, instancesFilename):
246
+ with open(instancesFilename, "r") as inFile:
247
+ instances = json.load(inFile)
248
+
249
+ with open(config.questionDictFile(), "rb") as inFile:
250
+ self.questionDict = pickle.load(inFile)
251
+
252
+ with open(config.answerDictFile(), "rb") as inFile:
253
+ self.answerDict = pickle.load(inFile)
254
+
255
+ with open(config.qaDictFile(), "rb") as inFile:
256
+ self.qaDict = pickle.load(inFile)
257
+
258
+ return instances
259
+
260
+ '''
261
+ Generate class' files. Save json representation of instances and
262
+ symbols-to-integers dictionaries.
263
+ '''
264
+
265
+ def writeFiles(self, instances, instancesFilename):
266
+ with open(instancesFilename, "w") as outFile:
267
+ json.dump(instances, outFile)
268
+
269
+ with open(config.questionDictFile(), "wb") as outFile:
270
+ pickle.dump(self.questionDict, outFile)
271
+
272
+ with open(config.answerDictFile(), "wb") as outFile:
273
+ pickle.dump(self.answerDict, outFile)
274
+
275
+ with open(config.qaDictFile(), "wb") as outFile:
276
+ pickle.dump(self.qaDict, outFile)
277
+
278
+ # Write prediction json to file and optionally a one-answer-per-line output file
279
+ def writePreds(self, res, tier, suffix=""):
280
+ if res is None:
281
+ return
282
+ preds = res["preds"]
283
+ sortedPreds = sorted(preds, key=lambda instance: instance["index"])
284
+ with open(config.predsFile(tier + suffix), "w") as outFile:
285
+ outFile.write(json.dumps(sortedPreds))
286
+ with open(config.answersFile(tier + suffix), "w") as outFile:
287
+ for instance in sortedPreds:
288
+ writeline(outFile, instance["prediction"])
289
+
290
+ def readPDF(self, instancesFilename):
291
+ instances = []
292
+
293
+ if os.path.exists(instancesFilename):
294
+ instances = self.readFiles(instancesFilename)
295
+
296
+ return instances
297
+
298
+ def readData(self, datasetFilename, instancesFilename, train):
299
+ # data extraction
300
+ datasetReader = {
301
+ "PDF": self.readPDF
302
+ }
303
+
304
+ return datasetReader[config.dataset](datasetFilename, instancesFilename, train)
305
+
306
+ def vectorizeData(self, data):
307
+ # if "SHARED" tie symbol representations in questions and answers
308
+ if config.ansEmbMod == "SHARED":
309
+ qDict = self.qaDict
310
+ else:
311
+ qDict = self.questionDict
312
+
313
+ encodedQuestion = [qDict.encodeSequence(d["questionSeq"]) for d in data]
314
+ question, questionL = vectorize2DList(encodedQuestion)
315
+
316
+ # pass the whole instances? if heavy then not good
317
+ imageId = [d["imageId"] for d in data]
318
+ instance = data
319
+
320
+ return {"question": question,
321
+ "questionLength": questionL,
322
+ "imageId": imageId
323
+ }
324
+
325
+ # Separates data based on a field length
326
+ def lseparator(self, key, lims):
327
+ maxI = len(lims)
328
+
329
+ def separatorFn(x):
330
+ v = x[key]
331
+ for i, lim in enumerate(lims):
332
+ if len(v) < lim:
333
+ return i
334
+ return maxI
335
+
336
+ return {"separate": separatorFn, "groupsNum": maxI + 1}
337
+
338
+ # Buckets data to groups using a separator
339
+ def bucket(self, instances, separator):
340
+ buckets = [[] for i in range(separator["groupsNum"])]
341
+ for instance in instances:
342
+ bucketI = separator["separate"](instance)
343
+ buckets[bucketI].append(instance)
344
+ return [bucket for bucket in buckets if len(bucket) > 0]
345
+
346
+ # Re-buckets bucket list given a seperator
347
+ def rebucket(self, buckets, separator):
348
+ res = []
349
+ for bucket in buckets:
350
+ res += self.bucket(bucket, separator)
351
+ return res
352
+
353
+ # Buckets data based on question / program length
354
+ def bucketData(self, data, noBucket=False):
355
+ if noBucket:
356
+ buckets = [data]
357
+ else:
358
+ if config.noBucket:
359
+ buckets = [data]
360
+ elif config.noRebucket:
361
+ questionSep = self.lseparator("questionSeq", config.questionLims)
362
+ buckets = self.bucket(data, questionSep)
363
+ else:
364
+ programSep = self.lseparator("programSeq", config.programLims)
365
+ questionSep = self.lseparator("questionSeq", config.questionLims)
366
+ buckets = self.bucket(data, programSep)
367
+ buckets = self.rebucket(buckets, questionSep)
368
+ return buckets
369
+
370
+ '''
371
+ Prepares data:
372
+ 1. Filters data according to above arguments.
373
+ 2. Takes only a subset of the data based on config.trainedNum / config.testedNum
374
+ 3. Buckets data according to question / program length
375
+ 4. Vectorizes data into numpy arrays
376
+ '''
377
+
378
+ def prepareData(self, data, train, filterKey=None, noBucket=False):
379
+ filterDefault = {"maxQLength": 0, "maxPLength": 0, "onlyChain": False, "filterOp": 0}
380
+
381
+ filterTrain = {"maxQLength": config.tMaxQ, "maxPLength": config.tMaxP,
382
+ "onlyChain": config.tOnlyChain, "filterOp": config.tFilterOp}
383
+
384
+ filterVal = {"maxQLength": config.vMaxQ, "maxPLength": config.vMaxP,
385
+ "onlyChain": config.vOnlyChain, "filterOp": config.vFilterOp}
386
+
387
+ filters = {"train": filterTrain, "evalTrain": filterTrain,
388
+ "val": filterVal, "test": filterDefault}
389
+
390
+ if filterKey is None:
391
+ fltr = filterDefault
392
+ else:
393
+ fltr = filters[filterKey]
394
+
395
+ # split data when finetuning on validation set
396
+ if config.trainExtra and config.extraVal and (config.finetuneNum > 0):
397
+ if train:
398
+ data = data[:config.finetuneNum]
399
+ else:
400
+ data = data[config.finetuneNum:]
401
+
402
+ typeFilter = config.typeFilters[fltr["filterOp"]]
403
+ # filter specific settings
404
+ if fltr["onlyChain"]:
405
+ data = [d for d in data if all((len(inputNum) < 2) for inputNum in d["programInputs"])]
406
+ if fltr["maxQLength"] > 0:
407
+ data = [d for d in data if len(d["questionSeq"]) <= fltr["maxQLength"]]
408
+ if fltr["maxPLength"] > 0:
409
+ data = [d for d in data if len(d["programSeq"]) <= fltr["maxPLength"]]
410
+ if len(typeFilter) > 0:
411
+ data = [d for d in data if d["programSeq"][-1] not in typeFilter]
412
+
413
+ # run on subset of the data. If 0 then use all data
414
+ num = config.trainedNum if train else config.testedNum
415
+ # retainVal = True to retain same clevr_sample of validation across runs
416
+ if (not train) and (not config.retainVal):
417
+ random.shuffle(data)
418
+ if num > 0:
419
+ data = data[:num]
420
+ # set number to match dataset size
421
+ if train:
422
+ config.trainedNum = len(data)
423
+ else:
424
+ config.testedNum = len(data)
425
+
426
+ # bucket
427
+ buckets = self.bucketData(data, noBucket=noBucket)
428
+
429
+ # vectorize
430
+ return [self.vectorizeData(bucket) for bucket in buckets]
431
+
432
+ # Prepares all the tiers of a dataset. See prepareData method for further details.
433
+ def prepareDataset(self, dataset, noBucket=False):
434
+ if dataset is None:
435
+ return None
436
+
437
+ for tier in dataset:
438
+ if dataset[tier] is not None:
439
+ dataset[tier]["data"] = self.prepareData(dataset[tier]["instances"],
440
+ train=dataset[tier]["train"], filterKey=tier,
441
+ noBucket=noBucket)
442
+
443
+ for tier in dataset:
444
+ if dataset[tier] is not None:
445
+ del dataset[tier]["instances"]
446
+
447
+ return dataset
448
+
449
+ # Initializes word embeddings to random uniform / random normal / GloVe.
450
+ def initializeWordEmbeddings(self, wordsDict=None, noPadding=False):
451
+ # default dictionary to use for embeddings
452
+ if wordsDict is None:
453
+ wordsDict = self.questionDict
454
+
455
+ # uniform initialization
456
+ if config.wrdEmbUniform:
457
+ lowInit = -1.0 * config.wrdEmbScale
458
+ highInit = 1.0 * config.wrdEmbScale
459
+ embeddings = np.random.uniform(low=lowInit, high=highInit,
460
+ size=(wordsDict.getNumSymbols(), config.wrdEmbDim))
461
+ # normal initialization
462
+ else:
463
+ embeddings = config.wrdEmbScale * np.random.randn(wordsDict.getNumSymbols(),
464
+ config.wrdEmbDim)
465
+
466
+ # if wrdEmbRandom = False, use GloVE
467
+ counter = 0
468
+ if (not config.wrdEmbRandom):
469
+ with open(config.wordVectorsFile, 'r') as inFile:
470
+ for line in inFile:
471
+ line = line.strip().split()
472
+ word = line[0].lower()
473
+ vector = [float(x) for x in line[1:]]
474
+ index = wordsDict.sym2id.get(word)
475
+ if index is not None:
476
+ embeddings[index] = vector
477
+ counter += 1
478
+
479
+ print(counter)
480
+ print(self.questionDict.sym2id)
481
+ print(len(self.questionDict.sym2id))
482
+ print(self.answerDict.sym2id)
483
+ print(len(self.answerDict.sym2id))
484
+ print(self.qaDict.sym2id)
485
+ print(len(self.qaDict.sym2id))
486
+
487
+ if noPadding:
488
+ return embeddings # no embedding for padding symbol
489
+ else:
490
+ return embeddings[1:]
491
+
492
+ '''
493
+ Initializes words embeddings for question words and optionally for answer words
494
+ (when config.ansEmbMod == "BOTH"). If config.ansEmbMod == "SHARED", tie embeddings for
495
+ question and answer same symbols.
496
+ '''
497
+
498
+ def initializeQAEmbeddings(self):
499
+ # use same embeddings for questions and answers
500
+ if config.ansEmbMod == "SHARED":
501
+ qaEmbeddings = self.initializeWordEmbeddings(self.qaDict)
502
+ ansMap = np.array([self.qaDict.sym2id[sym] for sym in self.answerDict.id2sym])
503
+ embeddings = {"qa": qaEmbeddings, "ansMap": ansMap}
504
+ # use different embeddings for questions and answers
505
+ else:
506
+ qEmbeddings = self.initializeWordEmbeddings(self.questionDict)
507
+ aEmbeddings = None
508
+ if config.ansEmbMod == "BOTH":
509
+ aEmbeddings = self.initializeWordEmbeddings(self.answerDict, noPadding=True)
510
+ embeddings = {"q": qEmbeddings, "a": aEmbeddings}
511
+ return embeddings
512
+
513
+ '''
514
+ Preprocesses a given dataset into numpy arrays:
515
+ 1. Reads the input data files into dictionary.
516
+ 2. Saves the results jsons in files and loads them instead of parsing input if files exist/
517
+ 3. Initializes word embeddings to random / GloVe.
518
+ 4. Optionally filters data according to given filters.
519
+ 5. Encodes and vectorize the data into numpy arrays.
520
+ 5. Buckets the data according to the instances length.
521
+ '''
522
+
523
+ def preprocessData(self, question, debug=False):
524
+ # Read data into json and symbols' dictionaries
525
+ print(bold("Loading data..."))
526
+ start = time.time()
527
+ with open(config.questionDictFile(), "rb") as inFile:
528
+ self.questionDict = pickle.load(inFile)
529
+ with open(config.qaDictFile(), "rb") as inFile:
530
+ self.qaDict = pickle.load(inFile)
531
+ with open(config.answerDictFile(), "rb") as inFile:
532
+ self.answerDict = pickle.load(inFile)
533
+ question = question.replace('?', '').replace(', ', '').lower().split()
534
+ encodedQuestion = self.questionDict.encodeSequence(question)
535
+ data = {'question': np.array([encodedQuestion]), 'questionLength': np.array([len(encodedQuestion)])}
536
+ print("took {:.2f} seconds".format(time.time() - start))
537
+
538
+ # Initialize word embeddings (random / glove)
539
+ print(bold("Loading word vectors..."))
540
+ start = time.time()
541
+ embeddings = self.initializeQAEmbeddings()
542
+ print("took {:.2f} seconds".format(time.time() - start))
543
+
544
+ answer = 'yes' # DUMMY_ANSWER
545
+ self.answerDict.addSeq([answer])
546
+ self.qaDict.addSeq([answer])
547
+
548
+ config.questionWordsNum = self.questionDict.getNumSymbols()
549
+ config.answerWordsNum = self.answerDict.getNumSymbols()
550
+
551
+ return data, embeddings, self.answerDict