HungNP
New single commit message
cb80c28
import copy
import logging
import torch
from torch import nn
from convs.cifar_resnet import resnet32
from convs.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from convs.ucir_cifar_resnet import resnet32 as cosine_resnet32
from convs.ucir_resnet import resnet18 as cosine_resnet18
from convs.ucir_resnet import resnet34 as cosine_resnet34
from convs.ucir_resnet import resnet50 as cosine_resnet50
from convs.linears import SimpleLinear, SplitCosineLinear, CosineLinear
from convs.modified_represnet import resnet18_rep,resnet34_rep
from convs.resnet_cbam import resnet18_cbam,resnet34_cbam,resnet50_cbam
from convs.memo_resnet import get_resnet18_imagenet as get_memo_resnet18 #for MEMO imagenet
from convs.memo_cifar_resnet import get_resnet32_a2fc as get_memo_resnet32 #for MEMO cifar
def get_convnet(args, pretrained=False):
name = args["convnet_type"].lower()
if name == "resnet32":
return resnet32()
elif name == "resnet18":
return resnet18(pretrained=pretrained,args=args)
elif name == "resnet34":
return resnet34(pretrained=pretrained,args=args)
elif name == "resnet50":
return resnet50(pretrained=pretrained,args=args)
elif name == "cosine_resnet18":
return cosine_resnet18(pretrained=pretrained,args=args)
elif name == "cosine_resnet32":
return cosine_resnet32()
elif name == "cosine_resnet34":
return cosine_resnet34(pretrained=pretrained,args=args)
elif name == "cosine_resnet50":
return cosine_resnet50(pretrained=pretrained,args=args)
elif name == "resnet18_rep":
return resnet18_rep(pretrained=pretrained,args=args)
elif name == "resnet18_cbam":
return resnet18_cbam(pretrained=pretrained,args=args)
elif name == "resnet34_cbam":
return resnet34_cbam(pretrained=pretrained,args=args)
elif name == "resnet50_cbam":
return resnet50_cbam(pretrained=pretrained,args=args)
# MEMO benchmark backbone
elif name == 'memo_resnet18':
_basenet, _adaptive_net = get_memo_resnet18()
return _basenet, _adaptive_net
elif name == 'memo_resnet32':
_basenet, _adaptive_net = get_memo_resnet32()
return _basenet, _adaptive_net
else:
raise NotImplementedError("Unknown type {}".format(name))
class BaseNet(nn.Module):
def __init__(self, args, pretrained):
super(BaseNet, self).__init__()
self.convnet = get_convnet(args, pretrained)
self.fc = None
@property
def feature_dim(self):
return self.convnet.out_dim
def extract_vector(self, x):
return self.convnet(x)["features"]
def forward(self, x):
x = self.convnet(x)
out = self.fc(x["features"])
"""
{
'fmaps': [x_1, x_2, ..., x_n],
'features': features
'logits': logits
}
"""
out.update(x)
return out
def update_fc(self, nb_classes):
pass
def generate_fc(self, in_dim, out_dim):
pass
def copy(self):
return copy.deepcopy(self)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def load_checkpoint(self, args):
if args["init_cls"] == 50:
pkl_name = "{}_{}_{}_B{}_Inc{}".format(
args["dataset"],
args["seed"],
args["convnet_type"],
0,
args["init_cls"],
)
checkpoint_name = f"checkpoints/finetune_{pkl_name}_0.pkl"
else:
checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl"
model_infos = torch.load(checkpoint_name)
self.convnet.load_state_dict(model_infos['convnet'])
self.fc.load_state_dict(model_infos['fc'])
test_acc = model_infos['test_acc']
return test_acc
class IncrementalNet(BaseNet):
def __init__(self, args, pretrained, gradcam=False):
super().__init__(args, pretrained)
self.gradcam = gradcam
if hasattr(self, "gradcam") and self.gradcam:
self._gradcam_hooks = [None, None]
self.set_gradcam_hook()
def update_fc(self, nb_classes):
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output] = weight
fc.bias.data[:nb_output] = bias
del self.fc
self.fc = fc
def weight_align(self, increment):
weights = self.fc.weight.data
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew
print("alignweights,gamma=", gamma)
self.fc.weight.data[-increment:, :] *= gamma
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def forward(self, x):
x = self.convnet(x)
out = self.fc(x["features"])
out.update(x)
if hasattr(self, "gradcam") and self.gradcam:
out["gradcam_gradients"] = self._gradcam_gradients
out["gradcam_activations"] = self._gradcam_activations
return out
def unset_gradcam_hook(self):
self._gradcam_hooks[0].remove()
self._gradcam_hooks[1].remove()
self._gradcam_hooks[0] = None
self._gradcam_hooks[1] = None
self._gradcam_gradients, self._gradcam_activations = [None], [None]
def set_gradcam_hook(self):
self._gradcam_gradients, self._gradcam_activations = [None], [None]
def backward_hook(module, grad_input, grad_output):
self._gradcam_gradients[0] = grad_output[0]
return None
def forward_hook(module, input, output):
self._gradcam_activations[0] = output
return None
self._gradcam_hooks[0] = self.convnet.last_conv.register_backward_hook(
backward_hook
)
self._gradcam_hooks[1] = self.convnet.last_conv.register_forward_hook(
forward_hook
)
class IL2ANet(IncrementalNet):
def update_fc(self, num_old, num_total, num_aux):
fc = self.generate_fc(self.feature_dim, num_total+num_aux)
if self.fc is not None:
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:num_old] = weight[:num_old]
fc.bias.data[:num_old] = bias[:num_old]
del self.fc
self.fc = fc
class CosineIncrementalNet(BaseNet):
def __init__(self, args, pretrained, nb_proxy=1):
super().__init__(args, pretrained)
self.nb_proxy = nb_proxy
def update_fc(self, nb_classes, task_num):
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
if task_num == 1:
fc.fc1.weight.data = self.fc.weight.data
fc.sigma.data = self.fc.sigma.data
else:
prev_out_features1 = self.fc.fc1.out_features
fc.fc1.weight.data[:prev_out_features1] = self.fc.fc1.weight.data
fc.fc1.weight.data[prev_out_features1:] = self.fc.fc2.weight.data
fc.sigma.data = self.fc.sigma.data
del self.fc
self.fc = fc
def generate_fc(self, in_dim, out_dim):
if self.fc is None:
fc = CosineLinear(in_dim, out_dim, self.nb_proxy, to_reduce=True)
else:
prev_out_features = self.fc.out_features // self.nb_proxy
# prev_out_features = self.fc.out_features
fc = SplitCosineLinear(
in_dim, prev_out_features, out_dim - prev_out_features, self.nb_proxy
)
return fc
class BiasLayer_BIC(nn.Module):
def __init__(self):
super(BiasLayer_BIC, self).__init__()
self.alpha = nn.Parameter(torch.ones(1, requires_grad=True))
self.beta = nn.Parameter(torch.zeros(1, requires_grad=True))
def forward(self, x, low_range, high_range):
ret_x = x.clone()
ret_x[:, low_range:high_range] = (
self.alpha * x[:, low_range:high_range] + self.beta
)
return ret_x
def get_params(self):
return (self.alpha.item(), self.beta.item())
class IncrementalNetWithBias(BaseNet):
def __init__(self, args, pretrained, bias_correction=False):
super().__init__(args, pretrained)
# Bias layer
self.bias_correction = bias_correction
self.bias_layers = nn.ModuleList([])
self.task_sizes = []
def forward(self, x):
x = self.convnet(x)
out = self.fc(x["features"])
if self.bias_correction:
logits = out["logits"]
for i, layer in enumerate(self.bias_layers):
logits = layer(
logits, sum(self.task_sizes[:i]), sum(self.task_sizes[: i + 1])
)
out["logits"] = logits
out.update(x)
return out
def update_fc(self, nb_classes):
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output] = weight
fc.bias.data[:nb_output] = bias
del self.fc
self.fc = fc
new_task_size = nb_classes - sum(self.task_sizes)
self.task_sizes.append(new_task_size)
self.bias_layers.append(BiasLayer_BIC())
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def get_bias_params(self):
params = []
for layer in self.bias_layers:
params.append(layer.get_params())
return params
def unfreeze(self):
for param in self.parameters():
param.requires_grad = True
class DERNet(nn.Module):
def __init__(self, args, pretrained):
super(DERNet, self).__init__()
self.convnet_type = args["convnet_type"]
self.convnets = nn.ModuleList()
self.pretrained = pretrained
self.out_dim = None
self.fc = None
self.aux_fc = None
self.task_sizes = []
self.args = args
@property
def feature_dim(self):
if self.out_dim is None:
return 0
return self.out_dim * len(self.convnets)
def extract_vector(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
return features
def forward(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
out = self.fc(features) # {logics: self.fc(features)}
aux_logits = self.aux_fc(features[:, -self.out_dim :])["logits"]
out.update({"aux_logits": aux_logits, "features": features})
return out
"""
{
'features': features
'logits': logits
'aux_logits':aux_logits
}
"""
def update_fc(self, nb_classes):
if len(self.convnets) == 0:
self.convnets.append(get_convnet(self.args))
else:
self.convnets.append(get_convnet(self.args))
self.convnets[-1].load_state_dict(self.convnets[-2].state_dict())
if self.out_dim is None:
self.out_dim = self.convnets[-1].out_dim
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight
fc.bias.data[:nb_output] = bias
del self.fc
self.fc = fc
new_task_size = nb_classes - sum(self.task_sizes)
self.task_sizes.append(new_task_size)
self.aux_fc = self.generate_fc(self.out_dim, new_task_size + 1)
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def copy(self):
return copy.deepcopy(self)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def freeze_conv(self):
for param in self.convnets.parameters():
param.requires_grad = False
self.convnets.eval()
def weight_align(self, increment):
weights = self.fc.weight.data
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew
print("alignweights,gamma=", gamma)
self.fc.weight.data[-increment:, :] *= gamma
def load_checkpoint(self, args):
checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl"
model_infos = torch.load(checkpoint_name)
assert len(self.convnets) == 1
self.convnets[0].load_state_dict(model_infos['convnet'])
self.fc.load_state_dict(model_infos['fc'])
test_acc = model_infos['test_acc']
return test_acc
class SimpleCosineIncrementalNet(BaseNet):
def __init__(self, args, pretrained):
super().__init__(args, pretrained)
def update_fc(self, nb_classes, nextperiod_initialization=None):
fc = self.generate_fc(self.feature_dim, nb_classes).cuda()
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
fc.sigma.data = self.fc.sigma.data
if nextperiod_initialization is not None:
weight = torch.cat([weight.cuda(), nextperiod_initialization.cuda()])
else:
weight = torch.cat([weight.cuda(), torch.zeros(nb_classes - nb_output, self.feature_dim).cuda()])
fc.weight = nn.Parameter(weight)
del self.fc
self.fc = fc
def load_checkpoint(self, checkpoint):
self.convnet.load_state_dict(checkpoint["convnet"])
self.fc.load_state_dict(checkpoint["fc"])
def generate_fc(self, in_dim, out_dim):
fc = CosineLinear(in_dim, out_dim)
return fc
class FOSTERNet(nn.Module):
def __init__(self, args, pretrained):
super(FOSTERNet, self).__init__()
self.convnet_type = args["convnet_type"]
self.convnets = nn.ModuleList()
self.pretrained = pretrained
self.out_dim = None
self.fc = None
self.fe_fc = None
self.task_sizes = []
self.oldfc = None
self.args = args
@property
def feature_dim(self):
if self.out_dim is None:
return 0
return self.out_dim * len(self.convnets)
def extract_vector(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
return features
def load_checkpoint(self, checkpoint):
if len(self.convnets) == 0:
self.convnets.append(get_convnet(self.args))
self.convnets[0].load_state_dict(checkpoint["convnet"])
self.fc.load_state_dict(checkpoint["fc"])
def forward(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
out = self.fc(features)
fe_logits = self.fe_fc(features[:, -self.out_dim :])["logits"]
out.update({"fe_logits": fe_logits, "features": features})
if self.oldfc is not None:
old_logits = self.oldfc(features[:, : -self.out_dim])["logits"]
out.update({"old_logits": old_logits})
out.update({"eval_logits": out["logits"]})
return out
def update_fc(self, nb_classes):
self.convnets.append(get_convnet(self.args))
if self.out_dim is None:
self.out_dim = self.convnets[-1].out_dim
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight
fc.bias.data[:nb_output] = bias
self.convnets[-1].load_state_dict(self.convnets[-2].state_dict())
self.oldfc = self.fc
self.fc = fc
new_task_size = nb_classes - sum(self.task_sizes)
self.task_sizes.append(new_task_size)
self.fe_fc = self.generate_fc(self.out_dim, nb_classes)
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def copy(self):
return copy.deepcopy(self)
def copy_fc(self, fc):
weight = copy.deepcopy(fc.weight.data)
bias = copy.deepcopy(fc.bias.data)
n, m = weight.shape[0], weight.shape[1]
self.fc.weight.data[:n, :m] = weight
self.fc.bias.data[:n] = bias
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def freeze_conv(self):
for param in self.convnets.parameters():
param.requires_grad = False
self.convnets.eval()
def weight_align(self, old, increment, value):
weights = self.fc.weight.data
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew * (value ** (old / increment))
logging.info("align weights, gamma = {} ".format(gamma))
self.fc.weight.data[-increment:, :] *= gamma
class BiasLayer(nn.Module):
def __init__(self):
super(BiasLayer, self).__init__()
self.alpha = nn.Parameter(torch.zeros(1, requires_grad=True))
self.beta = nn.Parameter(torch.zeros(1, requires_grad=True))
def forward(self, x , bias=True):
ret_x = x.clone()
ret_x = (self.alpha+1) * x # + self.beta
if bias:
ret_x = ret_x + self.beta
return ret_x
def get_params(self):
return (self.alpha.item(), self.beta.item())
class BEEFISONet(nn.Module):
def __init__(self, args, pretrained):
super(BEEFISONet, self).__init__()
self.convnet_type = args["convnet_type"]
self.convnets = nn.ModuleList()
self.pretrained = pretrained
self.out_dim = None
self.old_fc = None
self.new_fc = None
self.task_sizes = []
self.forward_prototypes = None
self.backward_prototypes = None
self.args = args
self.biases = nn.ModuleList()
@property
def feature_dim(self):
if self.out_dim is None:
return 0
return self.out_dim * len(self.convnets)
def extract_vector(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
return features
def forward(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
if self.old_fc is None:
fc = self.new_fc
out = fc(features)
else:
'''
merge the weights
'''
new_task_size = self.task_sizes[-1]
fc_weight = torch.cat([self.old_fc.weight,torch.zeros((new_task_size,self.feature_dim-self.out_dim)).cuda()],dim=0)
new_fc_weight = self.new_fc.weight
new_fc_bias = self.new_fc.bias
for i in range(len(self.task_sizes)-2,-1,-1):
new_fc_weight = torch.cat([*[self.biases[i](self.backward_prototypes.weight[i].unsqueeze(0),bias=False) for _ in range(self.task_sizes[i])],new_fc_weight],dim=0)
new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias[i].unsqueeze(0),bias=True) for _ in range(self.task_sizes[i])], new_fc_bias])
fc_weight = torch.cat([fc_weight,new_fc_weight],dim=1)
fc_bias = torch.cat([self.old_fc.bias,torch.zeros(new_task_size).cuda()])
fc_bias+=new_fc_bias
logits = features@fc_weight.permute(1,0)+fc_bias
out = {"logits":logits}
new_fc_weight = self.new_fc.weight
new_fc_bias = self.new_fc.bias
for i in range(len(self.task_sizes)-2,-1,-1):
new_fc_weight = torch.cat([self.backward_prototypes.weight[i].unsqueeze(0),new_fc_weight],dim=0)
new_fc_bias = torch.cat([self.backward_prototypes.bias[i].unsqueeze(0), new_fc_bias])
out["train_logits"] = features[:,-self.out_dim:]@new_fc_weight.permute(1,0)+new_fc_bias
out.update({"eval_logits": out["logits"],"energy_logits":self.forward_prototypes(features[:,-self.out_dim:])["logits"]})
return out
def update_fc_before(self, nb_classes):
new_task_size = nb_classes - sum(self.task_sizes)
self.biases = nn.ModuleList([BiasLayer() for i in range(len(self.task_sizes))])
self.convnets.append(get_convnet(self.args))
if self.out_dim is None:
self.out_dim = self.convnets[-1].out_dim
if self.new_fc is not None:
self.fe_fc = self.generate_fc(self.out_dim, nb_classes)
self.backward_prototypes = self.generate_fc(self.out_dim,len(self.task_sizes))
self.convnets[-1].load_state_dict(self.convnets[0].state_dict())
self.forward_prototypes = self.generate_fc(self.out_dim, nb_classes)
self.new_fc = self.generate_fc(self.out_dim,new_task_size)
self.task_sizes.append(new_task_size)
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def update_fc_after(self):
if self.old_fc is not None:
old_fc = self.generate_fc(self.feature_dim, sum(self.task_sizes))
new_task_size = self.task_sizes[-1]
old_fc.weight.data = torch.cat([self.old_fc.weight.data,torch.zeros((new_task_size,self.feature_dim-self.out_dim)).cuda()],dim=0)
new_fc_weight = self.new_fc.weight.data
new_fc_bias = self.new_fc.bias.data
for i in range(len(self.task_sizes)-2,-1,-1):
new_fc_weight = torch.cat([*[self.biases[i](self.backward_prototypes.weight.data[i].unsqueeze(0),bias=False) for _ in range(self.task_sizes[i])], new_fc_weight],dim=0)
new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias.data[i].unsqueeze(0),bias=True) for _ in range(self.task_sizes[i])], new_fc_bias])
old_fc.weight.data = torch.cat([old_fc.weight.data,new_fc_weight],dim=1)
old_fc.bias.data = torch.cat([self.old_fc.bias.data,torch.zeros(new_task_size).cuda()])
old_fc.bias.data+=new_fc_bias
self.old_fc = old_fc
else:
self.old_fc = self.new_fc
def copy(self):
return copy.deepcopy(self)
def copy_fc(self, fc):
weight = copy.deepcopy(fc.weight.data)
bias = copy.deepcopy(fc.bias.data)
n, m = weight.shape[0], weight.shape[1]
self.fc.weight.data[:n, :m] = weight
self.fc.bias.data[:n] = bias
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def freeze_conv(self):
for param in self.convnets.parameters():
param.requires_grad = False
self.convnets.eval()
def weight_align(self, old, increment, value):
weights = self.fc.weight.data
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew * (value ** (old / increment))
logging.info("align weights, gamma = {} ".format(gamma))
self.fc.weight.data[-increment:, :] *= gamma
class AdaptiveNet(nn.Module):
def __init__(self, args, pretrained):
super(AdaptiveNet, self).__init__()
self.convnet_type = args["convnet_type"]
self.TaskAgnosticExtractor , _network = get_convnet(args, pretrained) #Generalized blocks
self.TaskAgnosticExtractor.train()
self.AdaptiveExtractors = nn.ModuleList() #Specialized Blocks
self.AdaptiveExtractors.append(_network)
self.pretrained=pretrained
if args["backbone"] != None and pretrained == True:
self.load_checkpoint(args)
self.out_dim=None
self.fc = None
self.aux_fc=None
self.task_sizes = []
self.args=args
@property
def feature_dim(self):
if self.out_dim is None:
return 0
return self.out_dim*len(self.AdaptiveExtractors)
def extract_vector(self, x):
base_feature_map = self.TaskAgnosticExtractor(x)
features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors]
features = torch.cat(features, 1)
return features
def forward(self, x):
base_feature_map = self.TaskAgnosticExtractor(x)
features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors]
features = torch.cat(features, 1)
out=self.fc(features) #{logits: self.fc(features)}
aux_logits=self.aux_fc(features[:,-self.out_dim:])["logits"]
out.update({"aux_logits":aux_logits,"features":features})
out.update({"base_features":base_feature_map})
return out
'''
{
'features': features
'logits': logits
'aux_logits':aux_logits
}
'''
def update_fc(self,nb_classes):
_ , _new_extractor = get_convnet(self.args)
if len(self.AdaptiveExtractors)==0:
self.AdaptiveExtractors.append(_new_extractor)
else:
self.AdaptiveExtractors.append(_new_extractor)
self.AdaptiveExtractors[-1].load_state_dict(self.AdaptiveExtractors[-2].state_dict())
if self.out_dim is None:
logging.info(self.AdaptiveExtractors[-1])
self.out_dim=self.AdaptiveExtractors[-1].feature_dim
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight
fc.bias.data[:nb_output] = bias
del self.fc
self.fc = fc
new_task_size = nb_classes - sum(self.task_sizes)
self.task_sizes.append(new_task_size)
self.aux_fc=self.generate_fc(self.out_dim,new_task_size+1)
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def copy(self):
return copy.deepcopy(self)
def weight_align(self, increment):
weights=self.fc.weight.data
newnorm=(torch.norm(weights[-increment:,:],p=2,dim=1))
oldnorm=(torch.norm(weights[:-increment,:],p=2,dim=1))
meannew=torch.mean(newnorm)
meanold=torch.mean(oldnorm)
gamma=meanold/meannew
print('alignweights,gamma=',gamma)
self.fc.weight.data[-increment:,:]*=gamma
def load_checkpoint(self, args):
checkpoint_name = args["backbone"]
model_infos = torch.load(checkpoint_name)
model_dict = model_infos['convnet']
assert len(self.AdaptiveExtractors) == 1
base_state_dict = self.TaskAgnosticExtractor.state_dict()
adap_state_dict = self.AdaptiveExtractors[0].state_dict()
pretrained_base_dict = {
k:v
for k, v in model_dict.items()
if k in base_state_dict
}
pretrained_adap_dict = {
k:v
for k, v in model_dict.items()
if k in adap_state_dict
}
base_state_dict.update(pretrained_base_dict)
adap_state_dict.update(pretrained_adap_dict)
self.TaskAgnosticExtractor.load_state_dict(base_state_dict)
self.AdaptiveExtractors[0].load_state_dict(adap_state_dict)
#self.fc.load_state_dict(model_infos['fc'])
test_acc = model_infos['test_acc']
return test_acc