HungNP
New single commit message
cb80c28
raw
history blame contribute delete
No virus
10.9 kB
'''
results on CIFAR-100:
| Reported Resnet18 | Reproduced Resnet32
Protocols | Reported FC | Reported SVM | Reproduced FC | Reproduced SVM |
T = 5 | 64.7 | 66.3 | 65.775 | 65.375 |
T = 10 | 63.4 | 65.2 | 64.91 | 65.10 |
T = 60 | 50.8 | 59.8 | 62.09 | 61.72 |
'''
import logging
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader,Dataset
from models.base import BaseLearner
from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet
from utils.toolkit import count_parameters, target2onehot, tensor2numpy
from sklearn.svm import LinearSVC
from torchvision import datasets, transforms
from utils.autoaugment import CIFAR10Policy,ImageNetPolicy
from utils.ops import Cutout
EPSILON = 1e-8
class FeTrIL(BaseLearner):
def __init__(self, args):
super().__init__(args)
self.args = args
self._network = IncrementalNet(args, False)
self._means = []
self._svm_accs = []
def after_task(self):
self._known_classes = self._total_classes
def incremental_train(self, data_manager):
self.data_manager = data_manager
self.data_manager._train_trsf = [
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=63/255),
ImageNetPolicy(),
Cutout(n_holes=1, length=16),
]
self._cur_task += 1
self._total_classes = self._known_classes + \
data_manager.get_task_size(self._cur_task)
self._network.update_fc(self._total_classes)
self._network_module_ptr = self._network
logging.info(
'Learning on {}-{}'.format(self._known_classes, self._total_classes))
if self._cur_task > 0:
for p in self._network.convnet.parameters():
p.requires_grad = False
logging.info('All params: {}'.format(count_parameters(self._network)))
logging.info('Trainable params: {}'.format(
count_parameters(self._network, True)))
train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train',
mode='train', appendent=self._get_memory())
self.train_loader = DataLoader(
train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True)
test_dataset = data_manager.get_dataset(
np.arange(0, self._total_classes), source='test', mode='test')
self.test_loader = DataLoader(
test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"])
if len(self._multiple_gpus) > 1:
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._train(self.train_loader, self.test_loader)
if len(self._multiple_gpus) > 1:
self._network = self._network.module
def _train(self, train_loader, test_loader):
self._network.to(self._device)
if hasattr(self._network, "module"):
self._network_module_ptr = self._network.module
if self._cur_task == 0:
self._epoch_num = self.args["init_epochs"]
optimizer = optim.SGD(filter(lambda p: p.requires_grad, self._network.parameters(
)), momentum=0.9, lr=self.args["init_lr"], weight_decay=self.args["init_weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=self.args["init_epochs"])
self._train_function(train_loader, test_loader, optimizer, scheduler)
self._compute_means()
self._build_feature_set()
else:
self._epoch_num = self.args["epochs"]
self._compute_means()
self._compute_relations()
self._build_feature_set()
train_loader = DataLoader(self._feature_trainset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True)
optimizer = optim.SGD(self._network_module_ptr.fc.parameters(),momentum=0.9,lr=self.args["lr"],weight_decay=self.args["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max = self.args["epochs"])
self._train_function(train_loader, test_loader, optimizer, scheduler)
self._train_svm(self._feature_trainset,self._feature_testset)
def _compute_means(self):
with torch.no_grad():
for class_idx in range(self._known_classes, self._total_classes):
data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
mode='test', ret_data=True)
idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
vectors, _ = self._extract_vectors(idx_loader)
class_mean = np.mean(vectors, axis=0)
self._means.append(class_mean)
def _compute_relations(self):
old_means = np.array(self._means[:self._known_classes])
new_means = np.array(self._means[self._known_classes:])
self._relations=np.argmax((old_means/np.linalg.norm(old_means,axis=1)[:,None])@(new_means/np.linalg.norm(new_means,axis=1)[:,None]).T,axis=1)+self._known_classes
def _build_feature_set(self):
self.vectors_train = []
self.labels_train = []
for class_idx in range(self._known_classes, self._total_classes):
data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
mode='test', ret_data=True)
idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
vectors, _ = self._extract_vectors(idx_loader)
self.vectors_train.append(vectors)
self.labels_train.append([class_idx]*len(vectors))
for class_idx in range(0,self._known_classes):
new_idx = self._relations[class_idx]
self.vectors_train.append(self.vectors_train[new_idx-self._known_classes]-self._means[new_idx]+self._means[class_idx])
self.labels_train.append([class_idx]*len(self.vectors_train[-1]))
self.vectors_train = np.concatenate(self.vectors_train)
self.labels_train = np.concatenate(self.labels_train)
self._feature_trainset = FeatureDataset(self.vectors_train,self.labels_train)
self.vectors_test = []
self.labels_test = []
for class_idx in range(0, self._total_classes):
data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='test',
mode='test', ret_data=True)
idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
vectors, _ = self._extract_vectors(idx_loader)
self.vectors_test.append(vectors)
self.labels_test.append([class_idx]*len(vectors))
self.vectors_test = np.concatenate(self.vectors_test)
self.labels_test = np.concatenate(self.labels_test)
self._feature_testset = FeatureDataset(self.vectors_test,self.labels_test)
def _train_function(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(self._epoch_num))
for _, epoch in enumerate(prog_bar):
if self._cur_task == 0:
self._network.train()
else:
self._network.eval()
losses = 0.
correct, total = 0, 0
for i, _, inputs, targets in enumerate(train_loader):
inputs, targets = inputs.to(
self._device, non_blocking=True), targets.to(self._device, non_blocking=True)
if self._cur_task ==0:
logits = self._network(inputs)['logits']
else:
logits = self._network_module_ptr.fc(inputs)['logits']
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(
correct)*100 / total, decimals=2)
if epoch % 5 != 0:
info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format(
self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), train_acc)
else:
test_acc = self._compute_accuracy(self._network, test_loader)
info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), train_acc, test_acc)
prog_bar.set_description(info)
logging.info(info)
def _train_svm(self,train_set,test_set):
train_features = train_set.features.numpy()
train_labels = train_set.labels.numpy()
test_features = test_set.features.numpy()
test_labels = test_set.labels.numpy()
train_features = train_features/np.linalg.norm(train_features,axis=1)[:,None]
test_features = test_features/np.linalg.norm(test_features,axis=1)[:,None]
svm_classifier = LinearSVC(random_state=42)
svm_classifier.fit(train_features,train_labels)
logging.info("svm train: acc: {}".format(np.around(svm_classifier.score(train_features,train_labels)*100,decimals=2)))
acc = svm_classifier.score(test_features,test_labels)
self._svm_accs.append(np.around(acc*100,decimals=2))
logging.info("svm evaluation: acc_list: {}".format(self._svm_accs))
class FeatureDataset(Dataset):
def __init__(self, features, labels):
assert len(features) == len(labels), "Data size error!"
self.features = torch.from_numpy(features)
self.labels = torch.from_numpy(labels)
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
feature = self.features[idx]
label = self.labels[idx]
return idx, feature, label