Spaces:
Runtime error
Runtime error
File size: 1,976 Bytes
0392181 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# --------------------------------------------------------
# OpenVQA
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# Modified to add trojan result extraction options
# --------------------------------------------------------
import os, copy
from openvqa.datasets.dataset_loader import DatasetLoader
from utils.train_engine import train_engine
from utils.test_engine import test_engine
from utils.extract_engine import extract_engine
class Execution:
def __init__(self, __C):
self.__C = __C
if __C.RUN_MODE != 'extract':
print('Loading dataset........')
self.dataset = DatasetLoader(__C).DataSet()
# If trigger the evaluation after every epoch
# Will create a new cfgs with RUN_MODE = 'val'
self.dataset_eval = None
if __C.EVAL_EVERY_EPOCH:
__C_eval = copy.deepcopy(__C)
setattr(__C_eval, 'RUN_MODE', 'val')
# modification - force eval set to clean when in train mode
setattr(__C_eval, 'VER', 'clean')
print('Loading validation set for per-epoch evaluation........')
self.dataset_eval = DatasetLoader(__C_eval).DataSet()
def run(self, run_mode):
if run_mode == 'train':
if self.__C.RESUME is False:
self.empty_log(self.__C.VERSION)
train_engine(self.__C, self.dataset, self.dataset_eval)
elif run_mode == 'val':
test_engine(self.__C, self.dataset, validation=True)
elif run_mode == 'test':
test_engine(self.__C, self.dataset)
elif run_mode == 'extract':
extract_engine(self.__C)
else:
exit(-1)
def empty_log(self, version):
print('Initializing log file........')
if (os.path.exists(self.__C.LOG_PATH + '/log_run_' + version + '.txt')):
os.remove(self.__C.LOG_PATH + '/log_run_' + version + '.txt')
print('Finished!')
print('')
|