File size: 4,111 Bytes
cb80c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import sys
import logging
import copy
import torch
from PIL import Image
import torchvision.transforms as transforms
from utils import factory
from utils.data_manager import DataManager
from torch.utils.data import DataLoader
from utils.toolkit import count_parameters
import os
import numpy as np
import json
import argparse
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
def _set_device(args):
    device_type = args["device"]
    gpus = []

    for device in device_type:
        if device == -1:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda:{}".format(device))

        gpus.append(device)

    args["device"] = gpus

def get_methods(object, spacing=20):
  methodList = []
  for method_name in dir(object):
    try:
        if callable(getattr(object, method_name)):
            methodList.append(str(method_name))
    except Exception:
        methodList.append(str(method_name))
  processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s)
  for method in methodList:
    try:
        print(str(method.ljust(spacing)) + ' ' +
              processFunc(str(getattr(object, method).__doc__)[0:90]))
    except Exception:
        print(method.ljust(spacing) + ' ' + ' getattr() failed')

def load_model(args):
    _set_device(args)
    model = factory.get_model(args["model_name"], args)
    model.load_checkpoint(args["checkpoint"])
    return model
    
def evaluate(args):
    logs_name = "logs/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], args['init_cls'], args['increment'])

    if not os.path.exists(logs_name):
        os.makedirs(logs_name)
    logfilename = "logs/{}/{}_{}/{}/{}/{}_{}_{}".format(
        args["model_name"],
        args["dataset"],
        args['data'],
        args['init_cls'],
        args["increment"],
        args["prefix"],
        args["seed"],
        args["convnet_type"],
    )
    if not os.path.exists(logs_name):
        os.makedirs(logs_name)
    args['logfilename'] = logs_name
    args['csv_name'] = "{}_{}_{}".format(
        args["prefix"],
        args["seed"],
        args["convnet_type"],
    )
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(filename)s] => %(message)s",
        handlers=[
            logging.FileHandler(filename=logfilename + ".log"),
            logging.StreamHandler(sys.stdout),
        ],
    )
    _set_random()
    print_args(args)
    model = load_model(args)
    
    data_manager = DataManager(
        args["dataset"],
        False,
        args["seed"],
        args["init_cls"],
        args["increment"],
        path = args["data"]
    )
    loader = DataLoader(data_manager.get_dataset(model.class_list, source = "test", mode = "test"), batch_size=args['batch_size'], shuffle=True, num_workers=8)
    
    cnn_acc, nme_acc = model.eval_task(loader, group = 1, mode = "test")
    print(cnn_acc, nme_acc)
def main():
    args = setup_parser().parse_args()
    param = load_json(args.config)
    args = vars(args)  # Converting argparse Namespace to a dict.
    args.update(param)  # Add parameters from json
    evaluate(args)

def load_json(settings_path):
    with open(settings_path) as data_file:
        param = json.load(data_file)

    return param

def _set_random():
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def setup_parser():
    parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
    parser.add_argument('--config', type=str, default='./exps/finetune.json',
                        help='Json file of settings.')
    parser.add_argument('-d','--data', type=str, help='Path of the data folder')
    parser.add_argument('-c','--checkpoint', type=str, help='Path of checkpoint file if resume training')
    return parser

def print_args(args):
    for key, value in args.items():
        logging.info("{}: {}".format(key, value))
if __name__ == '__main__':
    main()