ydin0771 commited on
Commit
94566d7
β€’
1 Parent(s): 5115fbb

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +653 -0
main.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import warnings
3
+
4
+ from extract_feature import build_model, run_image, get_img_feat
5
+
6
+ # warnings.filterwarnings("ignore", category=FutureWarning)
7
+ # warnings.filterwarnings("ignore", message="size changed")
8
+ warnings.filterwarnings("ignore")
9
+
10
+ import sys
11
+ import os
12
+ import time
13
+ import math
14
+ import random
15
+
16
+ try:
17
+ import Queue as queue
18
+ except ImportError:
19
+ import queue
20
+ import threading
21
+ import h5py
22
+ import json
23
+ import numpy as np
24
+ import tensorflow as tf
25
+ from termcolor import colored, cprint
26
+
27
+ from config import config, loadDatasetConfig, parseArgs
28
+ from preprocess import Preprocesser, bold, bcolored, writeline, writelist
29
+ from model import MACnet
30
+ from collections import defaultdict
31
+
32
+
33
+ ############################################# loggers #############################################
34
+
35
+ # Writes log header to file
36
+ def logInit():
37
+ with open(config.logFile(), "a+") as outFile:
38
+ writeline(outFile, config.expName)
39
+ headers = ["epoch", "trainAcc", "valAcc", "trainLoss", "valLoss"]
40
+ if config.evalTrain:
41
+ headers += ["evalTrainAcc", "evalTrainLoss"]
42
+ if config.extra:
43
+ if config.evalTrain:
44
+ headers += ["thAcc", "thLoss"]
45
+ headers += ["vhAcc", "vhLoss"]
46
+ headers += ["time", "lr"]
47
+
48
+ writelist(outFile, headers)
49
+ # lr assumed to be last
50
+
51
+
52
+ # Writes log record to file
53
+ def logRecord(epoch, epochTime, lr, trainRes, evalRes, extraEvalRes):
54
+ with open(config.logFile(), "a+") as outFile:
55
+ record = [epoch, trainRes["acc"], evalRes["val"]["acc"], trainRes["loss"], evalRes["val"]["loss"]]
56
+ if config.evalTrain:
57
+ record += [evalRes["evalTrain"]["acc"], evalRes["evalTrain"]["loss"]]
58
+ if config.extra:
59
+ if config.evalTrain:
60
+ record += [extraEvalRes["evalTrain"]["acc"], extraEvalRes["evalTrain"]["loss"]]
61
+ record += [extraEvalRes["val"]["acc"], extraEvalRes["val"]["loss"]]
62
+ record += [epochTime, lr]
63
+
64
+ writelist(outFile, record)
65
+
66
+
67
+ # Gets last logged epoch and learning rate
68
+ def lastLoggedEpoch():
69
+ with open(config.logFile(), "r") as inFile:
70
+ lastLine = list(inFile)[-1].split(",")
71
+ epoch = int(lastLine[0])
72
+ lr = float(lastLine[-1])
73
+ return epoch, lr
74
+
75
+
76
+ ################################## printing, output and analysis ##################################
77
+
78
+ # Analysis by type
79
+ analysisQuestionLims = [(0, 18), (19, float("inf"))]
80
+ analysisProgramLims = [(0, 12), (13, float("inf"))]
81
+
82
+ toArity = lambda instance: instance["programSeq"][-1].split("_", 1)[0]
83
+ toType = lambda instance: instance["programSeq"][-1].split("_", 1)[1]
84
+
85
+
86
+ def fieldLenIsInRange(field):
87
+ return lambda instance, group: \
88
+ (len(instance[field]) >= group[0] and
89
+ len(instance[field]) <= group[1])
90
+
91
+
92
+ # Groups instances based on a key
93
+ def grouperKey(toKey):
94
+ def grouper(instances):
95
+ res = defaultdict(list)
96
+ for instance in instances:
97
+ res[toKey(instance)].append(instance)
98
+ return res
99
+
100
+ return grouper
101
+
102
+
103
+ # Groups instances according to their match to condition
104
+ def grouperCond(groups, isIn):
105
+ def grouper(instances):
106
+ res = {}
107
+ for group in groups:
108
+ res[group] = (instance for instance in instances if isIn(instance, group))
109
+ return res
110
+
111
+ return grouper
112
+
113
+
114
+ groupers = {
115
+ "questionLength": grouperCond(analysisQuestionLims, fieldLenIsInRange("questionSeq")),
116
+ "programLength": grouperCond(analysisProgramLims, fieldLenIsInRange("programSeq")),
117
+ "arity": grouperKey(toArity),
118
+ "type": grouperKey(toType)
119
+ }
120
+
121
+
122
+ # Computes average
123
+ def avg(instances, field):
124
+ if len(instances) == 0:
125
+ return 0.0
126
+ return sum(instances[field]) / len(instances)
127
+
128
+
129
+ # Prints analysis of questions loss and accuracy by their group
130
+ def printAnalysis(res):
131
+ if config.analysisType != "":
132
+ print("Analysis by {type}".format(type=config.analysisType))
133
+ groups = groupers[config.analysisType](res["preds"])
134
+ for key in groups:
135
+ instances = groups[key]
136
+ avgLoss = avg(instances, "loss")
137
+ avgAcc = avg(instances, "acc")
138
+ num = len(instances)
139
+ print("Group {key}: Loss: {loss}, Acc: {acc}, Num: {num}".format(key, avgLoss, avgAcc, num))
140
+
141
+
142
+ # Print results for a tier
143
+ def printTierResults(tierName, res, color):
144
+ if res is None:
145
+ return
146
+
147
+ print("{tierName} Loss: {loss}, {tierName} accuracy: {acc}".format(tierName=tierName,
148
+ loss=bcolored(res["loss"], color),
149
+ acc=bcolored(res["acc"], color)))
150
+
151
+ printAnalysis(res)
152
+
153
+
154
+ # Prints dataset results (for several tiers)
155
+ def printDatasetResults(trainRes, evalRes):
156
+ printTierResults("Training", trainRes, "magenta")
157
+ printTierResults("Training EMA", evalRes["evalTrain"], "red")
158
+ printTierResults("Validation", evalRes["val"], "cyan")
159
+
160
+
161
+ # Writes predictions for several tiers
162
+ def writePreds(preprocessor, evalRes):
163
+ preprocessor.writePreds(evalRes, "_")
164
+
165
+
166
+ ############################################# session #############################################
167
+ # Initializes TF session. Sets GPU memory configuration.
168
+ def setSession():
169
+ sessionConfig = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
170
+ if config.allowGrowth:
171
+ sessionConfig.gpu_options.allow_growth = True
172
+ if config.maxMemory < 1.0:
173
+ sessionConfig.gpu_options.per_process_gpu_memory_fraction = config.maxMemory
174
+ return sessionConfig
175
+
176
+
177
+ ############################################## savers #############################################
178
+ # Initializes savers (standard, optional exponential-moving-average and optional for subset of variables)
179
+ def setSavers(model):
180
+ saver = tf.train.Saver(max_to_keep=config.weightsToKeep)
181
+
182
+ subsetSaver = None
183
+ if config.saveSubset:
184
+ isRelevant = lambda var: any(s in var.name for s in config.varSubset)
185
+ relevantVars = [var for var in tf.global_variables() if isRelevant(var)]
186
+ subsetSaver = tf.train.Saver(relevantVars, max_to_keep=config.weightsToKeep, allow_empty=True)
187
+
188
+ emaSaver = None
189
+ if config.useEMA:
190
+ emaSaver = tf.train.Saver(model.emaDict, max_to_keep=config.weightsToKeep)
191
+
192
+ return {
193
+ "saver": saver,
194
+ "subsetSaver": subsetSaver,
195
+ "emaSaver": emaSaver
196
+ }
197
+
198
+
199
+ ################################### restore / initialize weights ##################################
200
+ # Restores weights of specified / last epoch if on restore mod.
201
+ # Otherwise, initializes weights.
202
+ def loadWeights(sess, saver, init):
203
+ if config.restoreEpoch > 0 or config.restore:
204
+ # restore last epoch only if restoreEpoch isn't set
205
+ if config.restoreEpoch == 0:
206
+ # restore last logged epoch
207
+ config.restoreEpoch, config.lr = lastLoggedEpoch()
208
+ print(bcolored("Restoring epoch {} and lr {}".format(config.restoreEpoch, config.lr), "cyan"))
209
+ print(bcolored("Restoring weights", "blue"))
210
+ print(config.weightsFile(config.restoreEpoch))
211
+ saver.restore(sess, config.weightsFile(config.restoreEpoch))
212
+ epoch = config.restoreEpoch
213
+ else:
214
+ print(bcolored("Initializing weights", "blue"))
215
+ sess.run(init)
216
+ logInit()
217
+ epoch = 0
218
+
219
+ return epoch
220
+
221
+
222
+ ###################################### training / evaluation ######################################
223
+ # Chooses data to train on (main / extra) data.
224
+ def chooseTrainingData(data):
225
+ trainingData = data["main"]["train"]
226
+ alterData = None
227
+
228
+ if config.extra:
229
+ if config.trainExtra:
230
+ if config.extraVal:
231
+ trainingData = data["extra"]["val"]
232
+ else:
233
+ trainingData = data["extra"]["train"]
234
+ if config.alterExtra:
235
+ alterData = data["extra"]["train"]
236
+
237
+ return trainingData, alterData
238
+
239
+
240
+ #### evaluation
241
+ # Runs evaluation on train / val / test datasets.
242
+ def runEvaluation(sess, model, data, epoch, evalTrain=True, evalTest=False, getAtt=None):
243
+ if getAtt is None:
244
+ getAtt = config.getAtt
245
+ res = {"evalTrain": None, "val": None, "test": None}
246
+
247
+ if data is not None:
248
+ if evalTrain and config.evalTrain:
249
+ res["evalTrain"] = runEpoch(sess, model, data["evalTrain"], train=False, epoch=epoch, getAtt=getAtt)
250
+
251
+ res["val"] = runEpoch(sess, model, data["val"], train=False, epoch=epoch, getAtt=getAtt)
252
+
253
+ if evalTest or config.test:
254
+ res["test"] = runEpoch(sess, model, data["test"], train=False, epoch=epoch, getAtt=getAtt)
255
+
256
+ return res
257
+
258
+
259
+ ## training conditions (comparing current epoch result to prior ones)
260
+ def improveEnough(curr, prior, lr):
261
+ prevRes = prior["prev"]["res"]
262
+ currRes = curr["res"]
263
+
264
+ if prevRes is None:
265
+ return True
266
+
267
+ prevTrainLoss = prevRes["train"]["loss"]
268
+ currTrainLoss = currRes["train"]["loss"]
269
+ lossDiff = prevTrainLoss - currTrainLoss
270
+
271
+ notImprove = ((lossDiff < 0.015 and prevTrainLoss < 0.5 and lr > 0.00002) or \
272
+ (lossDiff < 0.008 and prevTrainLoss < 0.15 and lr > 0.00001) or \
273
+ (lossDiff < 0.003 and prevTrainLoss < 0.10 and lr > 0.000005))
274
+ # (prevTrainLoss < 0.2 and config.lr > 0.000015)
275
+
276
+ return not notImprove
277
+
278
+
279
+ def better(currRes, bestRes):
280
+ return currRes["val"]["acc"] > bestRes["val"]["acc"]
281
+
282
+
283
+ ############################################## data ###############################################
284
+ #### instances and batching
285
+ # Trims sequences based on their max length.
286
+ def trim2DVectors(vectors, vectorsLengths):
287
+ maxLength = np.max(vectorsLengths)
288
+ return vectors[:, :maxLength]
289
+
290
+
291
+ # Trims batch based on question length.
292
+ def trimData(data):
293
+ data["questions"] = trim2DVectors(data["questions"], data["questionLengths"])
294
+ return data
295
+
296
+
297
+ # Gets batch / bucket size.
298
+ def getLength(data):
299
+ return len(data["instances"])
300
+
301
+
302
+ # Selects the data entries that match the indices.
303
+ def selectIndices(data, indices):
304
+ def select(field, indices):
305
+ if type(field) is np.ndarray:
306
+ return field[indices]
307
+ if type(field) is list:
308
+ return [field[i] for i in indices]
309
+ else:
310
+ return field
311
+
312
+ selected = {k: select(d, indices) for k, d in data.items()}
313
+ return selected
314
+
315
+
316
+ # Batches data into a a list of batches of batchSize.
317
+ # Shuffles the data by default.
318
+ def getBatches(data, batchSize=None, shuffle=True):
319
+ batches = []
320
+
321
+ dataLen = getLength(data)
322
+ if batchSize is None or batchSize > dataLen:
323
+ batchSize = dataLen
324
+
325
+ indices = np.arange(dataLen)
326
+ if shuffle:
327
+ np.random.shuffle(indices)
328
+
329
+ for batchStart in range(0, dataLen, batchSize):
330
+ batchIndices = indices[batchStart: batchStart + batchSize]
331
+ # if len(batchIndices) == batchSize?
332
+ if len(batchIndices) >= config.gpusNum:
333
+ batch = selectIndices(data, batchIndices)
334
+ batches.append(batch)
335
+ # batchesIndices.append((data, batchIndices))
336
+
337
+ return batches
338
+
339
+
340
+ #### image batches
341
+ # Opens image files.
342
+ def openImageFiles(images):
343
+ images["imagesFile"] = h5py.File(images["imagesFilename"], "r")
344
+ images["imagesIds"] = None
345
+ if config.dataset == "NLVR":
346
+ with open(images["imageIdsFilename"], "r") as imageIdsFile:
347
+ images["imagesIds"] = json.load(imageIdsFile)
348
+
349
+ # Closes image files.
350
+
351
+
352
+ def closeImageFiles(images):
353
+ images["imagesFile"].close()
354
+
355
+
356
+ # Loads an images from file for a given data batch.
357
+ def loadImageBatch(images, batch):
358
+ imagesFile = images["imagesFile"]
359
+ id2idx = images["imagesIds"]
360
+ toIndex = lambda imageId: imageId
361
+ if id2idx is not None:
362
+ toIndex = lambda imageId: id2idx[imageId]
363
+ imageBatch = np.stack([imagesFile["features"][toIndex(imageId)] for imageId in batch["imageIds"]], axis=0)
364
+
365
+ return {"images": imageBatch, "imageIds": batch["imageIds"]}
366
+
367
+
368
+ # Loads images for several num batches in the batches list from start index.
369
+ def loadImageBatches(images, batches, start, num):
370
+ batches = batches[start: start + num]
371
+ return [loadImageBatch(images, batch) for batch in batches]
372
+
373
+
374
+ #### data alternation
375
+ # Alternates main training batches with extra data.
376
+ def alternateData(batches, alterData, dataLen):
377
+ alterData = alterData["data"][0] # data isn't bucketed for altered data
378
+
379
+ # computes number of repetitions
380
+ needed = math.ceil(len(batches) / config.alterNum)
381
+ print(bold("Extra batches needed: %d") % needed)
382
+ perData = math.ceil(getLength(alterData) / config.batchSize)
383
+ print(bold("Batches per extra data: %d") % perData)
384
+ repetitions = math.ceil(needed / perData)
385
+ print(bold("reps: %d") % repetitions)
386
+
387
+ # make alternate batches
388
+ alterBatches = []
389
+ for _ in range(repetitions):
390
+ repBatches = getBatches(alterData, batchSize=config.batchSize)
391
+ random.shuffle(repBatches)
392
+ alterBatches += repBatches
393
+ print(bold("Batches num: %d") + len(alterBatches))
394
+
395
+ # alternate data with extra data
396
+ curr = len(batches) - 1
397
+ for alterBatch in alterBatches:
398
+ if curr < 0:
399
+ # print(colored("too many" + str(curr) + " " + str(len(batches)),"red"))
400
+ break
401
+ batches.insert(curr, alterBatch)
402
+ dataLen += getLength(alterBatch)
403
+ curr -= config.alterNum
404
+
405
+ return batches, dataLen
406
+
407
+
408
+ ############################################ threading ############################################
409
+
410
+ imagesQueue = queue.Queue(maxsize=20) # config.tasksNum
411
+ inQueue = queue.Queue(maxsize=1)
412
+ outQueue = queue.Queue(maxsize=1)
413
+
414
+
415
+ # Runs a worker thread(s) to load images while training .
416
+ class StoppableThread(threading.Thread):
417
+ # Thread class with a stop() method. The thread itself has to check
418
+ # regularly for the stopped() condition.
419
+
420
+ def __init__(self, images, batches): # i
421
+ super(StoppableThread, self).__init__()
422
+ # self.i = i
423
+ self.images = images
424
+ self.batches = batches
425
+ self._stop_event = threading.Event()
426
+
427
+ # def __init__(self, args):
428
+ # super(StoppableThread, self).__init__(args = args)
429
+ # self._stop_event = threading.Event()
430
+
431
+ # def __init__(self, target, args):
432
+ # super(StoppableThread, self).__init__(target = target, args = args)
433
+ # self._stop_event = threading.Event()
434
+
435
+ def stop(self):
436
+ self._stop_event.set()
437
+
438
+ def stopped(self):
439
+ return self._stop_event.is_set()
440
+
441
+ def run(self):
442
+ while not self.stopped():
443
+ try:
444
+ batchNum = inQueue.get(timeout=60)
445
+ nextItem = loadImageBatches(self.images, self.batches, batchNum, int(config.taskSize / 2))
446
+ outQueue.put(nextItem)
447
+ # inQueue.task_done()
448
+ except:
449
+ pass
450
+ # print("worker %d done", self.i)
451
+
452
+
453
+ def loaderRun(images, batches):
454
+ batchNum = 0
455
+
456
+ # if config.workers == 2:
457
+ # worker = StoppableThread(images, batches) # i,
458
+ # worker.daemon = True
459
+ # worker.start()
460
+
461
+ # while batchNum < len(batches):
462
+ # inQueue.put(batchNum + int(config.taskSize / 2))
463
+ # nextItem1 = loadImageBatches(images, batches, batchNum, int(config.taskSize / 2))
464
+ # nextItem2 = outQueue.get()
465
+
466
+ # nextItem = nextItem1 + nextItem2
467
+ # assert len(nextItem) == min(config.taskSize, len(batches) - batchNum)
468
+ # batchNum += config.taskSize
469
+
470
+ # imagesQueue.put(nextItem)
471
+
472
+ # worker.stop()
473
+ # else:
474
+ while batchNum < len(batches):
475
+ nextItem = loadImageBatches(images, batches, batchNum, config.taskSize)
476
+ assert len(nextItem) == min(config.taskSize, len(batches) - batchNum)
477
+ batchNum += config.taskSize
478
+ imagesQueue.put(nextItem)
479
+
480
+ # print("manager loader done")
481
+
482
+
483
+ ########################################## stats tracking #########################################
484
+ # Computes exponential moving average.
485
+ def emaAvg(avg, value):
486
+ if avg is None:
487
+ return value
488
+ emaRate = 0.98
489
+ return avg * emaRate + value * (1 - emaRate)
490
+
491
+
492
+ # Initializes training statistics.
493
+ def initStats():
494
+ return {
495
+ "totalBatches": 0,
496
+ "totalData": 0,
497
+ "totalLoss": 0.0,
498
+ "totalCorrect": 0,
499
+ "loss": 0.0,
500
+ "acc": 0.0,
501
+ "emaLoss": None,
502
+ "emaAcc": None,
503
+ }
504
+
505
+
506
+ # Updates statistics with training results of a batch
507
+ def updateStats(stats, res, batch):
508
+ stats["totalBatches"] += 1
509
+ stats["totalData"] += getLength(batch)
510
+
511
+ stats["totalLoss"] += res["loss"]
512
+ stats["totalCorrect"] += res["correctNum"]
513
+
514
+ stats["loss"] = stats["totalLoss"] / stats["totalBatches"]
515
+ stats["acc"] = stats["totalCorrect"] / stats["totalData"]
516
+
517
+ stats["emaLoss"] = emaAvg(stats["emaLoss"], res["loss"])
518
+ stats["emaAcc"] = emaAvg(stats["emaAcc"], res["acc"])
519
+
520
+ return stats
521
+
522
+
523
+ # auto-encoder ae = {:2.4f} autoEncLoss,
524
+ # Translates training statistics into a string to print
525
+ def statsToStr(stats, res, epoch, batchNum, dataLen, startTime):
526
+ formatStr = "\reb {epoch},{batchNum} ({dataProcessed} / {dataLen:5d}), " + \
527
+ "t = {time} ({loadTime:2.2f}+{trainTime:2.2f}), " + \
528
+ "lr {lr}, l = {loss}, a = {acc}, avL = {avgLoss}, " + \
529
+ "avA = {avgAcc}, g = {gradNorm:2.4f}, " + \
530
+ "emL = {emaLoss:2.4f}, emA = {emaAcc:2.4f}; " + \
531
+ "{expname}" # {machine}/{gpu}"
532
+
533
+ s_epoch = bcolored("{:2d}".format(epoch), "green")
534
+ s_batchNum = "{:3d}".format(batchNum)
535
+ s_dataProcessed = bcolored("{:5d}".format(stats["totalData"]), "green")
536
+ s_dataLen = dataLen
537
+ s_time = bcolored("{:2.2f}".format(time.time() - startTime), "green")
538
+ s_loadTime = res["readTime"]
539
+ s_trainTime = res["trainTime"]
540
+ s_lr = bold(config.lr)
541
+ s_loss = bcolored("{:2.4f}".format(res["loss"]), "blue")
542
+ s_acc = bcolored("{:2.4f}".format(res["acc"]), "blue")
543
+ s_avgLoss = bcolored("{:2.4f}".format(stats["loss"]), "blue")
544
+ s_avgAcc = bcolored("{:2.4f}".format(stats["acc"]), "red")
545
+ s_gradNorm = res["gradNorm"]
546
+ s_emaLoss = stats["emaLoss"]
547
+ s_emaAcc = stats["emaAcc"]
548
+ s_expname = config.expName
549
+ # s_machine = bcolored(config.dataPath[9:11],"green")
550
+ # s_gpu = bcolored(config.gpus,"green")
551
+
552
+ return formatStr.format(epoch=s_epoch, batchNum=s_batchNum, dataProcessed=s_dataProcessed,
553
+ dataLen=s_dataLen, time=s_time, loadTime=s_loadTime,
554
+ trainTime=s_trainTime, lr=s_lr, loss=s_loss, acc=s_acc,
555
+ avgLoss=s_avgLoss, avgAcc=s_avgAcc, gradNorm=s_gradNorm,
556
+ emaLoss=s_emaLoss, emaAcc=s_emaAcc, expname=s_expname)
557
+ # machine = s_machine, gpu = s_gpu)
558
+
559
+
560
+ # collectRuntimeStats, writer = None,
561
+ '''
562
+ Runs an epoch with model and session over the data.
563
+ 1. Batches the data and optionally mix it with the extra alterData.
564
+ 2. Start worker threads to load images in parallel to training.
565
+ 3. Runs model for each batch, and gets results (e.g. loss, accuracy).
566
+ 4. Updates and prints statistics based on batch results.
567
+ 5. Once in a while (every config.saveEvery), save weights.
568
+
569
+ Args:
570
+ sess: TF session to run with.
571
+
572
+ model: model to process data. Has runBatch method that process a given batch.
573
+ (See model.py for further details).
574
+
575
+ data: data to use for training/evaluation.
576
+
577
+ epoch: epoch number.
578
+
579
+ saver: TF saver to save weights
580
+
581
+ calle: a method to call every number of iterations (config.calleEvery)
582
+
583
+ alterData: extra data to mix with main data while training.
584
+
585
+ getAtt: True to return model attentions.
586
+ '''
587
+
588
+
589
+ def main(question, image):
590
+ with open(config.configFile(), "a+") as outFile:
591
+ json.dump(vars(config), outFile)
592
+
593
+ # set gpus
594
+ if config.gpus != "":
595
+ config.gpusNum = len(config.gpus.split(","))
596
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus
597
+
598
+ tf.logging.set_verbosity(tf.logging.ERROR)
599
+
600
+ # process data
601
+ print(bold("Preprocess data..."))
602
+ start = time.time()
603
+ preprocessor = Preprocesser()
604
+ cnn_model = build_model()
605
+ imageData = get_img_feat(cnn_model, image)
606
+ qData, embeddings, answerDict = preprocessor.preprocessData(question)
607
+ data = {'data': qData, 'image': imageData}
608
+ print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue")))
609
+
610
+ # build model
611
+ print(bold("Building model..."))
612
+ start = time.time()
613
+ model = MACnet(embeddings, answerDict)
614
+ print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue")))
615
+
616
+ # initializer
617
+ init = tf.global_variables_initializer()
618
+
619
+ # savers
620
+ savers = setSavers(model)
621
+ saver, emaSaver = savers["saver"], savers["emaSaver"]
622
+
623
+ # sessionConfig
624
+ sessionConfig = setSession()
625
+
626
+ with tf.Session(config=sessionConfig) as sess:
627
+
628
+ # ensure no more ops are added after model is built
629
+ sess.graph.finalize()
630
+
631
+ # restore / initialize weights, initialize epoch variable
632
+ epoch = loadWeights(sess, saver, init)
633
+ print(epoch)
634
+ start = time.time()
635
+ if epoch > 0:
636
+ if config.useEMA:
637
+ emaSaver.restore(sess, config.weightsFile(epoch))
638
+ else:
639
+ saver.restore(sess, config.weightsFile(epoch))
640
+
641
+ evalRes = model.runBatch(sess, data['data'], data['image'], False)
642
+
643
+ print("took {:.2f} seconds".format(time.time() - start))
644
+
645
+ print(evalRes)
646
+
647
+
648
+ if __name__ == '__main__':
649
+ parseArgs()
650
+ loadDatasetConfig[config.dataset]()
651
+ question = 'How many text objects are located at the bottom side of table?'
652
+ imagePath = './mac-layoutLM-sample/PDF_val_64.png'
653
+ main(question, imagePath)