HungNP
New single commit message
cb80c28
raw
history blame contribute delete
No virus
10.7 kB
import copy
import logging
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from models.foster import FOSTER
from utils.toolkit import count_parameters, tensor2numpy, accuracy
from utils.inc_net import IncrementalNet
from scipy.spatial.distance import cdist
from models.base import BaseLearner
from models.icarl import iCaRL
from tqdm import tqdm
import torch.optim as optim
EPSILON = 1e-8
batch_size = 32
weight_decay = 2e-4
num_workers = 8
class RMMBase(BaseLearner):
def __init__(self, args):
self._args = args
self._m_rate_list = args.get("m_rate_list", [])
self._c_rate_list = args.get("c_rate_list", [])
@property
def samples_per_class(self):
return int(self.memory_size // self._total_classes)
@property
def memory_size(self):
if self._args["dataset"] == "cifar100":
img_per_cls = 500
else:
img_per_cls = 1300
if self._m_rate_list[self._cur_task] != 0:
print(self._total_classes)
self._memory_size = min(int(self._total_classes*img_per_cls-1),self._args["memory_size"] + int(
self._m_rate_list[self._cur_task]
* self._args["increment"]
* img_per_cls
))
return self._memory_size
@property
def new_memory_size(self):
if self._args["dataset"] == "cifar100":
img_per_cls = 500
else:
img_per_cls = 1300
return int(
(1 - self._m_rate_list[self._cur_task])
* self._args["increment"]
* img_per_cls
)
def build_rehearsal_memory(self, data_manager, per_class):
self._reduce_exemplar(data_manager, per_class)
self._construct_exemplar(data_manager, per_class)
def _construct_exemplar(self, data_manager, m):
if self._args["dataset"] == "cifar100":
img_per_cls = 500
else:
img_per_cls = 1300
ns = [
min(img_per_cls,int(m * (1 - self._c_rate_list[self._cur_task]))),
min(img_per_cls,int(m * (1 + self._c_rate_list[self._cur_task]))),
]
logging.info(
"Constructing exemplars...({} or {} per classes)".format(ns[0], ns[1])
)
all_cls_entropies = []
ms = []
for class_idx in range(self._known_classes, self._total_classes):
data, targets, idx_dataset = data_manager.get_dataset(
np.arange(class_idx, class_idx + 1),
source="train",
mode="test",
ret_data=True,
)
idx_loader = DataLoader(
idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
with torch.no_grad():
cidx_cls_entropies = []
for idx, (_, inputs, targets) in enumerate(idx_loader):
inputs, targets = inputs.to(self._device), targets.to(self._device)
logits = self._network(inputs)["logits"]
cross_entropy = (
F.cross_entropy(logits, targets, reduction="none")
.detach()
.cpu()
.numpy()
)
cidx_cls_entropies.append(cross_entropy)
# print(cidx_cls_entropies)
cidx_cls_entropies = np.mean(np.concatenate(cidx_cls_entropies))
all_cls_entropies.append(cidx_cls_entropies)
entropy_median = np.median(all_cls_entropies)
for the_entropy in all_cls_entropies:
if the_entropy > entropy_median:
ms.append(ns[0])
else:
ms.append(ns[1])
logging.info(f"ms: {ms}")
for class_idx in range(self._known_classes, self._total_classes):
data, targets, idx_dataset = data_manager.get_dataset(
np.arange(class_idx, class_idx + 1),
source="train",
mode="test",
ret_data=True,
)
idx_loader = DataLoader(
idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
vectors, _ = self._extract_vectors(idx_loader)
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
class_mean = np.mean(vectors, axis=0)
# Select
selected_exemplars = []
exemplar_vectors = [] # [n, feature_dim]
for k in range(1, ms[class_idx - self._known_classes] + 1):
S = np.sum(
exemplar_vectors, axis=0
) # [feature_dim] sum of selected exemplars vectors
mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors
i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
selected_exemplars.append(
np.array(data[i])
) # New object to avoid passing by inference
exemplar_vectors.append(
np.array(vectors[i])
) # New object to avoid passing by inference
vectors = np.delete(
vectors, i, axis=0
) # Remove it to avoid duplicative selection
data = np.delete(
data, i, axis=0
) # Remove it to avoid duplicative selection
selected_exemplars = np.array(selected_exemplars)
exemplar_targets = np.full(ms[class_idx - self._known_classes], class_idx)
self._data_memory = (
np.concatenate((self._data_memory, selected_exemplars))
if len(self._data_memory) != 0
else selected_exemplars
)
self._targets_memory = (
np.concatenate((self._targets_memory, exemplar_targets))
if len(self._targets_memory) != 0
else exemplar_targets
)
# Exemplar mean
idx_dataset = data_manager.get_dataset(
[],
source="train",
mode="test",
appendent=(selected_exemplars, exemplar_targets),
)
idx_loader = DataLoader(
idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
vectors, _ = self._extract_vectors(idx_loader)
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
mean = np.mean(vectors, axis=0)
mean = mean / np.linalg.norm(mean)
self._class_means[class_idx, :] = mean
class RMM_iCaRL(
RMMBase, iCaRL
): # RMM Base is supposed to be prior to the orginal method.
def __init__(self, args):
RMMBase.__init__(self, args)
iCaRL.__init__(self, args)
def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self._network.update_fc(self._total_classes)
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
train_dataset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
m_rate=self._m_rate_list[self._cur_task] if self._cur_task > 0 else None,
)
self.train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
test_dataset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
if len(self._multiple_gpus) > 1:
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._train(self.train_loader, self.test_loader)
self.build_rehearsal_memory(data_manager, self.samples_per_class)
if len(self._multiple_gpus) > 1:
self._network = self._network.module
class RMM_FOSTER(RMMBase, FOSTER):
def __init__(self, args):
RMMBase.__init__(self, args)
FOSTER.__init__(self, args)
def incremental_train(self, data_manager):
self.data_manager = data_manager
self._cur_task += 1
if self._cur_task > 1:
self._network = self._snet
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self._network.update_fc(self._total_classes)
self._network_module_ptr = self._network
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
if self._cur_task > 0:
for p in self._network.convnets[0].parameters():
p.requires_grad = False
for p in self._network.oldfc.parameters():
p.requires_grad = False
logging.info("All params: {}".format(count_parameters(self._network)))
logging.info(
"Trainable params: {}".format(count_parameters(self._network, True))
)
train_dataset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
m_rate=self._m_rate_list[self._cur_task] if self._cur_task > 0 else None,
)
self.train_loader = DataLoader(
train_dataset,
batch_size=self.args["batch_size"],
shuffle=True,
num_workers=self.args["num_workers"],
pin_memory=True,
)
test_dataset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.test_loader = DataLoader(
test_dataset,
batch_size=self.args["batch_size"],
shuffle=False,
num_workers=self.args["num_workers"],
)
if len(self._multiple_gpus) > 1:
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._train(self.train_loader, self.test_loader)
self.build_rehearsal_memory(data_manager, self.samples_per_class)
if len(self._multiple_gpus) > 1:
self._network = self._network.module