HungNP
New single commit message
cb80c28
raw
history blame contribute delete
No virus
11.4 kB
import logging
import numpy as np
from torch._C import device
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import IncrementalNet
from utils.inc_net import CosineIncrementalNet
from utils.toolkit import target2onehot, tensor2numpy
try:
from quadprog import solve_qp
except:
pass
EPSILON = 1e-8
init_epoch = 1
init_lr = 0.1
init_milestones = [40, 60, 80]
init_lr_decay = 0.1
init_weight_decay = 0.0005
epochs = 1
lrate = 0.1
milestones = [20, 40, 60]
lrate_decay = 0.1
batch_size = 16
weight_decay = 2e-4
num_workers = 4
class GEM(BaseLearner):
def __init__(self, args):
super().__init__(args)
self._network = IncrementalNet(args, False)
self.previous_data = None
self.previous_label = None
def after_task(self):
self._old_network = self._network.copy().freeze()
self._known_classes = self._total_classes
logging.info("Exemplar size: {}".format(self.exemplar_size))
def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self._network.update_fc(self._total_classes)
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
train_dataset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
)
self.train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_dataset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
if self._cur_task > 0:
previous_dataset = data_manager.get_dataset(
[], source="train", mode="train", appendent=self._get_memory()
)
self.previous_data = []
self.previous_label = []
for i in previous_dataset:
_, data_, label_ = i
self.previous_data.append(data_)
self.previous_label.append(label_)
self.previous_data = torch.stack(self.previous_data)
self.previous_label = torch.tensor(self.previous_label)
# Procedure
if len(self._multiple_gpus) > 1:
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._train(self.train_loader, self.test_loader)
self.build_rehearsal_memory(data_manager, self.samples_per_class)
if len(self._multiple_gpus) > 1:
self._network = self._network.module
def _train(self, train_loader, test_loader):
self._network.to(self._device)
if self._old_network is not None:
self._old_network.to(self._device)
if self._cur_task == 0:
optimizer = optim.SGD(
self._network.parameters(),
momentum=0.9,
lr=init_lr,
weight_decay=init_weight_decay,
)
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
)
self._init_train(train_loader, test_loader, optimizer, scheduler)
else:
optimizer = optim.SGD(
self._network.parameters(),
lr=lrate,
momentum=0.9,
weight_decay=weight_decay,
) # 1e-5
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=milestones, gamma=lrate_decay
)
self._update_representation(train_loader, test_loader, optimizer, scheduler)
def _init_train(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(init_epoch))
for _, epoch in enumerate(prog_bar):
self._network.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(self._device), targets.to(self._device)
logits = self._network(inputs)["logits"]
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
init_epoch,
losses / len(train_loader),
train_acc,
test_acc,
)
else:
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
init_epoch,
losses / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)
def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(epochs))
grad_numels = []
for params in self._network.parameters():
grad_numels.append(params.data.numel())
G = torch.zeros((sum(grad_numels), self._cur_task + 1)).to(self._device)
for _, epoch in enumerate(prog_bar):
self._network.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
incremental_step = self._total_classes - self._known_classes
for k in range(0, self._cur_task):
optimizer.zero_grad()
mask = torch.where(
(self.previous_label >= k * incremental_step)
& (self.previous_label < (k + 1) * incremental_step)
)[0]
data_ = self.previous_data[mask].to(self._device)
label_ = self.previous_label[mask].to(self._device)
pred_ = self._network(data_)["logits"]
pred_[:, : k * incremental_step].data.fill_(-10e10)
pred_[:, (k + 1) * incremental_step :].data.fill_(-10e10)
loss_ = F.cross_entropy(pred_, label_)
loss_.backward()
j = 0
for params in self._network.parameters():
if params is not None:
if j == 0:
stpt = 0
else:
stpt = sum(grad_numels[:j])
endpt = sum(grad_numels[: j + 1])
G[stpt:endpt, k].data.copy_(params.grad.data.view(-1))
j += 1
optimizer.zero_grad()
inputs, targets = inputs.to(self._device), targets.to(self._device)
logits = self._network(inputs)["logits"]
logits[:, : self._known_classes].data.fill_(-10e10)
loss_clf = F.cross_entropy(logits, targets)
loss = loss_clf
optimizer.zero_grad()
loss.backward()
j = 0
for params in self._network.parameters():
if params is not None:
if j == 0:
stpt = 0
else:
stpt = sum(grad_numels[:j])
endpt = sum(grad_numels[: j + 1])
G[stpt:endpt, self._cur_task].data.copy_(
params.grad.data.view(-1)
)
j += 1
dotprod = torch.mm(
G[:, self._cur_task].unsqueeze(0), G[:, : self._cur_task]
)
if (dotprod < 0).sum() > 0:
old_grad = G[:, : self._cur_task].cpu().t().double().numpy()
cur_grad = G[:, self._cur_task].cpu().contiguous().double().numpy()
C = old_grad @ old_grad.T
p = old_grad @ cur_grad
A = np.eye(old_grad.shape[0])
b = np.zeros(old_grad.shape[0])
v = solve_qp(C, -p, A, b)[0]
new_grad = old_grad.T @ v + cur_grad
new_grad = torch.tensor(new_grad).float().to(self._device)
new_dotprod = torch.mm(
new_grad.unsqueeze(0), G[:, : self._cur_task]
)
if (new_dotprod < -0.01).sum() > 0:
assert 0
j = 0
for params in self._network.parameters():
if params is not None:
if j == 0:
stpt = 0
else:
stpt = sum(grad_numels[:j])
endpt = sum(grad_numels[: j + 1])
params.grad.data.copy_(
new_grad[stpt:endpt]
.contiguous()
.view(params.grad.data.size())
)
j += 1
optimizer.step()
losses += loss.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
epochs,
losses / len(train_loader),
train_acc,
test_acc,
)
else:
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
epochs,
losses / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)