Spaces:
Sleeping
Sleeping
import copy | |
import logging | |
import numpy as np | |
from tqdm import tqdm | |
import torch | |
from torch import nn | |
from torch import optim | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from models.base import BaseLearner | |
from utils.inc_net import BEEFISONet | |
from utils.toolkit import count_parameters, target2onehot, tensor2numpy | |
EPSILON = 1e-8 | |
class BEEFISO(BaseLearner): | |
def __init__(self, args): | |
super().__init__(args) | |
self.args = args | |
self._network = BEEFISONet(args, False) | |
self._snet = None | |
self.logits_alignment = args["logits_alignment"] | |
self.val_loader = None | |
self.reduce_batch_size = args["reduce_batch_size"] | |
self.random = args.get("random",None) | |
self.imbalance = args.get("imbalance",None) | |
def after_task(self): | |
self._network_module_ptr.update_fc_after() | |
self._known_classes = self._total_classes | |
if self.reduce_batch_size: | |
if self._cur_task == 0: | |
self.args["batch_size"] = self.args["batch_size"] | |
else: | |
self.args["batch_size"] = self.args["batch_size"] * (self._cur_task+1) // (self._cur_task+2) | |
logging.info("Exemplar size: {}".format(self.exemplar_size)) | |
def incremental_train(self, data_manager): | |
self.data_manager = data_manager | |
self._cur_task += 1 | |
if self._cur_task > 1 and self.args["is_compress"]: | |
self._network = self._snet | |
self._total_classes = self._known_classes + data_manager.get_task_size( | |
self._cur_task | |
) | |
self._network.update_fc_before(self._total_classes) | |
self._network_module_ptr = self._network | |
logging.info( | |
"Learning on {}-{}".format(self._known_classes, self._total_classes) | |
) | |
if self._cur_task > 0: | |
for id in range(self._cur_task): | |
for p in self._network.convnets[id].parameters(): | |
p.requires_grad = False | |
for p in self._network.old_fc.parameters(): | |
p.requires_grad = False | |
logging.info("All params: {}".format(count_parameters(self._network))) | |
logging.info( | |
"Trainable params: {}".format(count_parameters(self._network, True)) | |
) | |
train_dataset = data_manager.get_dataset( | |
np.arange(self._known_classes, self._total_classes), | |
source="train", | |
mode="train", | |
appendent=self._get_memory(), | |
) | |
self.train_loader = DataLoader( | |
train_dataset, | |
batch_size=self.args["batch_size"], | |
shuffle=True, | |
num_workers=self.args["num_workers"], | |
pin_memory=True, | |
) | |
test_dataset = data_manager.get_dataset( | |
np.arange(0, self._total_classes), source="test", mode="test" | |
) | |
self.test_loader = DataLoader( | |
test_dataset, | |
batch_size=self.args["batch_size"], | |
shuffle=False, | |
num_workers=self.args["num_workers"], | |
pin_memory=True, | |
) | |
if self._cur_task > 0: | |
if self.random or self.imbalance: | |
val_dset = data_manager.get_finetune_dataset(known_classes=self._known_classes, total_classes=self._total_classes, | |
source="train", mode='train', appendent=self._get_memory(), type="ratio") | |
else: | |
_, val_dset = data_manager.get_dataset_with_split(np.arange(self._known_classes, self._total_classes), | |
source='train', mode='train', | |
appendent=self._get_memory(), | |
val_samples_per_class=int( | |
self.samples_old_class)) | |
self.val_loader = DataLoader( | |
val_dset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) | |
if len(self._multiple_gpus) > 1: | |
self._network = nn.DataParallel(self._network, self._multiple_gpus) | |
self._train(self.train_loader, self.test_loader,self.val_loader) | |
if self.random or self.imbalance: | |
self.build_rehearsal_memory_imbalance(data_manager,self.samples_per_class) | |
else: | |
self.build_rehearsal_memory(data_manager, self.samples_per_class) | |
if len(self._multiple_gpus) > 1: | |
self._network = self._network.module | |
def train(self): | |
self._network_module_ptr.train() | |
self._network_module_ptr.convnets[-1].train() | |
if self._cur_task >= 1: | |
self._network_module_ptr.convnets[0].eval() | |
def _train(self, train_loader, test_loader, val_loader=None): | |
self._network.to(self._device) | |
if hasattr(self._network, "module"): | |
self._network_module_ptr = self._network.module | |
if self._cur_task == 0: | |
optimizer = optim.SGD( | |
filter(lambda p: p.requires_grad, self._network.parameters()), | |
momentum=0.9, | |
lr=self.args["init_lr"], | |
weight_decay=self.args["init_weight_decay"], | |
) | |
scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
optimizer=optimizer, T_max=self.args["init_epochs"] | |
) | |
self.epochs = self.args["init_epochs"] | |
self._init_train(train_loader, test_loader, optimizer, scheduler) | |
else: | |
optimizer = optim.SGD( | |
filter(lambda p: p.requires_grad, self._network.parameters()), | |
lr=self.args["lr"], | |
momentum=0.9, | |
weight_decay=self.args["weight_decay"], | |
) | |
scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
optimizer=optimizer, T_max=self.args["expansion_epochs"] | |
) | |
self.epochs = self.args["expansion_epochs"] | |
self.state = "expansion" | |
if len(self._multiple_gpus) > 1: | |
network = self._network.module | |
else: | |
network = self._network | |
for p in network.biases.parameters(): | |
p.requires_grad = False | |
self._expansion(train_loader, test_loader, optimizer, scheduler) | |
for p in self._network_module_ptr.forward_prototypes.parameters(): | |
p.requires_grad = False | |
for p in self._network_module_ptr.backward_prototypes.parameters(): | |
p.requires_grad = False | |
for p in self._network_module_ptr.new_fc.parameters(): | |
p.requires_grad = False | |
for p in self._network_module_ptr.convnets[-1].parameters(): | |
p.requires_grad = False | |
for p in self._network.biases.parameters(): | |
p.requires_grad = True | |
self.state = "fusion" | |
self.epochs = self.args["fusion_epochs"] | |
self.per_cls_weights = torch.ones(self._total_classes).to(self._device) | |
optimizer = optim.SGD( | |
filter(lambda p: p.requires_grad, self._network.parameters()), | |
lr=0.05, | |
momentum=0.9, | |
weight_decay=self.args["weight_decay"], | |
) | |
for n, p in self._network.named_parameters(): | |
if p.requires_grad == True: | |
print(n) | |
scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
optimizer=optimizer, T_max=self.args["fusion_epochs"] | |
) | |
self._fusion(val_loader,test_loader,optimizer,scheduler) | |
def _init_train(self, train_loader, test_loader, optimizer, scheduler): | |
prog_bar = tqdm(range(self.epochs)) | |
for _, epoch in enumerate(prog_bar): | |
self.train() | |
losses = 0.0 | |
losses_en = 0.0 | |
correct, total = 0, 0 | |
for i, (_, inputs, targets) in enumerate(train_loader): | |
inputs, targets = inputs.to( | |
self._device, non_blocking=True | |
), targets.to(self._device, non_blocking=True) | |
logits = self._network(inputs)["logits"] | |
loss_en = self.args["energy_weight"] * self.get_energy_loss(inputs,targets,targets) | |
loss = F.cross_entropy(logits, targets) | |
loss = loss + loss_en | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
losses += loss.item() | |
losses_en += loss_en.item() | |
_, preds = torch.max(logits, dim=1) | |
correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
total += len(targets) | |
scheduler.step() | |
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
if epoch % 5 == 0: | |
test_acc = self._compute_accuracy(self._network, test_loader) | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
self.args["init_epochs"], | |
losses / len(train_loader), | |
losses_en / len(train_loader), | |
train_acc, | |
test_acc, | |
) | |
else: | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
self.args["init_epochs"], | |
losses / len(train_loader), | |
losses_en / len(train_loader), | |
train_acc, | |
) | |
prog_bar.set_description(info) | |
logging.info(info) | |
def _expansion(self, train_loader, test_loader, optimizer, scheduler): | |
prog_bar = tqdm(range(self.epochs)) | |
for _, epoch in enumerate(prog_bar): | |
self.train() | |
losses = 0.0 | |
losses_clf = 0.0 | |
losses_fe = 0.0 | |
losses_en = 0.0 | |
correct, total = 0, 0 | |
for i, (_, inputs, targets) in enumerate(train_loader): | |
inputs, targets = inputs.to( | |
self._device, non_blocking=True | |
), targets.to(self._device, non_blocking=True) | |
targets = targets.float() | |
outputs = self._network(inputs) | |
logits,train_logits = ( | |
outputs["logits"], | |
outputs["train_logits"] | |
) | |
pseudo_targets = targets.clone() | |
for task_id in range(self._cur_task+1): | |
if task_id == 0: | |
pseudo_targets = torch.where(targets<self.data_manager.get_accumulate_tasksize(task_id),torch.Tensor([task_id]).float().to(self._device),pseudo_targets) | |
elif task_id == self._cur_task: | |
pseudo_targets = torch.where(targets-self._known_classes+1>0,targets-self._known_classes+task_id,pseudo_targets) | |
else: | |
pseudo_targets = torch.where((targets<self.data_manager.get_accumulate_tasksize(task_id)) & (targets>self.data_manager.get_accumulate_tasksize(task_id-1)-1),task_id,pseudo_targets) | |
train_logits[:, list(range(self._cur_task))] /= self.logits_alignment | |
loss_clf = F.cross_entropy(train_logits.float(), pseudo_targets) | |
loss_fe = torch.tensor(0.).cuda() | |
loss_en = self.args["energy_weight"] * self.get_energy_loss(inputs,targets,pseudo_targets) | |
loss = loss_clf + loss_fe + loss_en | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
losses += loss.item() | |
losses_fe += loss_fe.item() | |
losses_clf += loss_clf.item() | |
losses_en += loss_en.item() | |
_, preds = torch.max(logits, dim=1) | |
correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
total += len(targets) | |
scheduler.step() | |
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
if epoch % 5 == 0: | |
test_acc = self._compute_accuracy(self._network, test_loader) | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
self.epochs, | |
losses / len(train_loader), | |
losses_clf / len(train_loader), | |
losses_fe / len(train_loader), | |
losses_en / len(train_loader), | |
train_acc, | |
test_acc, | |
) | |
else: | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
self.epochs, | |
losses / len(train_loader), | |
losses_clf / len(train_loader), | |
losses_fe / len(train_loader), | |
losses_en / len(train_loader), | |
train_acc, | |
) | |
prog_bar.set_description(info) | |
logging.info(info) | |
def _fusion(self, train_loader, test_loader, optimizer, scheduler): | |
prog_bar = tqdm(range(self.epochs)) | |
for _, epoch in enumerate(prog_bar): | |
self.train() | |
# self. | |
losses = 0.0 | |
losses_clf = 0.0 | |
losses_fe = 0.0 | |
losses_kd = 0.0 | |
correct, total = 0, 0 | |
for i, (_, inputs, targets) in enumerate(train_loader): | |
inputs, targets = inputs.to( | |
self._device, non_blocking=True | |
), targets.to(self._device, non_blocking=True) | |
outputs = self._network(inputs) | |
logits,train_logits = ( | |
outputs["logits"], | |
outputs["train_logits"] | |
) | |
loss_clf = F.cross_entropy(logits,targets) | |
loss_fe = torch.tensor(0.).cuda() | |
loss_kd = torch.tensor(0.).cuda() | |
loss = loss_clf + loss_fe + loss_kd | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
losses += loss.item() | |
losses_fe += loss_fe.item() | |
losses_clf += loss_clf.item() | |
losses_kd += ( | |
self._known_classes / self._total_classes | |
) * loss_kd.item() | |
_, preds = torch.max(logits, dim=1) | |
correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
total += len(targets) | |
scheduler.step() | |
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
if epoch % 5 == 0: | |
test_acc = self._compute_accuracy(self._network, test_loader) | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
self.epochs, | |
losses / len(train_loader), | |
losses_clf / len(train_loader), | |
losses_fe / len(train_loader), | |
losses_kd / len(train_loader), | |
train_acc, | |
test_acc, | |
) | |
else: | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
self.epochs, | |
losses / len(train_loader), | |
losses_clf / len(train_loader), | |
losses_fe / len(train_loader), | |
losses_kd / len(train_loader), | |
train_acc, | |
) | |
prog_bar.set_description(info) | |
logging.info(info) | |
def samples_old_class(self): | |
if self._fixed_memory: | |
return self._memory_per_class | |
else: | |
assert self._total_classes != 0, "Total classes is 0" | |
return self._memory_size // self._known_classes | |
def samples_new_class(self, index): | |
if self.args["dataset"] == "cifar100": | |
return 500 | |
else: | |
return self.data_manager.getlen(index) | |
def BKD(self, pred, soft, T): | |
pred = torch.log_softmax(pred / T, dim=1) | |
soft = torch.softmax(soft / T, dim=1) | |
soft = soft * self.per_cls_weights | |
soft = soft / soft.sum(1)[:, None] | |
return -1 * torch.mul(soft, pred).sum() / pred.shape[0] | |
def get_energy_loss(self,inputs,targets,pseudo_targets): | |
inputs = self.sample_q(inputs) | |
out = self._network(inputs) | |
if self._cur_task == 0: | |
targets = targets + self._total_classes | |
train_logits, energy_logits = out["logits"], out["energy_logits"] | |
else: | |
targets = targets + (self._total_classes - self._known_classes) + self._cur_task | |
train_logits, energy_logits = out["train_logits"], out["energy_logits"] | |
logits = torch.cat([train_logits,energy_logits],dim=1) | |
logits[:,pseudo_targets] = 1e-9 | |
energy_loss = F.cross_entropy(logits,targets) | |
return energy_loss | |
def sample_q(self, replay_buffer, n_steps=3): | |
"""this func takes in replay_buffer now so we have the option to sample from | |
scratch (i.e. replay_buffer==[]). See test_wrn_ebm.py for example. | |
""" | |
self._network_copy = self._network_module_ptr.copy().freeze() | |
init_sample = replay_buffer | |
init_sample = torch.rot90(init_sample, 2, (2, 3)) | |
embedding_k = init_sample.clone().detach().requires_grad_(True) | |
optimizer_gen = torch.optim.SGD( | |
[embedding_k], lr=1e-2) | |
for k in range(1, n_steps + 1): | |
out = self._network_copy(embedding_k) | |
if self._cur_task == 0: | |
energy_logits, train_logits = out["energy_logits"], out["logits"] | |
else: | |
energy_logits, train_logits = out["energy_logits"], out["train_logits"] | |
num_forwards = energy_logits.shape[1] | |
logits = torch.cat([train_logits,energy_logits],dim=1) | |
negative_energy = torch.log(torch.sum(torch.softmax(logits,dim=1)[:,-num_forwards:])) | |
optimizer_gen.zero_grad() | |
negative_energy.sum().backward() | |
optimizer_gen.step() | |
embedding_k.data += 1e-3 * \ | |
torch.randn_like(embedding_k) | |
final_samples = embedding_k.detach() | |
return final_samples | |
def build_rehearsal_memory_imbalance(self, data_manager, per_class): | |
if self._fixed_memory: | |
self._construct_exemplar_unified_imbalance(data_manager, per_class,self.random,self.imbalance) | |
else: | |
self._reduce_exemplar_imbalance(data_manager, per_class,self.random,self.imbalance) | |
self._construct_exemplar_imbalance(data_manager, per_class,self.random,self.imbalance) | |
def _reduce_exemplar_imbalance(self, data_manager, m,random,imbalance): | |
logging.info('Reducing exemplars...({} per classes)'.format(m)) | |
dummy_data, dummy_targets = copy.deepcopy(self._data_memory), copy.deepcopy(self._targets_memory) | |
self._class_means = np.zeros((self._total_classes, self.feature_dim)) | |
self._data_memory, self._targets_memory = np.array([]), np.array([]) | |
for class_idx in range(self._known_classes): | |
mask = np.where(dummy_targets == class_idx)[0] | |
l = sum(mask) | |
if l == 0: | |
continue | |
if random or imbalance is not None: | |
dd, dt = dummy_data[mask][:-1], dummy_targets[mask][:-1] | |
else: | |
dd, dt = dummy_data[mask][:m], dummy_targets[mask][:m] | |
self._data_memory = np.concatenate((self._data_memory, dd)) if len(self._data_memory) != 0 else dd | |
self._targets_memory = np.concatenate((self._targets_memory, dt)) if len(self._targets_memory) != 0 else dt | |
# Exemplar mean | |
idx_dataset = data_manager.get_dataset([], source='train', mode='test', appendent=(dd, dt)) | |
idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) | |
vectors, _ = self._extract_vectors(idx_loader) | |
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T | |
mean = np.mean(vectors, axis=0) | |
mean = mean / np.linalg.norm(mean) | |
self._class_means[class_idx, :] = mean | |
def _construct_exemplar_imbalance(self, data_manager, m, random=False,imbalance=None): | |
increment = self._total_classes - self._known_classes | |
if random: | |
''' | |
uniform random type | |
''' | |
selected_exemplars = [] | |
selected_targets = [] | |
logging.info("Contructing exmplars, totally random...({} total instances {} classes)".format(increment*m, increment)) | |
data, targets, idx_dataset = data_manager.get_dataset(np.arange(self._known_classes,self._total_classes),source="train",mode="test",ret_data=True) | |
selected_indices = np.random.choice(list(range(len(data))),m*increment,repladce=False) | |
for idx in selected_indices: | |
selected_exemplars.append(data[idx]) | |
selected_targets.append(targets[idx]) | |
selected_exemplars = np.array(selected_exemplars)[:m*increment] | |
selected_targets = np.array(selected_targets)[:m*increment] | |
self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ | |
else selected_exemplars | |
self._targets_memory = np.concatenate((self._targets_memory, selected_targets)) if \ | |
len(self._targets_memory) != 0 else selected_targets | |
else: | |
if imbalance is None: | |
logging.info('Constructing exemplars...({} per classes)'.format(m)) | |
ms = np.ones(increment,dtype=int)*m | |
elif imbalance>=1: | |
''' | |
half-half type | |
''' | |
ms=[m for _ in range(increment)] | |
for i in range(increment//2): | |
ms[i]-=m//imbalance | |
for i in range(increment//2,increment): | |
ms[i]+=m//imbalance | |
np.random.shuffle(ms) | |
ms = np.array(ms,dtype=int) | |
logging.info("Constructing exmplars, Imbalance...({} or {} per classes)".format(m-m//imbalance,(m+m//imbalance))) | |
elif imbalance<1: | |
''' | |
exp type | |
''' | |
ms = np.array([imbalance**i for i in range(increment)]) | |
ms = ms/ms.sum() | |
tot = m*increment | |
ms = (tot*ms).astype(int) | |
np.random.shuffle(ms) | |
else: | |
assert 0, "not implemented yet" | |
logging.info("ms {}".format(ms)) | |
for class_idx in range(self._known_classes, self._total_classes): | |
data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', | |
mode='test', ret_data=True) | |
idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) | |
vectors, _ = self._extract_vectors(idx_loader) | |
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T | |
class_mean = np.mean(vectors, axis=0) | |
# Select | |
selected_exemplars = [] | |
exemplar_vectors = [] # [n, feature_dim] | |
for k in range(1, ms[class_idx-self._known_classes]+1): | |
S = np.sum(exemplar_vectors, axis=0) # [feature_dim] sum of selected exemplars vectors | |
mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors | |
i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) | |
selected_exemplars.append(np.array(data[i])) # New object to avoid passing by inference | |
exemplar_vectors.append(np.array(vectors[i])) # New object to avoid passing by inference | |
vectors = np.delete(vectors, i, axis=0) # Remove it to avoid duplicative selection | |
data = np.delete(data, i, axis=0) # Remove it to avoid duplicative selection | |
# uniques = np.unique(selected_exemplars, axis=0) | |
selected_exemplars = np.array(selected_exemplars) | |
if len(selected_exemplars)==0: | |
continue | |
exemplar_targets = np.full(ms[class_idx-self._known_classes], class_idx) | |
self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ | |
else selected_exemplars | |
self._targets_memory = np.concatenate((self._targets_memory, exemplar_targets)) if \ | |
len(self._targets_memory) != 0 else exemplar_targets | |
# Exemplar mean | |
idx_dataset = data_manager.get_dataset([], source='train', mode='test', | |
appendent=(selected_exemplars, exemplar_targets)) | |
idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4,pin_memory=True) | |
vectors, _ = self._extract_vectors(idx_loader) | |
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T | |
mean = np.mean(vectors, axis=0) | |
mean = mean / np.linalg.norm(mean) | |
self._class_means[class_idx, :] = mean | |
# self._class_means[class_idx, :] = class_mean | |
def _construct_exemplar_unified_imbalance(self, data_manager, m,random,imbalance): | |
logging.info('Constructing exemplars for new classes...({} per classes)'.format(m)) | |
_class_means = np.zeros((self._total_classes, self.feature_dim)) | |
increment = self._total_classes - self._known_classes | |
# Calculate the means of old classes with newly trained network | |
for class_idx in range(self._known_classes): | |
mask = np.where(self._targets_memory == class_idx)[0] | |
if sum(mask) == 0: continue | |
class_data, class_targets = self._data_memory[mask], self._targets_memory[mask] | |
class_dset = data_manager.get_dataset([], source='train', mode='test', | |
appendent=(class_data, class_targets)) | |
class_loader = DataLoader(class_dset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) | |
vectors, _ = self._extract_vectors(class_loader) | |
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T | |
mean = np.mean(vectors, axis=0) | |
mean = mean / np.linalg.norm(mean) | |
_class_means[class_idx, :] = mean | |
if random: | |
''' | |
uniform sample type | |
''' | |
selected_exemplars = [] | |
selected_targets = [] | |
logging.info("Contructing exmplars, totally random...({} total instances {} classes)".format(increment*m, increment)) | |
data, targets, idx_dataset = data_manager.get_dataset(np.arange(self._known_classes,self._total_classes),source="train",mode="test",ret_data=True) | |
selected_indices = np.random.choice(list(range(len(data))),m*increment,replace=False) | |
for idx in selected_indices: | |
selected_exemplars.append(data[idx]) | |
selected_targets.append(targets[idx]) | |
selected_exemplars = np.array(selected_exemplars) | |
selected_targets = np.array(selected_targets) | |
self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ | |
else selected_exemplars | |
self._targets_memory = np.concatenate((self._targets_memory, selected_targets)) if \ | |
len(self._targets_memory) != 0 else selected_targets | |
else: | |
if imbalance is None: | |
logging.info('Constructing exemplars...({} per classes)'.format(m)) | |
ms = np.ones(increment,dtype=int)*m | |
elif imbalance>=1: | |
''' | |
half-half type | |
''' | |
ms=[m for _ in range(increment)] | |
for i in range(increment//2): | |
ms[i]-=m//imbalance | |
for i in range(increment//2,increment): | |
ms[i]+=m//imbalance | |
np.random.shuffle(ms) | |
ms = np.array(ms,dtype=int) | |
logging.info("Constructing exmplars, Imbalance...({} or {} per classes)".format(m-m//imbalance,(m+m//imbalance))) | |
elif imbalance<1: | |
''' | |
exp type | |
''' | |
ms = np.array([imbalance**i for i in range(increment)]) | |
ms = ms/ms.sum() | |
tot = m*increment | |
ms = (tot*ms).astype(int) | |
np.random.shuffle(ms) | |
else: | |
assert 0, "not implemented yet" | |
logging.info("ms {}".format(ms)) | |
# Construct exemplars for new classes and calculate the means | |
for class_idx in range(self._known_classes, self._total_classes): | |
data, targets, class_dset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', | |
mode='test', ret_data=True) | |
class_loader = DataLoader(class_dset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4,pin_memory=True) | |
vectors, _ = self._extract_vectors(class_loader) | |
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T | |
class_mean = np.mean(vectors, axis=0) | |
# Select | |
selected_exemplars = [] | |
exemplar_vectors = [] | |
for k in range(1, ms[class_idx-self._known_classes]+1): | |
S = np.sum(exemplar_vectors, axis=0) # [feature_dim] sum of selected exemplars vectors | |
mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors | |
i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) | |
selected_exemplars.append(np.array(data[i])) # New object to avoid passing by inference | |
exemplar_vectors.append(np.array(vectors[i])) # New object to avoid passing by inference | |
vectors = np.delete(vectors, i, axis=0) # Remove it to avoid duplicative selection | |
data = np.delete(data, i, axis=0) # Remove it to avoid duplicative selection | |
selected_exemplars = np.array(selected_exemplars) | |
if len(selected_exemplars)==0: | |
continue | |
exemplar_targets = np.full(ms[class_idx-self._known_classes], class_idx) | |
self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ | |
else selected_exemplars | |
self._targets_memory = np.concatenate((self._targets_memory, exemplar_targets)) if \ | |
len(self._targets_memory) != 0 else exemplar_targets | |
# Exemplar mean | |
exemplar_dset = data_manager.get_dataset([], source='train', mode='test', | |
appendent=(selected_exemplars, exemplar_targets)) | |
exemplar_loader = DataLoader(exemplar_dset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) | |
vectors, _ = self._extract_vectors(exemplar_loader) | |
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T | |
mean = np.mean(vectors, axis=0) | |
mean = mean / np.linalg.norm(mean) | |
_class_means[class_idx, :] = mean | |
# _class_means[class_idx,:] = class_mean | |
self._class_means = _class_means | |