HungNP
New single commit message
cb80c28
import logging
import numpy as np
from tqdm import tqdm
import torch
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import (
IncrementalNet,
CosineIncrementalNet,
SimpleCosineIncrementalNet,
)
from utils.toolkit import target2onehot, tensor2numpy
import ot
from torch import nn
import copy
EPSILON = 1e-8
epochs = 100
lrate = 0.1
milestones = [40, 80]
lrate_decay = 0.1
batch_size = 32
memory_size = 2000
T = 2
class COIL(BaseLearner):
def __init__(self, args):
super().__init__(args)
self._network = SimpleCosineIncrementalNet(args, False)
self.data_manager = None
self.nextperiod_initialization = None
self.sinkhorn_reg = args["sinkhorn"]
self.calibration_term = args["calibration_term"]
self.args = args
def after_task(self):
self.nextperiod_initialization = self.solving_ot()
self._old_network = self._network.copy().freeze()
self._known_classes = self._total_classes
def solving_ot(self):
with torch.no_grad():
if self._total_classes == self.data_manager.get_total_classnum():
print("training over, no more ot solving")
return None
each_time_class_num = self.data_manager.get_task_size(1)
self._extract_class_means(
self.data_manager, 0, self._total_classes + each_time_class_num
)
former_class_means = torch.tensor(
self._ot_prototype_means[: self._total_classes]
)
next_period_class_means = torch.tensor(
self._ot_prototype_means[
self._total_classes : self._total_classes + each_time_class_num
]
)
Q_cost_matrix = torch.cdist(
former_class_means, next_period_class_means, p=self.args["norm_term"]
)
# solving ot
_mu1_vec = (
torch.ones(len(former_class_means)) / len(former_class_means) * 1.0
)
_mu2_vec = (
torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0
)
T = ot.sinkhorn(_mu1_vec, _mu2_vec, Q_cost_matrix, self.sinkhorn_reg)
T = torch.tensor(T).float().cuda()
transformed_hat_W = torch.mm(
T.T, F.normalize(self._network.fc.weight, p=2, dim=1)
)
oldnorm = torch.norm(self._network.fc.weight, p=2, dim=1)
newnorm = torch.norm(
transformed_hat_W * len(former_class_means), p=2, dim=1
)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew
self.calibration_term = gamma
self._ot_new_branch = (
transformed_hat_W * len(former_class_means) * self.calibration_term
)
return transformed_hat_W * len(former_class_means) * self.calibration_term
def solving_ot_to_old(self):
current_class_num = self.data_manager.get_task_size(self._cur_task)
self._extract_class_means_with_memory(
self.data_manager, self._known_classes, self._total_classes
)
former_class_means = torch.tensor(
self._ot_prototype_means[: self._known_classes]
)
next_period_class_means = torch.tensor(
self._ot_prototype_means[self._known_classes : self._total_classes]
)
Q_cost_matrix = (
torch.cdist(
next_period_class_means, former_class_means, p=self.args["norm_term"]
)
+ EPSILON
) # in case of numerical err
_mu1_vec = torch.ones(len(former_class_means)) / len(former_class_means) * 1.0
_mu2_vec = (
torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0
)
T = ot.sinkhorn(_mu2_vec, _mu1_vec, Q_cost_matrix, self.sinkhorn_reg)
T = torch.tensor(T).float().cuda()
transformed_hat_W = torch.mm(
T.T,
F.normalize(self._network.fc.weight[-current_class_num:, :], p=2, dim=1),
)
return transformed_hat_W * len(former_class_means) * self.calibration_term
def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self._network.update_fc(self._total_classes, self.nextperiod_initialization)
self.data_manager = data_manager
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
self.lamda = self._known_classes / self._total_classes
# Loader
train_dataset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
)
self.train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)
test_dataset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
self._train(self.train_loader, self.test_loader)
if self.args['fixed_memory']:
examplar_size = self.args["memory_per_class"]
else:
examplar_size = memory_size // self._total_classes
self._reduce_exemplar(data_manager, examplar_size)
self._construct_exemplar(data_manager, examplar_size)
def _train(self, train_loader, test_loader):
self._network.to(self._device)
if self._old_network is not None:
self._old_network.to(self._device)
optimizer = optim.SGD(
self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=5e-4
) # 1e-5
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=milestones, gamma=lrate_decay
)
self._update_representation(train_loader, test_loader, optimizer, scheduler)
def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(epochs))
for _, epoch in enumerate(prog_bar):
weight_ot_init = max(1.0 - (epoch / 2) ** 2, 0)
weight_ot_co_tuning = (epoch / epochs) ** 2.0
self._network.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(self._device), targets.to(self._device)
output = self._network(inputs)
logits = output["logits"]
onehots = target2onehot(targets, self._total_classes)
clf_loss = F.cross_entropy(logits, targets)
if self._old_network is not None:
old_logits = self._old_network(inputs)["logits"].detach()
hat_pai_k = F.softmax(old_logits / T, dim=1)
log_pai_k = F.log_softmax(
logits[:, : self._known_classes] / T, dim=1
)
distill_loss = -torch.mean(torch.sum(hat_pai_k * log_pai_k, dim=1))
if epoch < 1:
features = F.normalize(output["features"], p=2, dim=1)
current_logit_new = F.log_softmax(
logits[:, self._known_classes :] / T, dim=1
)
new_logit_by_wnew_init_by_ot = F.linear(
features, F.normalize(self._ot_new_branch, p=2, dim=1)
)
new_logit_by_wnew_init_by_ot = F.softmax(
new_logit_by_wnew_init_by_ot / T, dim=1
)
new_branch_distill_loss = -torch.mean(
torch.sum(
current_logit_new * new_logit_by_wnew_init_by_ot, dim=1
)
)
loss = (
distill_loss * self.lamda
+ clf_loss * (1 - self.lamda)
+ 0.001 * (weight_ot_init * new_branch_distill_loss)
)
else:
features = F.normalize(output["features"], p=2, dim=1)
if i % 30 == 0:
with torch.no_grad():
self._ot_old_branch = self.solving_ot_to_old()
old_logit_by_wold_init_by_ot = F.linear(
features, F.normalize(self._ot_old_branch, p=2, dim=1)
)
old_logit_by_wold_init_by_ot = F.log_softmax(
old_logit_by_wold_init_by_ot / T, dim=1
)
old_branch_distill_loss = -torch.mean(
torch.sum(hat_pai_k * old_logit_by_wold_init_by_ot, dim=1)
)
loss = (
distill_loss * self.lamda
+ clf_loss * (1 - self.lamda)
+ self.args["reg_term"]
* (weight_ot_co_tuning * old_branch_distill_loss)
)
else:
loss = clf_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
epochs,
losses / len(train_loader),
train_acc,
test_acc,
)
prog_bar.set_description(info)
logging.info(info)
def _extract_class_means(self, data_manager, low, high):
self._ot_prototype_means = np.zeros(
(data_manager.get_total_classnum(), self._network.feature_dim)
)
with torch.no_grad():
for class_idx in range(low, high):
data, targets, idx_dataset = data_manager.get_dataset(
np.arange(class_idx, class_idx + 1),
source="train",
mode="test",
ret_data=True,
)
idx_loader = DataLoader(
idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
vectors, _ = self._extract_vectors(idx_loader)
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
class_mean = np.mean(vectors, axis=0)
class_mean = class_mean / (np.linalg.norm(class_mean))
self._ot_prototype_means[class_idx, :] = class_mean
self._network.train()
def _extract_class_means_with_memory(self, data_manager, low, high):
self._ot_prototype_means = np.zeros(
(data_manager.get_total_classnum(), self._network.feature_dim)
)
memoryx, memoryy = self._data_memory, self._targets_memory
with torch.no_grad():
for class_idx in range(0, low):
idxes = np.where(
np.logical_and(memoryy >= class_idx, memoryy < class_idx + 1)
)[0]
data, targets = memoryx[idxes], memoryy[idxes]
# idx_dataset=TensorDataset(data,targets)
# idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
_, _, idx_dataset = data_manager.get_dataset(
[],
source="train",
appendent=(data, targets),
mode="test",
ret_data=True,
)
idx_loader = DataLoader(
idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
vectors, _ = self._extract_vectors(idx_loader)
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
class_mean = np.mean(vectors, axis=0)
class_mean = class_mean / np.linalg.norm(class_mean)
self._ot_prototype_means[class_idx, :] = class_mean
for class_idx in range(low, high):
data, targets, idx_dataset = data_manager.get_dataset(
np.arange(class_idx, class_idx + 1),
source="train",
mode="test",
ret_data=True,
)
idx_loader = DataLoader(
idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
vectors, _ = self._extract_vectors(idx_loader)
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
class_mean = np.mean(vectors, axis=0)
class_mean = class_mean / np.linalg.norm(class_mean)
self._ot_prototype_means[class_idx, :] = class_mean
self._network.train()