File size: 1,976 Bytes
0392181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# --------------------------------------------------------
# OpenVQA
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# Modified to add trojan result extraction options
# --------------------------------------------------------

import os, copy
from openvqa.datasets.dataset_loader import DatasetLoader
from utils.train_engine import train_engine
from utils.test_engine import test_engine
from utils.extract_engine import extract_engine

class Execution:
    def __init__(self, __C):
        self.__C = __C

        if __C.RUN_MODE != 'extract':
            print('Loading dataset........')
            self.dataset = DatasetLoader(__C).DataSet()

        # If trigger the evaluation after every epoch
        # Will create a new cfgs with RUN_MODE = 'val'
        self.dataset_eval = None
        if __C.EVAL_EVERY_EPOCH:
            __C_eval = copy.deepcopy(__C)
            setattr(__C_eval, 'RUN_MODE', 'val')
            # modification - force eval set to clean when in train mode
            setattr(__C_eval, 'VER', 'clean')

            print('Loading validation set for per-epoch evaluation........')
            self.dataset_eval = DatasetLoader(__C_eval).DataSet()


    def run(self, run_mode):
        if run_mode == 'train':
            if self.__C.RESUME is False:
                self.empty_log(self.__C.VERSION)
            train_engine(self.__C, self.dataset, self.dataset_eval)

        elif run_mode == 'val':
            test_engine(self.__C, self.dataset, validation=True)

        elif run_mode == 'test':
            test_engine(self.__C, self.dataset)

        elif run_mode == 'extract':
            extract_engine(self.__C)

        else:
            exit(-1)


    def empty_log(self, version):
        print('Initializing log file........')
        if (os.path.exists(self.__C.LOG_PATH + '/log_run_' + version + '.txt')):
            os.remove(self.__C.LOG_PATH + '/log_run_' + version + '.txt')
        print('Finished!')
        print('')