Spaces:
Runtime error
Runtime error
Upload main.py
Browse files
main.py
ADDED
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
from extract_feature import build_model, run_image, get_img_feat
|
5 |
+
|
6 |
+
# warnings.filterwarnings("ignore", category=FutureWarning)
|
7 |
+
# warnings.filterwarnings("ignore", message="size changed")
|
8 |
+
warnings.filterwarnings("ignore")
|
9 |
+
|
10 |
+
import sys
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import math
|
14 |
+
import random
|
15 |
+
|
16 |
+
try:
|
17 |
+
import Queue as queue
|
18 |
+
except ImportError:
|
19 |
+
import queue
|
20 |
+
import threading
|
21 |
+
import h5py
|
22 |
+
import json
|
23 |
+
import numpy as np
|
24 |
+
import tensorflow as tf
|
25 |
+
from termcolor import colored, cprint
|
26 |
+
|
27 |
+
from config import config, loadDatasetConfig, parseArgs
|
28 |
+
from preprocess import Preprocesser, bold, bcolored, writeline, writelist
|
29 |
+
from model import MACnet
|
30 |
+
from collections import defaultdict
|
31 |
+
|
32 |
+
|
33 |
+
############################################# loggers #############################################
|
34 |
+
|
35 |
+
# Writes log header to file
|
36 |
+
def logInit():
|
37 |
+
with open(config.logFile(), "a+") as outFile:
|
38 |
+
writeline(outFile, config.expName)
|
39 |
+
headers = ["epoch", "trainAcc", "valAcc", "trainLoss", "valLoss"]
|
40 |
+
if config.evalTrain:
|
41 |
+
headers += ["evalTrainAcc", "evalTrainLoss"]
|
42 |
+
if config.extra:
|
43 |
+
if config.evalTrain:
|
44 |
+
headers += ["thAcc", "thLoss"]
|
45 |
+
headers += ["vhAcc", "vhLoss"]
|
46 |
+
headers += ["time", "lr"]
|
47 |
+
|
48 |
+
writelist(outFile, headers)
|
49 |
+
# lr assumed to be last
|
50 |
+
|
51 |
+
|
52 |
+
# Writes log record to file
|
53 |
+
def logRecord(epoch, epochTime, lr, trainRes, evalRes, extraEvalRes):
|
54 |
+
with open(config.logFile(), "a+") as outFile:
|
55 |
+
record = [epoch, trainRes["acc"], evalRes["val"]["acc"], trainRes["loss"], evalRes["val"]["loss"]]
|
56 |
+
if config.evalTrain:
|
57 |
+
record += [evalRes["evalTrain"]["acc"], evalRes["evalTrain"]["loss"]]
|
58 |
+
if config.extra:
|
59 |
+
if config.evalTrain:
|
60 |
+
record += [extraEvalRes["evalTrain"]["acc"], extraEvalRes["evalTrain"]["loss"]]
|
61 |
+
record += [extraEvalRes["val"]["acc"], extraEvalRes["val"]["loss"]]
|
62 |
+
record += [epochTime, lr]
|
63 |
+
|
64 |
+
writelist(outFile, record)
|
65 |
+
|
66 |
+
|
67 |
+
# Gets last logged epoch and learning rate
|
68 |
+
def lastLoggedEpoch():
|
69 |
+
with open(config.logFile(), "r") as inFile:
|
70 |
+
lastLine = list(inFile)[-1].split(",")
|
71 |
+
epoch = int(lastLine[0])
|
72 |
+
lr = float(lastLine[-1])
|
73 |
+
return epoch, lr
|
74 |
+
|
75 |
+
|
76 |
+
################################## printing, output and analysis ##################################
|
77 |
+
|
78 |
+
# Analysis by type
|
79 |
+
analysisQuestionLims = [(0, 18), (19, float("inf"))]
|
80 |
+
analysisProgramLims = [(0, 12), (13, float("inf"))]
|
81 |
+
|
82 |
+
toArity = lambda instance: instance["programSeq"][-1].split("_", 1)[0]
|
83 |
+
toType = lambda instance: instance["programSeq"][-1].split("_", 1)[1]
|
84 |
+
|
85 |
+
|
86 |
+
def fieldLenIsInRange(field):
|
87 |
+
return lambda instance, group: \
|
88 |
+
(len(instance[field]) >= group[0] and
|
89 |
+
len(instance[field]) <= group[1])
|
90 |
+
|
91 |
+
|
92 |
+
# Groups instances based on a key
|
93 |
+
def grouperKey(toKey):
|
94 |
+
def grouper(instances):
|
95 |
+
res = defaultdict(list)
|
96 |
+
for instance in instances:
|
97 |
+
res[toKey(instance)].append(instance)
|
98 |
+
return res
|
99 |
+
|
100 |
+
return grouper
|
101 |
+
|
102 |
+
|
103 |
+
# Groups instances according to their match to condition
|
104 |
+
def grouperCond(groups, isIn):
|
105 |
+
def grouper(instances):
|
106 |
+
res = {}
|
107 |
+
for group in groups:
|
108 |
+
res[group] = (instance for instance in instances if isIn(instance, group))
|
109 |
+
return res
|
110 |
+
|
111 |
+
return grouper
|
112 |
+
|
113 |
+
|
114 |
+
groupers = {
|
115 |
+
"questionLength": grouperCond(analysisQuestionLims, fieldLenIsInRange("questionSeq")),
|
116 |
+
"programLength": grouperCond(analysisProgramLims, fieldLenIsInRange("programSeq")),
|
117 |
+
"arity": grouperKey(toArity),
|
118 |
+
"type": grouperKey(toType)
|
119 |
+
}
|
120 |
+
|
121 |
+
|
122 |
+
# Computes average
|
123 |
+
def avg(instances, field):
|
124 |
+
if len(instances) == 0:
|
125 |
+
return 0.0
|
126 |
+
return sum(instances[field]) / len(instances)
|
127 |
+
|
128 |
+
|
129 |
+
# Prints analysis of questions loss and accuracy by their group
|
130 |
+
def printAnalysis(res):
|
131 |
+
if config.analysisType != "":
|
132 |
+
print("Analysis by {type}".format(type=config.analysisType))
|
133 |
+
groups = groupers[config.analysisType](res["preds"])
|
134 |
+
for key in groups:
|
135 |
+
instances = groups[key]
|
136 |
+
avgLoss = avg(instances, "loss")
|
137 |
+
avgAcc = avg(instances, "acc")
|
138 |
+
num = len(instances)
|
139 |
+
print("Group {key}: Loss: {loss}, Acc: {acc}, Num: {num}".format(key, avgLoss, avgAcc, num))
|
140 |
+
|
141 |
+
|
142 |
+
# Print results for a tier
|
143 |
+
def printTierResults(tierName, res, color):
|
144 |
+
if res is None:
|
145 |
+
return
|
146 |
+
|
147 |
+
print("{tierName} Loss: {loss}, {tierName} accuracy: {acc}".format(tierName=tierName,
|
148 |
+
loss=bcolored(res["loss"], color),
|
149 |
+
acc=bcolored(res["acc"], color)))
|
150 |
+
|
151 |
+
printAnalysis(res)
|
152 |
+
|
153 |
+
|
154 |
+
# Prints dataset results (for several tiers)
|
155 |
+
def printDatasetResults(trainRes, evalRes):
|
156 |
+
printTierResults("Training", trainRes, "magenta")
|
157 |
+
printTierResults("Training EMA", evalRes["evalTrain"], "red")
|
158 |
+
printTierResults("Validation", evalRes["val"], "cyan")
|
159 |
+
|
160 |
+
|
161 |
+
# Writes predictions for several tiers
|
162 |
+
def writePreds(preprocessor, evalRes):
|
163 |
+
preprocessor.writePreds(evalRes, "_")
|
164 |
+
|
165 |
+
|
166 |
+
############################################# session #############################################
|
167 |
+
# Initializes TF session. Sets GPU memory configuration.
|
168 |
+
def setSession():
|
169 |
+
sessionConfig = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
|
170 |
+
if config.allowGrowth:
|
171 |
+
sessionConfig.gpu_options.allow_growth = True
|
172 |
+
if config.maxMemory < 1.0:
|
173 |
+
sessionConfig.gpu_options.per_process_gpu_memory_fraction = config.maxMemory
|
174 |
+
return sessionConfig
|
175 |
+
|
176 |
+
|
177 |
+
############################################## savers #############################################
|
178 |
+
# Initializes savers (standard, optional exponential-moving-average and optional for subset of variables)
|
179 |
+
def setSavers(model):
|
180 |
+
saver = tf.train.Saver(max_to_keep=config.weightsToKeep)
|
181 |
+
|
182 |
+
subsetSaver = None
|
183 |
+
if config.saveSubset:
|
184 |
+
isRelevant = lambda var: any(s in var.name for s in config.varSubset)
|
185 |
+
relevantVars = [var for var in tf.global_variables() if isRelevant(var)]
|
186 |
+
subsetSaver = tf.train.Saver(relevantVars, max_to_keep=config.weightsToKeep, allow_empty=True)
|
187 |
+
|
188 |
+
emaSaver = None
|
189 |
+
if config.useEMA:
|
190 |
+
emaSaver = tf.train.Saver(model.emaDict, max_to_keep=config.weightsToKeep)
|
191 |
+
|
192 |
+
return {
|
193 |
+
"saver": saver,
|
194 |
+
"subsetSaver": subsetSaver,
|
195 |
+
"emaSaver": emaSaver
|
196 |
+
}
|
197 |
+
|
198 |
+
|
199 |
+
################################### restore / initialize weights ##################################
|
200 |
+
# Restores weights of specified / last epoch if on restore mod.
|
201 |
+
# Otherwise, initializes weights.
|
202 |
+
def loadWeights(sess, saver, init):
|
203 |
+
if config.restoreEpoch > 0 or config.restore:
|
204 |
+
# restore last epoch only if restoreEpoch isn't set
|
205 |
+
if config.restoreEpoch == 0:
|
206 |
+
# restore last logged epoch
|
207 |
+
config.restoreEpoch, config.lr = lastLoggedEpoch()
|
208 |
+
print(bcolored("Restoring epoch {} and lr {}".format(config.restoreEpoch, config.lr), "cyan"))
|
209 |
+
print(bcolored("Restoring weights", "blue"))
|
210 |
+
print(config.weightsFile(config.restoreEpoch))
|
211 |
+
saver.restore(sess, config.weightsFile(config.restoreEpoch))
|
212 |
+
epoch = config.restoreEpoch
|
213 |
+
else:
|
214 |
+
print(bcolored("Initializing weights", "blue"))
|
215 |
+
sess.run(init)
|
216 |
+
logInit()
|
217 |
+
epoch = 0
|
218 |
+
|
219 |
+
return epoch
|
220 |
+
|
221 |
+
|
222 |
+
###################################### training / evaluation ######################################
|
223 |
+
# Chooses data to train on (main / extra) data.
|
224 |
+
def chooseTrainingData(data):
|
225 |
+
trainingData = data["main"]["train"]
|
226 |
+
alterData = None
|
227 |
+
|
228 |
+
if config.extra:
|
229 |
+
if config.trainExtra:
|
230 |
+
if config.extraVal:
|
231 |
+
trainingData = data["extra"]["val"]
|
232 |
+
else:
|
233 |
+
trainingData = data["extra"]["train"]
|
234 |
+
if config.alterExtra:
|
235 |
+
alterData = data["extra"]["train"]
|
236 |
+
|
237 |
+
return trainingData, alterData
|
238 |
+
|
239 |
+
|
240 |
+
#### evaluation
|
241 |
+
# Runs evaluation on train / val / test datasets.
|
242 |
+
def runEvaluation(sess, model, data, epoch, evalTrain=True, evalTest=False, getAtt=None):
|
243 |
+
if getAtt is None:
|
244 |
+
getAtt = config.getAtt
|
245 |
+
res = {"evalTrain": None, "val": None, "test": None}
|
246 |
+
|
247 |
+
if data is not None:
|
248 |
+
if evalTrain and config.evalTrain:
|
249 |
+
res["evalTrain"] = runEpoch(sess, model, data["evalTrain"], train=False, epoch=epoch, getAtt=getAtt)
|
250 |
+
|
251 |
+
res["val"] = runEpoch(sess, model, data["val"], train=False, epoch=epoch, getAtt=getAtt)
|
252 |
+
|
253 |
+
if evalTest or config.test:
|
254 |
+
res["test"] = runEpoch(sess, model, data["test"], train=False, epoch=epoch, getAtt=getAtt)
|
255 |
+
|
256 |
+
return res
|
257 |
+
|
258 |
+
|
259 |
+
## training conditions (comparing current epoch result to prior ones)
|
260 |
+
def improveEnough(curr, prior, lr):
|
261 |
+
prevRes = prior["prev"]["res"]
|
262 |
+
currRes = curr["res"]
|
263 |
+
|
264 |
+
if prevRes is None:
|
265 |
+
return True
|
266 |
+
|
267 |
+
prevTrainLoss = prevRes["train"]["loss"]
|
268 |
+
currTrainLoss = currRes["train"]["loss"]
|
269 |
+
lossDiff = prevTrainLoss - currTrainLoss
|
270 |
+
|
271 |
+
notImprove = ((lossDiff < 0.015 and prevTrainLoss < 0.5 and lr > 0.00002) or \
|
272 |
+
(lossDiff < 0.008 and prevTrainLoss < 0.15 and lr > 0.00001) or \
|
273 |
+
(lossDiff < 0.003 and prevTrainLoss < 0.10 and lr > 0.000005))
|
274 |
+
# (prevTrainLoss < 0.2 and config.lr > 0.000015)
|
275 |
+
|
276 |
+
return not notImprove
|
277 |
+
|
278 |
+
|
279 |
+
def better(currRes, bestRes):
|
280 |
+
return currRes["val"]["acc"] > bestRes["val"]["acc"]
|
281 |
+
|
282 |
+
|
283 |
+
############################################## data ###############################################
|
284 |
+
#### instances and batching
|
285 |
+
# Trims sequences based on their max length.
|
286 |
+
def trim2DVectors(vectors, vectorsLengths):
|
287 |
+
maxLength = np.max(vectorsLengths)
|
288 |
+
return vectors[:, :maxLength]
|
289 |
+
|
290 |
+
|
291 |
+
# Trims batch based on question length.
|
292 |
+
def trimData(data):
|
293 |
+
data["questions"] = trim2DVectors(data["questions"], data["questionLengths"])
|
294 |
+
return data
|
295 |
+
|
296 |
+
|
297 |
+
# Gets batch / bucket size.
|
298 |
+
def getLength(data):
|
299 |
+
return len(data["instances"])
|
300 |
+
|
301 |
+
|
302 |
+
# Selects the data entries that match the indices.
|
303 |
+
def selectIndices(data, indices):
|
304 |
+
def select(field, indices):
|
305 |
+
if type(field) is np.ndarray:
|
306 |
+
return field[indices]
|
307 |
+
if type(field) is list:
|
308 |
+
return [field[i] for i in indices]
|
309 |
+
else:
|
310 |
+
return field
|
311 |
+
|
312 |
+
selected = {k: select(d, indices) for k, d in data.items()}
|
313 |
+
return selected
|
314 |
+
|
315 |
+
|
316 |
+
# Batches data into a a list of batches of batchSize.
|
317 |
+
# Shuffles the data by default.
|
318 |
+
def getBatches(data, batchSize=None, shuffle=True):
|
319 |
+
batches = []
|
320 |
+
|
321 |
+
dataLen = getLength(data)
|
322 |
+
if batchSize is None or batchSize > dataLen:
|
323 |
+
batchSize = dataLen
|
324 |
+
|
325 |
+
indices = np.arange(dataLen)
|
326 |
+
if shuffle:
|
327 |
+
np.random.shuffle(indices)
|
328 |
+
|
329 |
+
for batchStart in range(0, dataLen, batchSize):
|
330 |
+
batchIndices = indices[batchStart: batchStart + batchSize]
|
331 |
+
# if len(batchIndices) == batchSize?
|
332 |
+
if len(batchIndices) >= config.gpusNum:
|
333 |
+
batch = selectIndices(data, batchIndices)
|
334 |
+
batches.append(batch)
|
335 |
+
# batchesIndices.append((data, batchIndices))
|
336 |
+
|
337 |
+
return batches
|
338 |
+
|
339 |
+
|
340 |
+
#### image batches
|
341 |
+
# Opens image files.
|
342 |
+
def openImageFiles(images):
|
343 |
+
images["imagesFile"] = h5py.File(images["imagesFilename"], "r")
|
344 |
+
images["imagesIds"] = None
|
345 |
+
if config.dataset == "NLVR":
|
346 |
+
with open(images["imageIdsFilename"], "r") as imageIdsFile:
|
347 |
+
images["imagesIds"] = json.load(imageIdsFile)
|
348 |
+
|
349 |
+
# Closes image files.
|
350 |
+
|
351 |
+
|
352 |
+
def closeImageFiles(images):
|
353 |
+
images["imagesFile"].close()
|
354 |
+
|
355 |
+
|
356 |
+
# Loads an images from file for a given data batch.
|
357 |
+
def loadImageBatch(images, batch):
|
358 |
+
imagesFile = images["imagesFile"]
|
359 |
+
id2idx = images["imagesIds"]
|
360 |
+
toIndex = lambda imageId: imageId
|
361 |
+
if id2idx is not None:
|
362 |
+
toIndex = lambda imageId: id2idx[imageId]
|
363 |
+
imageBatch = np.stack([imagesFile["features"][toIndex(imageId)] for imageId in batch["imageIds"]], axis=0)
|
364 |
+
|
365 |
+
return {"images": imageBatch, "imageIds": batch["imageIds"]}
|
366 |
+
|
367 |
+
|
368 |
+
# Loads images for several num batches in the batches list from start index.
|
369 |
+
def loadImageBatches(images, batches, start, num):
|
370 |
+
batches = batches[start: start + num]
|
371 |
+
return [loadImageBatch(images, batch) for batch in batches]
|
372 |
+
|
373 |
+
|
374 |
+
#### data alternation
|
375 |
+
# Alternates main training batches with extra data.
|
376 |
+
def alternateData(batches, alterData, dataLen):
|
377 |
+
alterData = alterData["data"][0] # data isn't bucketed for altered data
|
378 |
+
|
379 |
+
# computes number of repetitions
|
380 |
+
needed = math.ceil(len(batches) / config.alterNum)
|
381 |
+
print(bold("Extra batches needed: %d") % needed)
|
382 |
+
perData = math.ceil(getLength(alterData) / config.batchSize)
|
383 |
+
print(bold("Batches per extra data: %d") % perData)
|
384 |
+
repetitions = math.ceil(needed / perData)
|
385 |
+
print(bold("reps: %d") % repetitions)
|
386 |
+
|
387 |
+
# make alternate batches
|
388 |
+
alterBatches = []
|
389 |
+
for _ in range(repetitions):
|
390 |
+
repBatches = getBatches(alterData, batchSize=config.batchSize)
|
391 |
+
random.shuffle(repBatches)
|
392 |
+
alterBatches += repBatches
|
393 |
+
print(bold("Batches num: %d") + len(alterBatches))
|
394 |
+
|
395 |
+
# alternate data with extra data
|
396 |
+
curr = len(batches) - 1
|
397 |
+
for alterBatch in alterBatches:
|
398 |
+
if curr < 0:
|
399 |
+
# print(colored("too many" + str(curr) + " " + str(len(batches)),"red"))
|
400 |
+
break
|
401 |
+
batches.insert(curr, alterBatch)
|
402 |
+
dataLen += getLength(alterBatch)
|
403 |
+
curr -= config.alterNum
|
404 |
+
|
405 |
+
return batches, dataLen
|
406 |
+
|
407 |
+
|
408 |
+
############################################ threading ############################################
|
409 |
+
|
410 |
+
imagesQueue = queue.Queue(maxsize=20) # config.tasksNum
|
411 |
+
inQueue = queue.Queue(maxsize=1)
|
412 |
+
outQueue = queue.Queue(maxsize=1)
|
413 |
+
|
414 |
+
|
415 |
+
# Runs a worker thread(s) to load images while training .
|
416 |
+
class StoppableThread(threading.Thread):
|
417 |
+
# Thread class with a stop() method. The thread itself has to check
|
418 |
+
# regularly for the stopped() condition.
|
419 |
+
|
420 |
+
def __init__(self, images, batches): # i
|
421 |
+
super(StoppableThread, self).__init__()
|
422 |
+
# self.i = i
|
423 |
+
self.images = images
|
424 |
+
self.batches = batches
|
425 |
+
self._stop_event = threading.Event()
|
426 |
+
|
427 |
+
# def __init__(self, args):
|
428 |
+
# super(StoppableThread, self).__init__(args = args)
|
429 |
+
# self._stop_event = threading.Event()
|
430 |
+
|
431 |
+
# def __init__(self, target, args):
|
432 |
+
# super(StoppableThread, self).__init__(target = target, args = args)
|
433 |
+
# self._stop_event = threading.Event()
|
434 |
+
|
435 |
+
def stop(self):
|
436 |
+
self._stop_event.set()
|
437 |
+
|
438 |
+
def stopped(self):
|
439 |
+
return self._stop_event.is_set()
|
440 |
+
|
441 |
+
def run(self):
|
442 |
+
while not self.stopped():
|
443 |
+
try:
|
444 |
+
batchNum = inQueue.get(timeout=60)
|
445 |
+
nextItem = loadImageBatches(self.images, self.batches, batchNum, int(config.taskSize / 2))
|
446 |
+
outQueue.put(nextItem)
|
447 |
+
# inQueue.task_done()
|
448 |
+
except:
|
449 |
+
pass
|
450 |
+
# print("worker %d done", self.i)
|
451 |
+
|
452 |
+
|
453 |
+
def loaderRun(images, batches):
|
454 |
+
batchNum = 0
|
455 |
+
|
456 |
+
# if config.workers == 2:
|
457 |
+
# worker = StoppableThread(images, batches) # i,
|
458 |
+
# worker.daemon = True
|
459 |
+
# worker.start()
|
460 |
+
|
461 |
+
# while batchNum < len(batches):
|
462 |
+
# inQueue.put(batchNum + int(config.taskSize / 2))
|
463 |
+
# nextItem1 = loadImageBatches(images, batches, batchNum, int(config.taskSize / 2))
|
464 |
+
# nextItem2 = outQueue.get()
|
465 |
+
|
466 |
+
# nextItem = nextItem1 + nextItem2
|
467 |
+
# assert len(nextItem) == min(config.taskSize, len(batches) - batchNum)
|
468 |
+
# batchNum += config.taskSize
|
469 |
+
|
470 |
+
# imagesQueue.put(nextItem)
|
471 |
+
|
472 |
+
# worker.stop()
|
473 |
+
# else:
|
474 |
+
while batchNum < len(batches):
|
475 |
+
nextItem = loadImageBatches(images, batches, batchNum, config.taskSize)
|
476 |
+
assert len(nextItem) == min(config.taskSize, len(batches) - batchNum)
|
477 |
+
batchNum += config.taskSize
|
478 |
+
imagesQueue.put(nextItem)
|
479 |
+
|
480 |
+
# print("manager loader done")
|
481 |
+
|
482 |
+
|
483 |
+
########################################## stats tracking #########################################
|
484 |
+
# Computes exponential moving average.
|
485 |
+
def emaAvg(avg, value):
|
486 |
+
if avg is None:
|
487 |
+
return value
|
488 |
+
emaRate = 0.98
|
489 |
+
return avg * emaRate + value * (1 - emaRate)
|
490 |
+
|
491 |
+
|
492 |
+
# Initializes training statistics.
|
493 |
+
def initStats():
|
494 |
+
return {
|
495 |
+
"totalBatches": 0,
|
496 |
+
"totalData": 0,
|
497 |
+
"totalLoss": 0.0,
|
498 |
+
"totalCorrect": 0,
|
499 |
+
"loss": 0.0,
|
500 |
+
"acc": 0.0,
|
501 |
+
"emaLoss": None,
|
502 |
+
"emaAcc": None,
|
503 |
+
}
|
504 |
+
|
505 |
+
|
506 |
+
# Updates statistics with training results of a batch
|
507 |
+
def updateStats(stats, res, batch):
|
508 |
+
stats["totalBatches"] += 1
|
509 |
+
stats["totalData"] += getLength(batch)
|
510 |
+
|
511 |
+
stats["totalLoss"] += res["loss"]
|
512 |
+
stats["totalCorrect"] += res["correctNum"]
|
513 |
+
|
514 |
+
stats["loss"] = stats["totalLoss"] / stats["totalBatches"]
|
515 |
+
stats["acc"] = stats["totalCorrect"] / stats["totalData"]
|
516 |
+
|
517 |
+
stats["emaLoss"] = emaAvg(stats["emaLoss"], res["loss"])
|
518 |
+
stats["emaAcc"] = emaAvg(stats["emaAcc"], res["acc"])
|
519 |
+
|
520 |
+
return stats
|
521 |
+
|
522 |
+
|
523 |
+
# auto-encoder ae = {:2.4f} autoEncLoss,
|
524 |
+
# Translates training statistics into a string to print
|
525 |
+
def statsToStr(stats, res, epoch, batchNum, dataLen, startTime):
|
526 |
+
formatStr = "\reb {epoch},{batchNum} ({dataProcessed} / {dataLen:5d}), " + \
|
527 |
+
"t = {time} ({loadTime:2.2f}+{trainTime:2.2f}), " + \
|
528 |
+
"lr {lr}, l = {loss}, a = {acc}, avL = {avgLoss}, " + \
|
529 |
+
"avA = {avgAcc}, g = {gradNorm:2.4f}, " + \
|
530 |
+
"emL = {emaLoss:2.4f}, emA = {emaAcc:2.4f}; " + \
|
531 |
+
"{expname}" # {machine}/{gpu}"
|
532 |
+
|
533 |
+
s_epoch = bcolored("{:2d}".format(epoch), "green")
|
534 |
+
s_batchNum = "{:3d}".format(batchNum)
|
535 |
+
s_dataProcessed = bcolored("{:5d}".format(stats["totalData"]), "green")
|
536 |
+
s_dataLen = dataLen
|
537 |
+
s_time = bcolored("{:2.2f}".format(time.time() - startTime), "green")
|
538 |
+
s_loadTime = res["readTime"]
|
539 |
+
s_trainTime = res["trainTime"]
|
540 |
+
s_lr = bold(config.lr)
|
541 |
+
s_loss = bcolored("{:2.4f}".format(res["loss"]), "blue")
|
542 |
+
s_acc = bcolored("{:2.4f}".format(res["acc"]), "blue")
|
543 |
+
s_avgLoss = bcolored("{:2.4f}".format(stats["loss"]), "blue")
|
544 |
+
s_avgAcc = bcolored("{:2.4f}".format(stats["acc"]), "red")
|
545 |
+
s_gradNorm = res["gradNorm"]
|
546 |
+
s_emaLoss = stats["emaLoss"]
|
547 |
+
s_emaAcc = stats["emaAcc"]
|
548 |
+
s_expname = config.expName
|
549 |
+
# s_machine = bcolored(config.dataPath[9:11],"green")
|
550 |
+
# s_gpu = bcolored(config.gpus,"green")
|
551 |
+
|
552 |
+
return formatStr.format(epoch=s_epoch, batchNum=s_batchNum, dataProcessed=s_dataProcessed,
|
553 |
+
dataLen=s_dataLen, time=s_time, loadTime=s_loadTime,
|
554 |
+
trainTime=s_trainTime, lr=s_lr, loss=s_loss, acc=s_acc,
|
555 |
+
avgLoss=s_avgLoss, avgAcc=s_avgAcc, gradNorm=s_gradNorm,
|
556 |
+
emaLoss=s_emaLoss, emaAcc=s_emaAcc, expname=s_expname)
|
557 |
+
# machine = s_machine, gpu = s_gpu)
|
558 |
+
|
559 |
+
|
560 |
+
# collectRuntimeStats, writer = None,
|
561 |
+
'''
|
562 |
+
Runs an epoch with model and session over the data.
|
563 |
+
1. Batches the data and optionally mix it with the extra alterData.
|
564 |
+
2. Start worker threads to load images in parallel to training.
|
565 |
+
3. Runs model for each batch, and gets results (e.g. loss, accuracy).
|
566 |
+
4. Updates and prints statistics based on batch results.
|
567 |
+
5. Once in a while (every config.saveEvery), save weights.
|
568 |
+
|
569 |
+
Args:
|
570 |
+
sess: TF session to run with.
|
571 |
+
|
572 |
+
model: model to process data. Has runBatch method that process a given batch.
|
573 |
+
(See model.py for further details).
|
574 |
+
|
575 |
+
data: data to use for training/evaluation.
|
576 |
+
|
577 |
+
epoch: epoch number.
|
578 |
+
|
579 |
+
saver: TF saver to save weights
|
580 |
+
|
581 |
+
calle: a method to call every number of iterations (config.calleEvery)
|
582 |
+
|
583 |
+
alterData: extra data to mix with main data while training.
|
584 |
+
|
585 |
+
getAtt: True to return model attentions.
|
586 |
+
'''
|
587 |
+
|
588 |
+
|
589 |
+
def main(question, image):
|
590 |
+
with open(config.configFile(), "a+") as outFile:
|
591 |
+
json.dump(vars(config), outFile)
|
592 |
+
|
593 |
+
# set gpus
|
594 |
+
if config.gpus != "":
|
595 |
+
config.gpusNum = len(config.gpus.split(","))
|
596 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus
|
597 |
+
|
598 |
+
tf.logging.set_verbosity(tf.logging.ERROR)
|
599 |
+
|
600 |
+
# process data
|
601 |
+
print(bold("Preprocess data..."))
|
602 |
+
start = time.time()
|
603 |
+
preprocessor = Preprocesser()
|
604 |
+
cnn_model = build_model()
|
605 |
+
imageData = get_img_feat(cnn_model, image)
|
606 |
+
qData, embeddings, answerDict = preprocessor.preprocessData(question)
|
607 |
+
data = {'data': qData, 'image': imageData}
|
608 |
+
print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue")))
|
609 |
+
|
610 |
+
# build model
|
611 |
+
print(bold("Building model..."))
|
612 |
+
start = time.time()
|
613 |
+
model = MACnet(embeddings, answerDict)
|
614 |
+
print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue")))
|
615 |
+
|
616 |
+
# initializer
|
617 |
+
init = tf.global_variables_initializer()
|
618 |
+
|
619 |
+
# savers
|
620 |
+
savers = setSavers(model)
|
621 |
+
saver, emaSaver = savers["saver"], savers["emaSaver"]
|
622 |
+
|
623 |
+
# sessionConfig
|
624 |
+
sessionConfig = setSession()
|
625 |
+
|
626 |
+
with tf.Session(config=sessionConfig) as sess:
|
627 |
+
|
628 |
+
# ensure no more ops are added after model is built
|
629 |
+
sess.graph.finalize()
|
630 |
+
|
631 |
+
# restore / initialize weights, initialize epoch variable
|
632 |
+
epoch = loadWeights(sess, saver, init)
|
633 |
+
print(epoch)
|
634 |
+
start = time.time()
|
635 |
+
if epoch > 0:
|
636 |
+
if config.useEMA:
|
637 |
+
emaSaver.restore(sess, config.weightsFile(epoch))
|
638 |
+
else:
|
639 |
+
saver.restore(sess, config.weightsFile(epoch))
|
640 |
+
|
641 |
+
evalRes = model.runBatch(sess, data['data'], data['image'], False)
|
642 |
+
|
643 |
+
print("took {:.2f} seconds".format(time.time() - start))
|
644 |
+
|
645 |
+
print(evalRes)
|
646 |
+
|
647 |
+
|
648 |
+
if __name__ == '__main__':
|
649 |
+
parseArgs()
|
650 |
+
loadDatasetConfig[config.dataset]()
|
651 |
+
question = 'How many text objects are located at the bottom side of table?'
|
652 |
+
imagePath = './mac-layoutLM-sample/PDF_val_64.png'
|
653 |
+
main(question, imagePath)
|