SAE / models /foster.py
Ttius's picture
Upload 192 files
998bb30 verified
import logging
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import FOSTERNet
from utils.toolkit import count_parameters, target2onehot, tensor2numpy
# Please refer to https://github.com/G-U-N/ECCV22-FOSTER for the full source code to reproduce foster.
EPSILON = 1e-8
class FOSTER(BaseLearner):
def __init__(self, args):
super().__init__(args)
self.args = args
self._network = FOSTERNet(args, False)
self._snet = None
self.beta1 = args["beta1"]
self.beta2 = args["beta2"]
self.per_cls_weights = None
self.is_teacher_wa = args["is_teacher_wa"]
self.is_student_wa = args["is_student_wa"]
self.lambda_okd = args["lambda_okd"]
self.wa_value = args["wa_value"]
self.oofc = args["oofc"].lower()
def after_task(self):
self._known_classes = self._total_classes
logging.info("Exemplar size: {}".format(self.exemplar_size))
def incremental_train(self, data_manager=None):
self.data_manager = data_manager
self._cur_task += 1
if self._cur_task > 1:
self._network = self._snet
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self._network.update_fc(self._total_classes)
self._network_module_ptr = self._network
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
if not self.args["attack"]:
if self._cur_task > 0:
for p in self._network.convnets[0].parameters():
p.requires_grad = False
for p in self._network.oldfc.parameters():
p.requires_grad = False
logging.info("All params: {}".format(count_parameters(self._network)))
logging.info(
"Trainable params: {}".format(count_parameters(self._network, True))
)
train_dataset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
)
self.train_loader = DataLoader(
train_dataset,
batch_size=self.args["batch_size"],
shuffle=True,
num_workers=self.args["num_workers"],
pin_memory=True,
)
test_dataset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.test_loader = DataLoader(
test_dataset,
batch_size=self.args["batch_size"],
shuffle=False,
num_workers=self.args["num_workers"],
)
if len(self._multiple_gpus) > 1:
self._network = nn.DataParallel(self._network, self._multiple_gpus)
else:
self.train_loader, self.test_loader = None, None
self._train(self.train_loader, self.test_loader)
if not self.args["attack"]:
self.build_rehearsal_memory(data_manager, self.samples_per_class)
if len(self._multiple_gpus) > 1:
self._network = self._network.module
def train(self):
self._network_module_ptr.train()
self._network_module_ptr.convnets[-1].train()
if self._cur_task >= 1:
self._network_module_ptr.convnets[0].eval()
def _train(self, train_loader=None, test_loader=None):
self._network.to(self._device)
if self.args["attack"]:
self._network.eval()
if hasattr(self._network, "module"):
self._network_module_ptr = self._network.module
if self._cur_task == 0:
if self.args["attack"]:
optimizer, scheduler = None, None
else:
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, self._network.parameters()),
momentum=0.9,
lr=self.args["init_lr"],
weight_decay=self.args["init_weight_decay"],
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=self.args["init_epochs"]
)
self._init_train(train_loader, test_loader, optimizer, scheduler)
else:
if self.args["attack"]:
train_loader, test_loader = None, None
else:
cls_num_list = [self.samples_old_class] * self._known_classes + [
self.samples_new_class(i)
for i in range(self._known_classes, self._total_classes)
]
effective_num = 1.0 - np.power(self.beta1, cls_num_list)
per_cls_weights = (1.0 - self.beta1) / np.array(effective_num)
per_cls_weights = (
per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
)
logging.info("per cls weights : {}".format(per_cls_weights))
self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device)
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, self._network.parameters()),
lr=self.args["lr"],
momentum=0.9,
weight_decay=self.args["weight_decay"],
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=self.args["boosting_epochs"]
)
if self.oofc == "az":
for i, p in enumerate(self._network_module_ptr.fc.parameters()):
if i == 0:
p.data[
self._known_classes :, : self._network_module_ptr.out_dim
] = torch.tensor(0.0)
elif self.oofc != "ft":
assert 0, "not implemented"
self._feature_boosting(train_loader, test_loader, optimizer, scheduler)
if self.is_teacher_wa:
self._network_module_ptr.weight_align(
self._known_classes,
self._total_classes - self._known_classes,
self.wa_value,
)
else:
logging.info("do not weight align teacher!")
cls_num_list = [self.samples_old_class] * self._known_classes + [
self.samples_new_class(i)
for i in range(self._known_classes, self._total_classes)
]
effective_num = 1.0 - np.power(self.beta2, cls_num_list)
per_cls_weights = (1.0 - self.beta2) / np.array(effective_num)
per_cls_weights = (
per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
)
logging.info("per cls weights : {}".format(per_cls_weights))
self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device)
self._feature_compression(train_loader, test_loader)
def _init_train(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(self.args["init_epochs"]))
for _, epoch in enumerate(prog_bar):
self.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(
self._device, non_blocking=True
), targets.to(self._device, non_blocking=True)
logits = self._network(inputs)["logits"]
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["init_epochs"],
losses / len(train_loader),
train_acc,
test_acc,
)
else:
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["init_epochs"],
losses / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)
def _feature_boosting(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(self.args["boosting_epochs"]))
for _, epoch in enumerate(prog_bar):
self.train()
losses = 0.0
losses_clf = 0.0
losses_fe = 0.0
losses_kd = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(
self._device, non_blocking=True
), targets.to(self._device, non_blocking=True)
outputs = self._network(inputs)
logits, fe_logits, old_logits = (
outputs["logits"],
outputs["fe_logits"],
outputs["old_logits"].detach(),
)
loss_clf = F.cross_entropy(logits / self.per_cls_weights, targets)
loss_fe = F.cross_entropy(fe_logits, targets)
loss_kd = self.lambda_okd * _KD_loss(
logits[:, : self._known_classes], old_logits, self.args["T"]
)
loss = loss_clf + loss_fe + loss_kd
optimizer.zero_grad()
loss.backward()
if self.oofc == "az":
for i, p in enumerate(self._network_module_ptr.fc.parameters()):
if i == 0:
p.grad.data[
self._known_classes :,
: self._network_module_ptr.out_dim,
] = torch.tensor(0.0)
elif self.oofc != "ft":
assert 0, "not implemented"
optimizer.step()
losses += loss.item()
losses_fe += loss_fe.item()
losses_clf += loss_clf.item()
losses_kd += (
self._known_classes / self._total_classes
) * loss_kd.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["boosting_epochs"],
losses / len(train_loader),
losses_clf / len(train_loader),
losses_fe / len(train_loader),
losses_kd / len(train_loader),
train_acc,
test_acc,
)
else:
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["boosting_epochs"],
losses / len(train_loader),
losses_clf / len(train_loader),
losses_fe / len(train_loader),
losses_kd / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)
def _feature_compression(self, train_loader, test_loader):
self._snet = FOSTERNet(self.args, False)
self._snet.update_fc(self._total_classes)
if len(self._multiple_gpus) > 1:
self._snet = nn.DataParallel(self._snet, self._multiple_gpus)
if hasattr(self._snet, "module"):
self._snet_module_ptr = self._snet.module
else:
self._snet_module_ptr = self._snet
self._snet.to(self._device)
self._snet_module_ptr.convnets[0].load_state_dict(
self._network_module_ptr.convnets[0].state_dict()
)
self._snet_module_ptr.copy_fc(self._network_module_ptr.oldfc)
if not self.args["attack"]:
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, self._snet.parameters()),
lr=self.args["lr"],
momentum=0.9,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=self.args["compression_epochs"]
)
self._network.eval()
prog_bar = tqdm(range(self.args["compression_epochs"]))
for _, epoch in enumerate(prog_bar):
self._snet.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(
self._device, non_blocking=True
), targets.to(self._device, non_blocking=True)
dark_logits = self._snet(inputs)["logits"]
with torch.no_grad():
outputs = self._network(inputs)
logits, old_logits, fe_logits = (
outputs["logits"],
outputs["old_logits"],
outputs["fe_logits"],
)
loss_dark = self.BKD(dark_logits, logits, self.args["T"])
loss = loss_dark
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
_, preds = torch.max(dark_logits[: targets.shape[0]], dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._snet, test_loader)
info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["compression_epochs"],
losses / len(train_loader),
train_acc,
test_acc,
)
else:
info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args["compression_epochs"],
losses / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)
if len(self._multiple_gpus) > 1:
self._snet = self._snet.module
if self.is_student_wa:
self._snet.weight_align(
self._known_classes,
self._total_classes - self._known_classes,
self.wa_value,
)
else:
logging.info("do not weight align student!")
self._snet.eval()
y_pred, y_true = [], []
for _, (_, inputs, targets) in enumerate(test_loader):
inputs = inputs.to(self._device, non_blocking=True)
with torch.no_grad():
outputs = self._snet(inputs)["logits"]
predicts = torch.topk(
outputs, k=self.topk, dim=1, largest=True, sorted=True
)[1]
y_pred.append(predicts.cpu().numpy())
y_true.append(targets.cpu().numpy())
y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)
cnn_accy = self._evaluate(y_pred, y_true)
logging.info("darknet eval: ")
logging.info("CNN top1 curve: {}".format(cnn_accy["top1"]))
logging.info("CNN top5 curve: {}".format(cnn_accy["top5"]))
@property
def samples_old_class(self):
if self._fixed_memory:
return self._memory_per_class
else:
assert self._total_classes != 0, "Total classes is 0"
return self._memory_size // self._known_classes
def samples_new_class(self, index):
if self.args["dataset"] == "cifar100":
return 500
else:
return self.data_manager.getlen(index)
def BKD(self, pred, soft, T):
pred = torch.log_softmax(pred / T, dim=1)
soft = torch.softmax(soft / T, dim=1)
soft = soft * self.per_cls_weights
soft = soft / soft.sum(1)[:, None]
return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
def _KD_loss(pred, soft, T):
pred = torch.log_softmax(pred / T, dim=1)
soft = torch.softmax(soft / T, dim=1)
return -1 * torch.mul(soft, pred).sum() / pred.shape[0]