|
''' |
|
Re-implementation of SimpleCIL (https://arxiv.org/abs/2303.07338) without pre-trained weights. |
|
The training process is as follows: train the model with cross-entropy in the first stage and replace the classifier with prototypes for all the classes in the subsequent stages. |
|
Please refer to the original implementation (https://github.com/zhoudw-zdw/RevisitingCIL) if you are using pre-trained weights. |
|
''' |
|
import logging |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torch.serialization import load |
|
from tqdm import tqdm |
|
from torch import optim |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from utils.inc_net import SimpleCosineIncrementalNet |
|
from models.base import BaseLearner |
|
from utils.toolkit import target2onehot, tensor2numpy |
|
|
|
|
|
num_workers = 8 |
|
batch_size = 32 |
|
milestones = [40, 80] |
|
|
|
class SimpleCIL(BaseLearner): |
|
def __init__(self, args): |
|
super().__init__(args) |
|
self._network = SimpleCosineIncrementalNet(args, False) |
|
self.min_lr = args['min_lr'] if args['min_lr'] is not None else 1e-8 |
|
self.args = args |
|
|
|
def load_checkpoint(self, filename): |
|
checkpoint = torch.load(filename) |
|
self._total_classes = len(checkpoint["classes"]) |
|
self.class_list = np.array(checkpoint["classes"]) |
|
self.label_list = checkpoint["label_list"] |
|
print("Class list: ", self.class_list) |
|
self._network.update_fc(self._total_classes) |
|
self._network.load_checkpoint(checkpoint["network"]) |
|
self._network.to(self._device) |
|
|
|
def after_task(self): |
|
self._known_classes = self._total_classes |
|
|
|
def save_checkpoint(self, filename): |
|
self._network.cpu() |
|
save_dict = { |
|
"classes": self.data_manager.get_class_list(self._cur_task), |
|
"network": { |
|
"convnet": self._network.convnet.state_dict(), |
|
"fc": self._network.fc.state_dict() |
|
}, |
|
"label_list": self.data_manager.get_label_list(self._cur_task), |
|
} |
|
torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task)) |
|
|
|
def replace_fc(self,trainloader, model, args): |
|
model = model.eval() |
|
embedding_list = [] |
|
label_list = [] |
|
with torch.no_grad(): |
|
for i, batch in enumerate(trainloader): |
|
(_,data,label) = batch |
|
data = data.cuda() |
|
label = label.cuda() |
|
embedding = model(data)["features"] |
|
embedding_list.append(embedding.cpu()) |
|
label_list.append(label.cpu()) |
|
embedding_list = torch.cat(embedding_list, dim=0) |
|
label_list = torch.cat(label_list, dim=0) |
|
|
|
class_list = np.unique(self.train_dataset.labels) |
|
proto_list = [] |
|
for class_index in class_list: |
|
|
|
data_index = torch.nonzero(label_list == class_index).squeeze(-1) |
|
embedding = embedding_list[data_index] |
|
proto = embedding.mean(0) |
|
if len(self._multiple_gpus) > 1: |
|
self._network.module.fc.weight.data[class_index] = proto |
|
else: |
|
self._network.fc.weight.data[class_index] = proto |
|
return model |
|
|
|
def incremental_train(self, data_manager): |
|
self._cur_task += 1 |
|
self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) |
|
self._network.update_fc(self._total_classes) |
|
logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes)) |
|
self.class_list = np.array(data_manager.get_class_list(self._cur_task)) |
|
train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train", ) |
|
self.train_dataset = train_dataset |
|
self.data_manager = data_manager |
|
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) |
|
test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" ) |
|
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) |
|
|
|
train_dataset_for_protonet = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="test", ) |
|
self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=batch_size, shuffle=True, num_workers=num_workers) |
|
|
|
if len(self._multiple_gpus) > 1: |
|
print('Multiple GPUs') |
|
self._network = nn.DataParallel(self._network, self._multiple_gpus) |
|
self._train(self.train_loader, self.test_loader, self.train_loader_for_protonet) |
|
if len(self._multiple_gpus) > 1: |
|
self._network = self._network.module |
|
|
|
def _train(self, train_loader, test_loader, train_loader_for_protonet): |
|
self._network.to(self._device) |
|
if self._cur_task == 0: |
|
optimizer = optim.SGD( |
|
self._network.parameters(), |
|
momentum=0.9, |
|
lr=self.args["init_lr"], |
|
weight_decay=self.args["init_weight_decay"] |
|
) |
|
scheduler = optim.lr_scheduler.CosineAnnealingLR( |
|
optimizer=optimizer, T_max=self.args['init_epoch'], eta_min=self.min_lr |
|
) |
|
self._init_train(train_loader, test_loader, optimizer, scheduler) |
|
self.replace_fc(train_loader_for_protonet, self._network, None) |
|
|
|
def _init_train(self, train_loader, test_loader, optimizer, scheduler): |
|
prog_bar = tqdm(range(self.args["init_epoch"])) |
|
for _, epoch in enumerate(prog_bar): |
|
self._network.train() |
|
losses = 0.0 |
|
correct, total = 0, 0 |
|
for i, (_, inputs, targets) in enumerate(train_loader): |
|
inputs, targets = inputs.to(self._device), targets.to(self._device) |
|
logits = self._network(inputs)["logits"] |
|
|
|
loss = F.cross_entropy(logits, targets) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
losses += loss.item() |
|
|
|
_, preds = torch.max(logits, dim=1) |
|
correct += preds.eq(targets.expand_as(preds)).cpu().sum() |
|
total += len(targets) |
|
|
|
scheduler.step() |
|
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) |
|
|
|
if epoch % 5 == 0: |
|
test_acc = self._compute_accuracy(self._network, test_loader) |
|
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( |
|
self._cur_task, |
|
epoch + 1, |
|
self.args['init_epoch'], |
|
losses / len(train_loader), |
|
train_acc, |
|
test_acc, |
|
) |
|
else: |
|
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( |
|
self._cur_task, |
|
epoch + 1, |
|
self.args['init_epoch'], |
|
losses / len(train_loader), |
|
train_acc, |
|
) |
|
elapsed = prog_bar.format_dict["elapsed"] |
|
rate = prog_bar.format_dict["rate"] |
|
remaining = (prog_bar.total - prog_bar.n) / rate if rate and prog_bar.total else 0 |
|
prog_bar.set_description(info) |
|
logging.info("Working on task {}: {:.2f}:{:.2f}".format( |
|
self._cur_task, |
|
elapsed, |
|
remaining)) |
|
logging.info(info) |
|
logging.info("Finised on task {}: {:.2f}".format( |
|
self._cur_task, elapsed)) |
|
|
|
|
|
|