HungNP
New single commit message
cb80c28
raw
history blame contribute delete
No virus
18.6 kB
import logging
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import FOSTERNet
from utils.toolkit import count_parameters, target2onehot, tensor2numpy
# Please refer to https://github.com/G-U-N/ECCV22-FOSTER for the full source code to reproduce foster.
EPSILON = 1e-8
class FOSTER(BaseLearner):
def __init__(self, args):
super().__init__(args)
self.args = args
self._network = FOSTERNet(args, False)
self._snet = None
self.beta1 = args["beta1"]
self.beta2 = args["beta2"]
self.per_cls_weights = None
self.is_teacher_wa = args["is_teacher_wa"]
self.is_student_wa = args["is_student_wa"]
self.lambda_okd = args["lambda_okd"]
self.wa_value = args["wa_value"]
self.oofc = args["oofc"].lower()
def after_task(self):
self._known_classes = self._total_classes
logging.info("Exemplar size: {}".format(self.exemplar_size))
def load_checkpoint(self, filename):
checkpoint = torch.load(filename)
self._known_classes = len(checkpoint["classes"])
self.class_list = np.array(checkpoint["classes"])
self.label_list = checkpoint["label_list"]
self._network.update_fc(self._known_classes)
self._network.load_checkpoint(checkpoint["network"])
self._network.to(self._device)
self._cur_task = 0
def save_checkpoint(self, filename):
self._network.cpu()
save_dict = {
"classes": self.data_manager.get_class_list(self._cur_task),
"network": {
"convnet": self._network.convnets[0].state_dict(),
"fc": self._network.fc.state_dict()
},
"label_list": self.data_manager.get_label_list(self._cur_task),
}
torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task))
def incremental_train(self, data_manager):
self.data_manager = data_manager
if hasattr(self.data_manager,'label_list') and hasattr(self,'label_list'):
self.data_manager.label_list = list(self.label_list.values()) + self.data_manager.label_list
self._cur_task += 1
if self._cur_task > 1:
self._network = self._snet
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self._network.update_fc(self._total_classes)
self._network_module_ptr = self._network
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
if self._cur_task > 0:
for p in self._network.convnets[0].parameters():
p.requires_grad = False
for p in self._network.oldfc.parameters():
p.requires_grad = False
logging.info("All params: {}".format(count_parameters(self._network)))
logging.info(
"Trainable params: {}".format(count_parameters(self._network, True))
)
train_dataset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
)
self.train_loader = DataLoader(
train_dataset,
batch_size=self.args["batch_size"],
shuffle=True,
num_workers=self.args["num_workers"],
pin_memory=True,
)
test_dataset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.test_loader = DataLoader(
test_dataset,
batch_size=self.args["batch_size"],
shuffle=False,
num_workers=self.args["num_workers"],
)
if len(self._multiple_gpus) > 1:
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._train(self.train_loader, self.test_loader)
#self.build_rehearsal_memory(data_manager, self.samples_per_class)
if len(self._multiple_gpus) > 1:
self._network = self._network.module
def train(self):
self._network_module_ptr.train()
self._network_module_ptr.convnets[-1].train()
if self._cur_task >= 1:
self._network_module_ptr.convnets[0].eval()
def _train(self, train_loader, test_loader):
self._network.to(self._device)
if hasattr(self._network, "module"):
self._network_module_ptr = self._network.module
if self._cur_task == 0:
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, self._network.parameters()),
momentum=0.9,
lr=self.args["init_lr"],
weight_decay=self.args["init_weight_decay"],
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=self.args["init_epochs"]
)
self._init_train(train_loader, test_loader, optimizer, scheduler)
else:
cls_num_list = [self.samples_old_class] * self._known_classes + [
self.samples_new_class(i)
for i in range(self._known_classes, self._total_classes)
]
effective_num = 1.0 - np.power(self.beta1, cls_num_list)
per_cls_weights = (1.0 - self.beta1) / np.array(effective_num)
per_cls_weights = (
per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
)
logging.info("per cls weights : {}".format(per_cls_weights))
self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device)
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, self._network.parameters()),
lr=self.args["lr"],
momentum=0.9,
weight_decay=self.args["weight_decay"],
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=self.args["boosting_epochs"]
)
if self.oofc == "az":
for i, p in enumerate(self._network_module_ptr.fc.parameters()):
if i == 0:
p.data[
self._known_classes :, : self._network_module_ptr.out_dim
] = torch.tensor(0.0)
elif self.oofc != "ft":
assert 0, "not implemented"
self._feature_boosting(train_loader, test_loader, optimizer, scheduler)
if self.is_teacher_wa:
self._network_module_ptr.weight_align(
self._known_classes,
self._total_classes - self._known_classes,
self.wa_value,
)
else:
logging.info("do not weight align teacher!")
cls_num_list = [self.samples_old_class] * self._known_classes + [
self.samples_new_class(i)
for i in range(self._known_classes, self._total_classes)
]
effective_num = 1.0 - np.power(self.beta2, cls_num_list)
per_cls_weights = (1.0 - self.beta2) / np.array(effective_num)
per_cls_weights = (
per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
)
logging.info("per cls weights : {}".format(per_cls_weights))
self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device)
self._feature_compression(train_loader, test_loader)
def _init_train(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(self.args["init_epochs"]))
for _, epoch in enumerate(prog_bar):
self.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(
self._device, non_blocking=True
), targets.to(self._device, non_blocking=True)
logits = self._network(inputs)["logits"]
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["init_epochs"],
losses / len(train_loader),
train_acc,
test_acc,
)
else:
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["init_epochs"],
losses / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)
def _feature_boosting(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(self.args["boosting_epochs"]))
for _, epoch in enumerate(prog_bar):
self.train()
losses = 0.0
losses_clf = 0.0
losses_fe = 0.0
losses_kd = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(
self._device, non_blocking=True
), targets.to(self._device, non_blocking=True)
outputs = self._network(inputs)
logits, fe_logits, old_logits = (
outputs["logits"],
outputs["fe_logits"],
outputs["old_logits"].detach(),
)
loss_clf = F.cross_entropy(logits / self.per_cls_weights, targets)
loss_fe = F.cross_entropy(fe_logits, targets)
loss_kd = self.lambda_okd * _KD_loss(
logits[:, : self._known_classes], old_logits, self.args["T"]
)
loss = loss_clf + loss_fe + loss_kd
optimizer.zero_grad()
loss.backward()
if self.oofc == "az":
for i, p in enumerate(self._network_module_ptr.fc.parameters()):
if i == 0:
p.grad.data[
self._known_classes :,
: self._network_module_ptr.out_dim,
] = torch.tensor(0.0)
elif self.oofc != "ft":
assert 0, "not implemented"
optimizer.step()
losses += loss.item()
losses_fe += loss_fe.item()
losses_clf += loss_clf.item()
losses_kd += (
self._known_classes / self._total_classes
) * loss_kd.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["boosting_epochs"],
losses / len(train_loader),
losses_clf / len(train_loader),
losses_fe / len(train_loader),
losses_kd / len(train_loader),
train_acc,
test_acc,
)
else:
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["boosting_epochs"],
losses / len(train_loader),
losses_clf / len(train_loader),
losses_fe / len(train_loader),
losses_kd / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)
def _feature_compression(self, train_loader, test_loader):
self._snet = FOSTERNet(self.args, False)
self._snet.update_fc(self._total_classes)
if len(self._multiple_gpus) > 1:
self._snet = nn.DataParallel(self._snet, self._multiple_gpus)
if hasattr(self._snet, "module"):
self._snet_module_ptr = self._snet.module
else:
self._snet_module_ptr = self._snet
self._snet.to(self._device)
self._snet_module_ptr.convnets[0].load_state_dict(
self._network_module_ptr.convnets[0].state_dict()
)
self._snet_module_ptr.copy_fc(self._network_module_ptr.oldfc)
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, self._snet.parameters()),
lr=self.args["lr"],
momentum=0.9,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=self.args["compression_epochs"]
)
self._network.eval()
prog_bar = tqdm(range(self.args["compression_epochs"]))
for _, epoch in enumerate(prog_bar):
self._snet.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(
self._device, non_blocking=True
), targets.to(self._device, non_blocking=True)
dark_logits = self._snet(inputs)["logits"]
with torch.no_grad():
outputs = self._network(inputs)
logits, old_logits, fe_logits = (
outputs["logits"],
outputs["old_logits"],
outputs["fe_logits"],
)
loss_dark = self.BKD(dark_logits, logits, self.args["T"])
loss = loss_dark
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
_, preds = torch.max(dark_logits[: targets.shape[0]], dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._snet, test_loader)
info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["compression_epochs"],
losses / len(train_loader),
train_acc,
test_acc,
)
else:
info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["compression_epochs"],
losses / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)
if len(self._multiple_gpus) > 1:
self._snet = self._snet.module
if self.is_student_wa:
self._snet.weight_align(
self._known_classes,
self._total_classes - self._known_classes,
self.wa_value,
)
else:
logging.info("do not weight align student!")
if self._cur_task > 1:
self._network = self._snet
self._snet.eval()
y_pred, y_true = [], []
for _, (_, inputs, targets) in enumerate(test_loader):
inputs = inputs.to(self._device, non_blocking=True)
with torch.no_grad():
outputs = self._snet(inputs)["logits"]
predicts = torch.topk(
outputs, k=self.topk, dim=1, largest=True, sorted=True
)[1]
y_pred.append(predicts.cpu().numpy())
y_true.append(targets.cpu().numpy())
y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)
cnn_accy = self._evaluate(y_pred, y_true)
logging.info("darknet eval: ")
logging.info("CNN top1 curve: {}".format(cnn_accy["top1"]))
logging.info("CNN top5 curve: {}".format(cnn_accy["top5"]))
@property
def samples_old_class(self):
if self._fixed_memory:
return self._memory_per_class
else:
assert self._total_classes != 0, "Total classes is 0"
return self._memory_size // self._known_classes
def samples_new_class(self, index):
if self.args["dataset"] == "cifar100":
return 500
else:
return self.data_manager.getlen(index)
def BKD(self, pred, soft, T):
pred = torch.log_softmax(pred / T, dim=1)
soft = torch.softmax(soft / T, dim=1)
soft = soft * self.per_cls_weights
soft = soft / soft.sum(1)[:, None]
return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
def _KD_loss(pred, soft, T):
pred = torch.log_softmax(pred / T, dim=1)
soft = torch.softmax(soft / T, dim=1)
return -1 * torch.mul(soft, pred).sum() / pred.shape[0]