Spaces:
Sleeping
Sleeping
| import sys | |
| import logging | |
| import copy | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from utils import factory | |
| from utils.data_manager import DataManager | |
| from torch.utils.data import DataLoader | |
| from utils.toolkit import count_parameters | |
| import os | |
| import numpy as np | |
| import json | |
| import argparse | |
| import torch.multiprocessing | |
| torch.multiprocessing.set_sharing_strategy('file_system') | |
| def _set_device(args): | |
| device_type = args["device"] | |
| gpus = [] | |
| for device in device_type: | |
| if device == -1: | |
| device = torch.device("cpu") | |
| else: | |
| device = torch.device("cuda:{}".format(device)) | |
| gpus.append(device) | |
| args["device"] = gpus | |
| def get_methods(object, spacing=20): | |
| methodList = [] | |
| for method_name in dir(object): | |
| try: | |
| if callable(getattr(object, method_name)): | |
| methodList.append(str(method_name)) | |
| except Exception: | |
| methodList.append(str(method_name)) | |
| processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s) | |
| for method in methodList: | |
| try: | |
| print(str(method.ljust(spacing)) + ' ' + | |
| processFunc(str(getattr(object, method).__doc__)[0:90])) | |
| except Exception: | |
| print(method.ljust(spacing) + ' ' + ' getattr() failed') | |
| def load_model(args): | |
| _set_device(args) | |
| model = factory.get_model(args["model_name"], args) | |
| model.load_checkpoint(args["checkpoint"]) | |
| return model | |
| def evaluate(args): | |
| logs_name = "logs/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], args['init_cls'], args['increment']) | |
| if not os.path.exists(logs_name): | |
| os.makedirs(logs_name) | |
| logfilename = "logs/{}/{}_{}/{}/{}/{}_{}_{}".format( | |
| args["model_name"], | |
| args["dataset"], | |
| args['data'], | |
| args['init_cls'], | |
| args["increment"], | |
| args["prefix"], | |
| args["seed"], | |
| args["convnet_type"], | |
| ) | |
| if not os.path.exists(logs_name): | |
| os.makedirs(logs_name) | |
| args['logfilename'] = logs_name | |
| args['csv_name'] = "{}_{}_{}".format( | |
| args["prefix"], | |
| args["seed"], | |
| args["convnet_type"], | |
| ) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(filename)s] => %(message)s", | |
| handlers=[ | |
| logging.FileHandler(filename=logfilename + ".log"), | |
| logging.StreamHandler(sys.stdout), | |
| ], | |
| ) | |
| _set_random() | |
| print_args(args) | |
| model = load_model(args) | |
| data_manager = DataManager( | |
| args["dataset"], | |
| False, | |
| args["seed"], | |
| args["init_cls"], | |
| args["increment"], | |
| path = args["data"] | |
| ) | |
| loader = DataLoader(data_manager.get_dataset(model.class_list, source = "test", mode = "test"), batch_size=args['batch_size'], shuffle=True, num_workers=8) | |
| cnn_acc, nme_acc = model.eval_task(loader, group = 1, mode = "test") | |
| print(cnn_acc, nme_acc) | |
| def main(): | |
| args = setup_parser().parse_args() | |
| param = load_json(args.config) | |
| args = vars(args) # Converting argparse Namespace to a dict. | |
| args.update(param) # Add parameters from json | |
| evaluate(args) | |
| def load_json(settings_path): | |
| with open(settings_path) as data_file: | |
| param = json.load(data_file) | |
| return param | |
| def _set_random(): | |
| torch.manual_seed(1) | |
| torch.cuda.manual_seed(1) | |
| torch.cuda.manual_seed_all(1) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def setup_parser(): | |
| parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') | |
| parser.add_argument('--config', type=str, default='./exps/finetune.json', | |
| help='Json file of settings.') | |
| parser.add_argument('-d','--data', type=str, help='Path of the data folder') | |
| parser.add_argument('-c','--checkpoint', type=str, help='Path of checkpoint file if resume training') | |
| return parser | |
| def print_args(args): | |
| for key, value in args.items(): | |
| logging.info("{}: {}".format(key, value)) | |
| if __name__ == '__main__': | |
| main() | |