# -------------------------------------------------------- # OpenVQA # Written by Yuhao Cui https://github.com/cuiyuhao1996 # Modified to add trojan result extraction options # -------------------------------------------------------- import os, copy from openvqa.datasets.dataset_loader import DatasetLoader from utils.train_engine import train_engine from utils.test_engine import test_engine from utils.extract_engine import extract_engine class Execution: def __init__(self, __C): self.__C = __C if __C.RUN_MODE != 'extract': print('Loading dataset........') self.dataset = DatasetLoader(__C).DataSet() # If trigger the evaluation after every epoch # Will create a new cfgs with RUN_MODE = 'val' self.dataset_eval = None if __C.EVAL_EVERY_EPOCH: __C_eval = copy.deepcopy(__C) setattr(__C_eval, 'RUN_MODE', 'val') # modification - force eval set to clean when in train mode setattr(__C_eval, 'VER', 'clean') print('Loading validation set for per-epoch evaluation........') self.dataset_eval = DatasetLoader(__C_eval).DataSet() def run(self, run_mode): if run_mode == 'train': if self.__C.RESUME is False: self.empty_log(self.__C.VERSION) train_engine(self.__C, self.dataset, self.dataset_eval) elif run_mode == 'val': test_engine(self.__C, self.dataset, validation=True) elif run_mode == 'test': test_engine(self.__C, self.dataset) elif run_mode == 'extract': extract_engine(self.__C) else: exit(-1) def empty_log(self, version): print('Initializing log file........') if (os.path.exists(self.__C.LOG_PATH + '/log_run_' + version + '.txt')): os.remove(self.__C.LOG_PATH + '/log_run_' + version + '.txt') print('Finished!') print('')