Spaces:
Runtime error
Runtime error
from __future__ import division | |
import warnings | |
from extract_feature import build_model, run_image, get_img_feat | |
# warnings.filterwarnings("ignore", category=FutureWarning) | |
# warnings.filterwarnings("ignore", message="size changed") | |
warnings.filterwarnings("ignore") | |
import sys | |
import os | |
import time | |
import math | |
import random | |
try: | |
import Queue as queue | |
except ImportError: | |
import queue | |
import threading | |
import h5py | |
import json | |
import numpy as np | |
import tensorflow as tf | |
from termcolor import colored, cprint | |
from config import config, loadDatasetConfig, parseArgs | |
from preprocess import Preprocesser, bold, bcolored, writeline, writelist | |
from model import MACnet | |
from collections import defaultdict | |
############################################# loggers ############################################# | |
# Writes log header to file | |
def logInit(): | |
with open(config.logFile(), "a+") as outFile: | |
writeline(outFile, config.expName) | |
headers = ["epoch", "trainAcc", "valAcc", "trainLoss", "valLoss"] | |
if config.evalTrain: | |
headers += ["evalTrainAcc", "evalTrainLoss"] | |
if config.extra: | |
if config.evalTrain: | |
headers += ["thAcc", "thLoss"] | |
headers += ["vhAcc", "vhLoss"] | |
headers += ["time", "lr"] | |
writelist(outFile, headers) | |
# lr assumed to be last | |
# Writes log record to file | |
def logRecord(epoch, epochTime, lr, trainRes, evalRes, extraEvalRes): | |
with open(config.logFile(), "a+") as outFile: | |
record = [epoch, trainRes["acc"], evalRes["val"]["acc"], trainRes["loss"], evalRes["val"]["loss"]] | |
if config.evalTrain: | |
record += [evalRes["evalTrain"]["acc"], evalRes["evalTrain"]["loss"]] | |
if config.extra: | |
if config.evalTrain: | |
record += [extraEvalRes["evalTrain"]["acc"], extraEvalRes["evalTrain"]["loss"]] | |
record += [extraEvalRes["val"]["acc"], extraEvalRes["val"]["loss"]] | |
record += [epochTime, lr] | |
writelist(outFile, record) | |
# Gets last logged epoch and learning rate | |
def lastLoggedEpoch(): | |
with open(config.logFile(), "r") as inFile: | |
lastLine = list(inFile)[-1].split(",") | |
epoch = int(lastLine[0]) | |
lr = float(lastLine[-1]) | |
return epoch, lr | |
################################## printing, output and analysis ################################## | |
# Analysis by type | |
analysisQuestionLims = [(0, 18), (19, float("inf"))] | |
analysisProgramLims = [(0, 12), (13, float("inf"))] | |
toArity = lambda instance: instance["programSeq"][-1].split("_", 1)[0] | |
toType = lambda instance: instance["programSeq"][-1].split("_", 1)[1] | |
def fieldLenIsInRange(field): | |
return lambda instance, group: \ | |
(len(instance[field]) >= group[0] and | |
len(instance[field]) <= group[1]) | |
# Groups instances based on a key | |
def grouperKey(toKey): | |
def grouper(instances): | |
res = defaultdict(list) | |
for instance in instances: | |
res[toKey(instance)].append(instance) | |
return res | |
return grouper | |
# Groups instances according to their match to condition | |
def grouperCond(groups, isIn): | |
def grouper(instances): | |
res = {} | |
for group in groups: | |
res[group] = (instance for instance in instances if isIn(instance, group)) | |
return res | |
return grouper | |
groupers = { | |
"questionLength": grouperCond(analysisQuestionLims, fieldLenIsInRange("questionSeq")), | |
"programLength": grouperCond(analysisProgramLims, fieldLenIsInRange("programSeq")), | |
"arity": grouperKey(toArity), | |
"type": grouperKey(toType) | |
} | |
# Computes average | |
def avg(instances, field): | |
if len(instances) == 0: | |
return 0.0 | |
return sum(instances[field]) / len(instances) | |
# Prints analysis of questions loss and accuracy by their group | |
def printAnalysis(res): | |
if config.analysisType != "": | |
print("Analysis by {type}".format(type=config.analysisType)) | |
groups = groupers[config.analysisType](res["preds"]) | |
for key in groups: | |
instances = groups[key] | |
avgLoss = avg(instances, "loss") | |
avgAcc = avg(instances, "acc") | |
num = len(instances) | |
print("Group {key}: Loss: {loss}, Acc: {acc}, Num: {num}".format(key, avgLoss, avgAcc, num)) | |
# Print results for a tier | |
def printTierResults(tierName, res, color): | |
if res is None: | |
return | |
print("{tierName} Loss: {loss}, {tierName} accuracy: {acc}".format(tierName=tierName, | |
loss=bcolored(res["loss"], color), | |
acc=bcolored(res["acc"], color))) | |
printAnalysis(res) | |
# Prints dataset results (for several tiers) | |
def printDatasetResults(trainRes, evalRes): | |
printTierResults("Training", trainRes, "magenta") | |
printTierResults("Training EMA", evalRes["evalTrain"], "red") | |
printTierResults("Validation", evalRes["val"], "cyan") | |
# Writes predictions for several tiers | |
def writePreds(preprocessor, evalRes): | |
preprocessor.writePreds(evalRes, "_") | |
############################################# session ############################################# | |
# Initializes TF session. Sets GPU memory configuration. | |
def setSession(): | |
sessionConfig = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) | |
if config.allowGrowth: | |
sessionConfig.gpu_options.allow_growth = True | |
if config.maxMemory < 1.0: | |
sessionConfig.gpu_options.per_process_gpu_memory_fraction = config.maxMemory | |
return sessionConfig | |
############################################## savers ############################################# | |
# Initializes savers (standard, optional exponential-moving-average and optional for subset of variables) | |
def setSavers(model): | |
saver = tf.train.Saver(max_to_keep=config.weightsToKeep) | |
subsetSaver = None | |
if config.saveSubset: | |
isRelevant = lambda var: any(s in var.name for s in config.varSubset) | |
relevantVars = [var for var in tf.global_variables() if isRelevant(var)] | |
subsetSaver = tf.train.Saver(relevantVars, max_to_keep=config.weightsToKeep, allow_empty=True) | |
emaSaver = None | |
if config.useEMA: | |
emaSaver = tf.train.Saver(model.emaDict, max_to_keep=config.weightsToKeep) | |
return { | |
"saver": saver, | |
"subsetSaver": subsetSaver, | |
"emaSaver": emaSaver | |
} | |
################################### restore / initialize weights ################################## | |
# Restores weights of specified / last epoch if on restore mod. | |
# Otherwise, initializes weights. | |
def loadWeights(sess, saver, init): | |
if config.restoreEpoch > 0 or config.restore: | |
# restore last epoch only if restoreEpoch isn't set | |
if config.restoreEpoch == 0: | |
# restore last logged epoch | |
config.restoreEpoch, config.lr = lastLoggedEpoch() | |
print(bcolored("Restoring epoch {} and lr {}".format(config.restoreEpoch, config.lr), "cyan")) | |
print(bcolored("Restoring weights", "blue")) | |
print(config.weightsFile(config.restoreEpoch)) | |
saver.restore(sess, config.weightsFile(config.restoreEpoch)) | |
epoch = config.restoreEpoch | |
else: | |
print(bcolored("Initializing weights", "blue")) | |
sess.run(init) | |
logInit() | |
epoch = 0 | |
return epoch | |
###################################### training / evaluation ###################################### | |
# Chooses data to train on (main / extra) data. | |
def chooseTrainingData(data): | |
trainingData = data["main"]["train"] | |
alterData = None | |
if config.extra: | |
if config.trainExtra: | |
if config.extraVal: | |
trainingData = data["extra"]["val"] | |
else: | |
trainingData = data["extra"]["train"] | |
if config.alterExtra: | |
alterData = data["extra"]["train"] | |
return trainingData, alterData | |
#### evaluation | |
# Runs evaluation on train / val / test datasets. | |
def runEvaluation(sess, model, data, epoch, evalTrain=True, evalTest=False, getAtt=None): | |
if getAtt is None: | |
getAtt = config.getAtt | |
res = {"evalTrain": None, "val": None, "test": None} | |
if data is not None: | |
if evalTrain and config.evalTrain: | |
res["evalTrain"] = runEpoch(sess, model, data["evalTrain"], train=False, epoch=epoch, getAtt=getAtt) | |
res["val"] = runEpoch(sess, model, data["val"], train=False, epoch=epoch, getAtt=getAtt) | |
if evalTest or config.test: | |
res["test"] = runEpoch(sess, model, data["test"], train=False, epoch=epoch, getAtt=getAtt) | |
return res | |
## training conditions (comparing current epoch result to prior ones) | |
def improveEnough(curr, prior, lr): | |
prevRes = prior["prev"]["res"] | |
currRes = curr["res"] | |
if prevRes is None: | |
return True | |
prevTrainLoss = prevRes["train"]["loss"] | |
currTrainLoss = currRes["train"]["loss"] | |
lossDiff = prevTrainLoss - currTrainLoss | |
notImprove = ((lossDiff < 0.015 and prevTrainLoss < 0.5 and lr > 0.00002) or \ | |
(lossDiff < 0.008 and prevTrainLoss < 0.15 and lr > 0.00001) or \ | |
(lossDiff < 0.003 and prevTrainLoss < 0.10 and lr > 0.000005)) | |
# (prevTrainLoss < 0.2 and config.lr > 0.000015) | |
return not notImprove | |
def better(currRes, bestRes): | |
return currRes["val"]["acc"] > bestRes["val"]["acc"] | |
############################################## data ############################################### | |
#### instances and batching | |
# Trims sequences based on their max length. | |
def trim2DVectors(vectors, vectorsLengths): | |
maxLength = np.max(vectorsLengths) | |
return vectors[:, :maxLength] | |
# Trims batch based on question length. | |
def trimData(data): | |
data["questions"] = trim2DVectors(data["questions"], data["questionLengths"]) | |
return data | |
# Gets batch / bucket size. | |
def getLength(data): | |
return len(data["instances"]) | |
# Selects the data entries that match the indices. | |
def selectIndices(data, indices): | |
def select(field, indices): | |
if type(field) is np.ndarray: | |
return field[indices] | |
if type(field) is list: | |
return [field[i] for i in indices] | |
else: | |
return field | |
selected = {k: select(d, indices) for k, d in data.items()} | |
return selected | |
# Batches data into a a list of batches of batchSize. | |
# Shuffles the data by default. | |
def getBatches(data, batchSize=None, shuffle=True): | |
batches = [] | |
dataLen = getLength(data) | |
if batchSize is None or batchSize > dataLen: | |
batchSize = dataLen | |
indices = np.arange(dataLen) | |
if shuffle: | |
np.random.shuffle(indices) | |
for batchStart in range(0, dataLen, batchSize): | |
batchIndices = indices[batchStart: batchStart + batchSize] | |
# if len(batchIndices) == batchSize? | |
if len(batchIndices) >= config.gpusNum: | |
batch = selectIndices(data, batchIndices) | |
batches.append(batch) | |
# batchesIndices.append((data, batchIndices)) | |
return batches | |
#### image batches | |
# Opens image files. | |
def openImageFiles(images): | |
images["imagesFile"] = h5py.File(images["imagesFilename"], "r") | |
images["imagesIds"] = None | |
if config.dataset == "NLVR": | |
with open(images["imageIdsFilename"], "r") as imageIdsFile: | |
images["imagesIds"] = json.load(imageIdsFile) | |
# Closes image files. | |
def closeImageFiles(images): | |
images["imagesFile"].close() | |
# Loads an images from file for a given data batch. | |
def loadImageBatch(images, batch): | |
imagesFile = images["imagesFile"] | |
id2idx = images["imagesIds"] | |
toIndex = lambda imageId: imageId | |
if id2idx is not None: | |
toIndex = lambda imageId: id2idx[imageId] | |
imageBatch = np.stack([imagesFile["features"][toIndex(imageId)] for imageId in batch["imageIds"]], axis=0) | |
return {"images": imageBatch, "imageIds": batch["imageIds"]} | |
# Loads images for several num batches in the batches list from start index. | |
def loadImageBatches(images, batches, start, num): | |
batches = batches[start: start + num] | |
return [loadImageBatch(images, batch) for batch in batches] | |
#### data alternation | |
# Alternates main training batches with extra data. | |
def alternateData(batches, alterData, dataLen): | |
alterData = alterData["data"][0] # data isn't bucketed for altered data | |
# computes number of repetitions | |
needed = math.ceil(len(batches) / config.alterNum) | |
print(bold("Extra batches needed: %d") % needed) | |
perData = math.ceil(getLength(alterData) / config.batchSize) | |
print(bold("Batches per extra data: %d") % perData) | |
repetitions = math.ceil(needed / perData) | |
print(bold("reps: %d") % repetitions) | |
# make alternate batches | |
alterBatches = [] | |
for _ in range(repetitions): | |
repBatches = getBatches(alterData, batchSize=config.batchSize) | |
random.shuffle(repBatches) | |
alterBatches += repBatches | |
print(bold("Batches num: %d") + len(alterBatches)) | |
# alternate data with extra data | |
curr = len(batches) - 1 | |
for alterBatch in alterBatches: | |
if curr < 0: | |
# print(colored("too many" + str(curr) + " " + str(len(batches)),"red")) | |
break | |
batches.insert(curr, alterBatch) | |
dataLen += getLength(alterBatch) | |
curr -= config.alterNum | |
return batches, dataLen | |
############################################ threading ############################################ | |
imagesQueue = queue.Queue(maxsize=20) # config.tasksNum | |
inQueue = queue.Queue(maxsize=1) | |
outQueue = queue.Queue(maxsize=1) | |
# Runs a worker thread(s) to load images while training . | |
class StoppableThread(threading.Thread): | |
# Thread class with a stop() method. The thread itself has to check | |
# regularly for the stopped() condition. | |
def __init__(self, images, batches): # i | |
super(StoppableThread, self).__init__() | |
# self.i = i | |
self.images = images | |
self.batches = batches | |
self._stop_event = threading.Event() | |
# def __init__(self, args): | |
# super(StoppableThread, self).__init__(args = args) | |
# self._stop_event = threading.Event() | |
# def __init__(self, target, args): | |
# super(StoppableThread, self).__init__(target = target, args = args) | |
# self._stop_event = threading.Event() | |
def stop(self): | |
self._stop_event.set() | |
def stopped(self): | |
return self._stop_event.is_set() | |
def run(self): | |
while not self.stopped(): | |
try: | |
batchNum = inQueue.get(timeout=60) | |
nextItem = loadImageBatches(self.images, self.batches, batchNum, int(config.taskSize / 2)) | |
outQueue.put(nextItem) | |
# inQueue.task_done() | |
except: | |
pass | |
# print("worker %d done", self.i) | |
def loaderRun(images, batches): | |
batchNum = 0 | |
# if config.workers == 2: | |
# worker = StoppableThread(images, batches) # i, | |
# worker.daemon = True | |
# worker.start() | |
# while batchNum < len(batches): | |
# inQueue.put(batchNum + int(config.taskSize / 2)) | |
# nextItem1 = loadImageBatches(images, batches, batchNum, int(config.taskSize / 2)) | |
# nextItem2 = outQueue.get() | |
# nextItem = nextItem1 + nextItem2 | |
# assert len(nextItem) == min(config.taskSize, len(batches) - batchNum) | |
# batchNum += config.taskSize | |
# imagesQueue.put(nextItem) | |
# worker.stop() | |
# else: | |
while batchNum < len(batches): | |
nextItem = loadImageBatches(images, batches, batchNum, config.taskSize) | |
assert len(nextItem) == min(config.taskSize, len(batches) - batchNum) | |
batchNum += config.taskSize | |
imagesQueue.put(nextItem) | |
# print("manager loader done") | |
########################################## stats tracking ######################################### | |
# Computes exponential moving average. | |
def emaAvg(avg, value): | |
if avg is None: | |
return value | |
emaRate = 0.98 | |
return avg * emaRate + value * (1 - emaRate) | |
# Initializes training statistics. | |
def initStats(): | |
return { | |
"totalBatches": 0, | |
"totalData": 0, | |
"totalLoss": 0.0, | |
"totalCorrect": 0, | |
"loss": 0.0, | |
"acc": 0.0, | |
"emaLoss": None, | |
"emaAcc": None, | |
} | |
# Updates statistics with training results of a batch | |
def updateStats(stats, res, batch): | |
stats["totalBatches"] += 1 | |
stats["totalData"] += getLength(batch) | |
stats["totalLoss"] += res["loss"] | |
stats["totalCorrect"] += res["correctNum"] | |
stats["loss"] = stats["totalLoss"] / stats["totalBatches"] | |
stats["acc"] = stats["totalCorrect"] / stats["totalData"] | |
stats["emaLoss"] = emaAvg(stats["emaLoss"], res["loss"]) | |
stats["emaAcc"] = emaAvg(stats["emaAcc"], res["acc"]) | |
return stats | |
# auto-encoder ae = {:2.4f} autoEncLoss, | |
# Translates training statistics into a string to print | |
def statsToStr(stats, res, epoch, batchNum, dataLen, startTime): | |
formatStr = "\reb {epoch},{batchNum} ({dataProcessed} / {dataLen:5d}), " + \ | |
"t = {time} ({loadTime:2.2f}+{trainTime:2.2f}), " + \ | |
"lr {lr}, l = {loss}, a = {acc}, avL = {avgLoss}, " + \ | |
"avA = {avgAcc}, g = {gradNorm:2.4f}, " + \ | |
"emL = {emaLoss:2.4f}, emA = {emaAcc:2.4f}; " + \ | |
"{expname}" # {machine}/{gpu}" | |
s_epoch = bcolored("{:2d}".format(epoch), "green") | |
s_batchNum = "{:3d}".format(batchNum) | |
s_dataProcessed = bcolored("{:5d}".format(stats["totalData"]), "green") | |
s_dataLen = dataLen | |
s_time = bcolored("{:2.2f}".format(time.time() - startTime), "green") | |
s_loadTime = res["readTime"] | |
s_trainTime = res["trainTime"] | |
s_lr = bold(config.lr) | |
s_loss = bcolored("{:2.4f}".format(res["loss"]), "blue") | |
s_acc = bcolored("{:2.4f}".format(res["acc"]), "blue") | |
s_avgLoss = bcolored("{:2.4f}".format(stats["loss"]), "blue") | |
s_avgAcc = bcolored("{:2.4f}".format(stats["acc"]), "red") | |
s_gradNorm = res["gradNorm"] | |
s_emaLoss = stats["emaLoss"] | |
s_emaAcc = stats["emaAcc"] | |
s_expname = config.expName | |
# s_machine = bcolored(config.dataPath[9:11],"green") | |
# s_gpu = bcolored(config.gpus,"green") | |
return formatStr.format(epoch=s_epoch, batchNum=s_batchNum, dataProcessed=s_dataProcessed, | |
dataLen=s_dataLen, time=s_time, loadTime=s_loadTime, | |
trainTime=s_trainTime, lr=s_lr, loss=s_loss, acc=s_acc, | |
avgLoss=s_avgLoss, avgAcc=s_avgAcc, gradNorm=s_gradNorm, | |
emaLoss=s_emaLoss, emaAcc=s_emaAcc, expname=s_expname) | |
# machine = s_machine, gpu = s_gpu) | |
# collectRuntimeStats, writer = None, | |
''' | |
Runs an epoch with model and session over the data. | |
1. Batches the data and optionally mix it with the extra alterData. | |
2. Start worker threads to load images in parallel to training. | |
3. Runs model for each batch, and gets results (e.g. loss, accuracy). | |
4. Updates and prints statistics based on batch results. | |
5. Once in a while (every config.saveEvery), save weights. | |
Args: | |
sess: TF session to run with. | |
model: model to process data. Has runBatch method that process a given batch. | |
(See model.py for further details). | |
data: data to use for training/evaluation. | |
epoch: epoch number. | |
saver: TF saver to save weights | |
calle: a method to call every number of iterations (config.calleEvery) | |
alterData: extra data to mix with main data while training. | |
getAtt: True to return model attentions. | |
''' | |
def main(question, image): | |
with open(config.configFile(), "a+") as outFile: | |
json.dump(vars(config), outFile) | |
# set gpus | |
if config.gpus != "": | |
config.gpusNum = len(config.gpus.split(",")) | |
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus | |
tf.logging.set_verbosity(tf.logging.ERROR) | |
# process data | |
print(bold("Preprocess data...")) | |
start = time.time() | |
preprocessor = Preprocesser() | |
cnn_model = build_model() | |
imageData = get_img_feat(cnn_model, image) | |
qData, embeddings, answerDict = preprocessor.preprocessData(question) | |
data = {'data': qData, 'image': imageData} | |
print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue"))) | |
# build model | |
print(bold("Building model...")) | |
start = time.time() | |
model = MACnet(embeddings, answerDict) | |
print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue"))) | |
# initializer | |
init = tf.global_variables_initializer() | |
# savers | |
savers = setSavers(model) | |
saver, emaSaver = savers["saver"], savers["emaSaver"] | |
# sessionConfig | |
sessionConfig = setSession() | |
with tf.Session(config=sessionConfig) as sess: | |
# ensure no more ops are added after model is built | |
sess.graph.finalize() | |
# restore / initialize weights, initialize epoch variable | |
epoch = loadWeights(sess, saver, init) | |
print(epoch) | |
start = time.time() | |
if epoch > 0: | |
if config.useEMA: | |
emaSaver.restore(sess, config.weightsFile(epoch)) | |
else: | |
saver.restore(sess, config.weightsFile(epoch)) | |
evalRes = model.runBatch(sess, data['data'], data['image'], False) | |
print("took {:.2f} seconds".format(time.time() - start)) | |
print(evalRes) | |
if __name__ == '__main__': | |
parseArgs() | |
loadDatasetConfig[config.dataset]() | |
question = 'How many text objects are located at the bottom side of table?' | |
imagePath = './mac-layoutLM-sample/PDF_val_64.png' | |
main(question, imagePath) | |