''' We implemented `iCaRL+RMM`, `FOSTER+RMM` in [rmm.py](models/rmm.py). We implemented the `Pretraining Stage` of `RMM` in [rmm_train.py](rmm_train.py). Use the following training script to run it. ```bash python rmm_train.py --config=./exps/rmm-pretrain.json ``` ''' import json import argparse from trainer import train import sys import logging import copy import torch from utils import factory from utils.data_manager import DataManager from utils.rl_utils.ddpg import DDPG from utils.rl_utils.rl_utils import ReplayBuffer from utils.toolkit import count_parameters import os import numpy as np import random class CILEnv: def __init__(self, args) -> None: self._args = copy.deepcopy(args) self.settings = [(50, 2), (50, 5), (50, 10), (50, 20), (10, 10), (20, 20), (5, 5)] # self.settings = [(5,5)] # Debug self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))] self.data_manager = DataManager( self._args["dataset"], self._args["shuffle"], self._args["seed"], self._args["init_cls"], self._args["increment"], ) self.model = factory.get_model(self._args["model_name"], self._args) @property def nb_task(self): return self.data_manager.nb_tasks @property def cur_task(self): return self.model._cur_task def get_task_size(self, task_id): return self.data_manager.get_task_size(task_id) def reset(self): self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))] self.data_manager = DataManager( self._args["dataset"], self._args["shuffle"], self._args["seed"], self._args["init_cls"], self._args["increment"], ) self.model = factory.get_model(self._args["model_name"], self._args) info = "start new task: dataset: {}, init_cls: {}, increment: {}".format( self._args["dataset"], self._args["init_cls"], self._args["increment"] ) return np.array([self.get_task_size(0) / 100, 0]), None, False, info def step(self, action): self.model._m_rate_list.append(action[0]) self.model._c_rate_list.append(action[1]) self.model.incremental_train(self.data_manager) cnn_accy, nme_accy = self.model.eval_task() self.model.after_task() done = self.cur_task == self.nb_task - 1 info = "running task [{}/{}]: dataset: {}, increment: {}, cnn_accy top1: {}, top5: {}".format( self.model._known_classes, 100, self._args["dataset"], self._args["increment"], cnn_accy["top1"], cnn_accy["top5"], ) return ( np.array( [ self.get_task_size(self.cur_task+1)/100 if not done else 0., self.model.memory_size / (self.model.memory_size + self.model.new_memory_size), ] ), cnn_accy["top1"]/100, done, info, ) def _train(args): logs_name = "logs/RL-CIL/{}/".format(args["model_name"]) if not os.path.exists(logs_name): os.makedirs(logs_name) logfilename = "logs/RL-CIL/{}/{}_{}_{}_{}_{}".format( args["model_name"], args["prefix"], args["seed"], args["model_name"], args["convnet_type"], args["dataset"], ) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(filename)s] => %(message)s", handlers=[ logging.FileHandler(filename=logfilename + ".log"), logging.StreamHandler(sys.stdout), ], ) _set_random() _set_device(args) print_args(args) actor_lr = 5e-4 critic_lr = 5e-3 num_episodes = 200 hidden_dim = 32 gamma = 0.98 tau = 0.005 buffer_size = 1000 minimal_size = 50 batch_size = 32 sigma = 0.2 # action noise, encouraging the off-policy algo to explore. device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") env = CILEnv(args) replay_buffer = ReplayBuffer(buffer_size) agent = DDPG( 2, 1, 4, hidden_dim, False, 1, sigma, actor_lr, critic_lr, tau, gamma, device ) for iteration in range(num_episodes): state, *_, info = env.reset() logging.info(info) done = False while not done: action = agent.take_action(state) logging.info(f"take action: m_rate {action[0]}, c_rate {action[1]}") next_state, reward, done, info = env.step(action) logging.info(info) replay_buffer.add(state, action, reward, next_state, done) state = next_state if replay_buffer.size() > minimal_size: b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size) transition_dict = { "states": b_s, "actions": b_a, "next_states": b_ns, "rewards": b_r, "dones": b_d, } agent.update(transition_dict) def _set_device(args): device_type = args["device"] gpus = [] for device in device_type: if device_type == -1: device = torch.device("cpu") else: device = torch.device("cuda:{}".format(device)) gpus.append(device) args["device"] = gpus def _set_random(): random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed(1) torch.cuda.manual_seed_all(1) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def print_args(args): for key, value in args.items(): logging.info("{}: {}".format(key, value)) def train(args): seed_list = copy.deepcopy(args["seed"]) device = copy.deepcopy(args["device"]) for seed in seed_list: args["seed"] = seed args["device"] = device _train(args) def main(): args = setup_parser().parse_args() param = load_json(args.config) args = vars(args) # Converting argparse Namespace to a dict. args.update(param) # Add parameters from json train(args) def load_json(settings_path): with open(settings_path) as data_file: param = json.load(data_file) return param def setup_parser(): parser = argparse.ArgumentParser( description="Reproduce of multiple continual learning algorthms." ) parser.add_argument( "--config", type=str, default="./exps/finetune.json", help="Json file of settings.", ) return parser if __name__ == "__main__": main()