HungNP
New single commit message
cb80c28
# Please note that the current implementation of DER only contains the dynamic expansion process, since masking and pruning are not implemented by the source repo.
import logging
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import DERNet, IncrementalNet
from utils.toolkit import count_parameters, target2onehot, tensor2numpy
EPSILON = 1e-8
init_epoch = 100
init_lr = 0.1
init_milestones = [40, 60, 80]
init_lr_decay = 0.1
init_weight_decay = 0.0005
epochs = 80
lrate = 0.1
milestones = [30, 50, 70]
lrate_decay = 0.1
batch_size = 32
weight_decay = 2e-4
num_workers = 8
T = 2
class DER(BaseLearner):
def __init__(self, args):
super().__init__(args)
self._network = DERNet(args, False)
def after_task(self):
self._known_classes = self._total_classes
logging.info("Exemplar size: {}".format(self.exemplar_size))
def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self._network.update_fc(self._total_classes)
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
if self._cur_task > 0:
for i in range(self._cur_task):
for p in self._network.convnets[i].parameters():
p.requires_grad = False
logging.info("All params: {}".format(count_parameters(self._network)))
logging.info(
"Trainable params: {}".format(count_parameters(self._network, True))
)
train_dataset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
)
self.train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_dataset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
if len(self._multiple_gpus) > 1:
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._train(self.train_loader, self.test_loader)
self.build_rehearsal_memory(data_manager, self.samples_per_class)
if len(self._multiple_gpus) > 1:
self._network = self._network.module
def train(self):
self._network.train()
if len(self._multiple_gpus) > 1 :
self._network_module_ptr = self._network.module
else:
self._network_module_ptr = self._network
self._network_module_ptr.convnets[-1].train()
if self._cur_task >= 1:
for i in range(self._cur_task):
self._network_module_ptr.convnets[i].eval()
def _train(self, train_loader, test_loader):
self._network.to(self._device)
if self._cur_task == 0:
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, self._network.parameters()),
momentum=0.9,
lr=init_lr,
weight_decay=init_weight_decay,
)
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
)
self._init_train(train_loader, test_loader, optimizer, scheduler)
else:
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, self._network.parameters()),
lr=lrate,
momentum=0.9,
weight_decay=weight_decay,
)
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=milestones, gamma=lrate_decay
)
self._update_representation(train_loader, test_loader, optimizer, scheduler)
if len(self._multiple_gpus) > 1:
self._network.module.weight_align(
self._total_classes - self._known_classes
)
else:
self._network.weight_align(self._total_classes - self._known_classes)
def _init_train(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(init_epoch))
for _, epoch in enumerate(prog_bar):
self.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(self._device), targets.to(self._device)
logits = self._network(inputs)["logits"]
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
init_epoch,
losses / len(train_loader),
train_acc,
test_acc,
)
else:
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
init_epoch,
losses / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)
def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(epochs))
for _, epoch in enumerate(prog_bar):
self.train()
losses = 0.0
losses_clf = 0.0
losses_aux = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(self._device), targets.to(self._device)
outputs = self._network(inputs)
logits, aux_logits = outputs["logits"], outputs["aux_logits"]
loss_clf = F.cross_entropy(logits, targets)
aux_targets = targets.clone()
aux_targets = torch.where(
aux_targets - self._known_classes + 1 > 0,
aux_targets - self._known_classes + 1,
torch.tensor([0]).to(self._device),
)
loss_aux = F.cross_entropy(aux_logits, aux_targets)
loss = loss_clf + loss_aux
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
losses_aux += loss_aux.item()
losses_clf += loss_clf.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
if epoch % 5 == 0:
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
epochs,
losses / len(train_loader),
losses_clf / len(train_loader),
losses_aux / len(train_loader),
train_acc,
test_acc,
)
else:
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}".format(
self._cur_task,
epoch + 1,
epochs,
losses / len(train_loader),
losses_clf / len(train_loader),
losses_aux / len(train_loader),
train_acc,
)
prog_bar.set_description(info)
logging.info(info)