import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from copy import deepcopy from tqdm import tqdm from timm.utils import accuracy from .protonet import ProtoNet from .utils import trunc_normal_, DiffAugment def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 @torch.jit.script def entropy_loss(x): return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean() def unique_indices(x): """ Ref: https://github.com/rusty1s/pytorch_unique """ unique, inverse = torch.unique(x, sorted=True, return_inverse=True) perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) inverse, perm = inverse.flip([0]), perm.flip([0]) perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm) return unique, perm class ProtoNet_Auto_Finetune(ProtoNet): def __init__(self, backbone, num_iters=50, aug_prob=0.9, aug_types=['color', 'translation'], lr_lst=[0.01, 0.001, 0.0001]): super().__init__(backbone) self.num_iters = num_iters self.lr_lst = lr_lst self.aug_types = aug_types self.aug_prob = aug_prob state_dict = backbone.state_dict() self.backbone_state = deepcopy(state_dict) def forward(self, supp_x, supp_y, qry_x): """ supp_x.shape = [B, nSupp, C, H, W] supp_y.shape = [B, nSupp] qry_x.shape = [B, nQry, C, H, W] """ B, nSupp, C, H, W = supp_x.shape num_classes = supp_y.max() + 1 # NOTE: assume B==1 device = qry_x.device criterion = nn.CrossEntropyLoss() supp_x = supp_x.view(-1, C, H, W) qry_x = qry_x.view(-1, C, H, W) supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp supp_y = supp_y.view(-1) def single_step(z, mode=True, x=None, y=None, y_1hot=None): ''' z = Aug(supp_x) or qry_x global vars: supp_x, supp_y, supp_y_1hot ''' with torch.set_grad_enabled(mode): # recalculate prototypes from supp_x with updated backbone proto_f = self.backbone.forward(x).unsqueeze(0) if y_1hot is None: prototypes = proto_f else: prototypes = torch.bmm(y_1hot.float(), proto_f) # B, nC, d prototypes = prototypes / y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 # compute feature for z feat = self.backbone.forward(z) feat = feat.view(B, z.shape[0], -1) # B, nQry, d # classification logits = self.cos_classifier(prototypes, feat) # B, nQry, nC loss = None if mode: # if enable grad, compute loss loss = criterion(logits.view(len(y), -1), y) return logits, loss # load trained weights self.backbone.load_state_dict(self.backbone_state, strict=True) #zz = DiffAugment(supp_x, ["color", "offset", "offset_h", "offset_v", "translation", "cutout"], 1., detach=True) proto_y, proto_i = unique_indices(supp_y) proto_x = supp_x[proto_i] zz_i = np.setdiff1d(range(len(supp_x)), proto_i.cpu().numpy()) zz_x = supp_x[zz_i] zz_y = supp_y[zz_i] best_lr = 0 max_acc1 = 0 if len(zz_y) > 0: # eval non-finetuned weights (lr=0) logits, _ = single_step(zz_x, False, x=proto_x) max_acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0] print(f'## *lr = 0: acc1 = {max_acc1}\n') for lr in self.lr_lst: # create optimizer opt = torch.optim.Adam(self.backbone.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.) # main loop _num_iters = 50 pbar = tqdm(range(_num_iters)) if is_main_process() else range(_num_iters) for i in pbar: opt.zero_grad() z = DiffAugment(proto_x, self.aug_types, self.aug_prob, detach=True) _, loss = single_step(z, True, x=proto_x, y=proto_y) loss.backward() opt.step() if is_main_process(): pbar.set_description(f' << lr = {lr}: loss = {loss.item()}') logits, _ = single_step(zz_x, False, x=proto_x) acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0] print(f'## *lr = {lr}: acc1 = {acc1}\n') if acc1 > max_acc1: max_acc1 = acc1 best_lr = lr # reset backbone state self.backbone.load_state_dict(self.backbone_state, strict=True) print(f'***Best lr = {best_lr} with acc1 = {max_acc1}.\nStart final loop...\n') # create optimizer opt = torch.optim.Adam(self.backbone.parameters(), lr=best_lr, betas=(0.9, 0.999), weight_decay=0.) # main loop pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters) for i in pbar: opt.zero_grad() z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True) _, loss = single_step(z, True, x=supp_x, y=supp_y, y_1hot=supp_y_1hot) loss.backward() opt.step() if is_main_process(): pbar.set_description(f' >> lr = {best_lr}: loss = {loss.item()}') logits, _ = single_step(qry_x, False, x=supp_x, y_1hot=supp_y_1hot) # supp_x has to pair with y_1hot return logits class ProtoNet_Finetune(ProtoNet): def __init__(self, backbone, num_iters=50, lr=5e-2, aug_prob=0.9, aug_types=['color', 'translation']): super().__init__(backbone) self.num_iters = num_iters self.lr = lr self.aug_types = aug_types self.aug_prob = aug_prob def load_state_dict(self, state_dict, strict=True): super().load_state_dict(state_dict, strict) state_dict = self.backbone.state_dict() self.backbone_state = deepcopy(state_dict) def forward(self, supp_x, supp_y, x): """ supp_x.shape = [B, nSupp, C, H, W] supp_y.shape = [B, nSupp] x.shape = [B, nQry, C, H, W] """ # reset backbone state self.backbone.load_state_dict(self.backbone_state, strict=True) if self.lr == 0: return super().forward(supp_x, supp_y, x) B, nSupp, C, H, W = supp_x.shape num_classes = supp_y.max() + 1 # NOTE: assume B==1 device = x.device criterion = nn.CrossEntropyLoss() supp_x = supp_x.view(-1, C, H, W) x = x.view(-1, C, H, W) supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp supp_y = supp_y.view(-1) # create optimizer opt = torch.optim.Adam(self.backbone.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=0.) def single_step(z, mode=True): ''' z = Aug(supp_x) or x ''' with torch.set_grad_enabled(mode): # recalculate prototypes from supp_x with updated backbone supp_f = self.backbone.forward(supp_x) supp_f = supp_f.view(B, nSupp, -1) prototypes = torch.bmm(supp_y_1hot.float(), supp_f) # B, nC, d prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 # compute feature for z feat = self.backbone.forward(z) feat = feat.view(B, z.shape[0], -1) # B, nQry, d # classification logits = self.cos_classifier(prototypes, feat) # B, nQry, nC loss = None if mode: # if enable grad, compute loss loss = criterion(logits.view(B*nSupp, -1), supp_y) return logits, loss # main loop pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters) for i in pbar: opt.zero_grad() z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True) _, loss = single_step(z, True) loss.backward() opt.step() if is_main_process(): pbar.set_description(f'lr{self.lr}, nSupp{nSupp}, nQry{x.shape[0]}: loss = {loss.item()}') logits, _ = single_step(x, False) return logits class ProtoNet_AdaTok(ProtoNet): def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-2, momentum=0.9, weight_decay=0.): super().__init__(backbone) self.num_adapters = num_adapters self.num_iters = num_iters self.lr = lr self.momentum = momentum self.weight_decay = weight_decay def forward(self, supp_x, supp_y, x): """ supp_x.shape = [B, nSupp, C, H, W] supp_y.shape = [B, nSupp] x.shape = [B, nQry, C, H, W] """ B, nSupp, C, H, W = supp_x.shape nQry = x.shape[1] num_classes = supp_y.max() + 1 # NOTE: assume B==1 device = x.device criterion = nn.CrossEntropyLoss() supp_x = supp_x.view(-1, C, H, W) x = x.view(-1, C, H, W) supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp supp_y = supp_y.view(-1) # prepare adapter tokens ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device) trunc_normal_(ada_tokens, std=.02) ada_tokens = ada_tokens.detach().requires_grad_() #optimizer = torch.optim.SGD([ada_tokens], optimizer = torch.optim.Adadelta([ada_tokens], lr=self.lr, #momentum=self.momentum, weight_decay=self.weight_decay) def single_step(mode=True): with torch.set_grad_enabled(mode): supp_f = self.backbone.forward(supp_x, ada_tokens) supp_f = supp_f.view(B, nSupp, -1) # B, nC, nSupp x B, nSupp, d = B, nC, d prototypes = torch.bmm(supp_y_1hot.float(), supp_f) prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 if mode == False: # no grad feat = self.backbone.forward(x, ada_tokens) feat = feat.view(B, nQry, -1) # B, nQry, d logits = self.cos_classifier(prototypes, feat) # B, nQry, nC loss = None else: with torch.enable_grad(): logits = self.cos_classifier(prototypes, supp_f) # B, nQry, nC loss = criterion(logits.view(B*nSupp, -1), supp_y) return logits, loss pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters) for i in pbar: optimizer.zero_grad() _, loss = single_step(True) loss.backward() optimizer.step() if is_main_process(): pbar.set_description(f'loss = {loss.item()}') logits, _ = single_step(False) return logits class ProtoNet_AdaTok_EntMin(ProtoNet): def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-3, momentum=0.9, weight_decay=0.): super().__init__(backbone) self.num_adapters = num_adapters self.num_iters = num_iters self.lr = lr self.momentum = momentum self.weight_decay = weight_decay def forward(self, supp_x, supp_y, x): """ supp_x.shape = [B, nSupp, C, H, W] supp_y.shape = [B, nSupp] x.shape = [B, nQry, C, H, W] """ B, nSupp, C, H, W = supp_x.shape num_classes = supp_y.max() + 1 # NOTE: assume B==1 device = x.device criterion = entropy_loss supp_x = supp_x.view(-1, C, H, W) x = x.view(-1, C, H, W) supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp # adapter tokens ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device) trunc_normal_(ada_tokens, std=.02) ada_tokens = ada_tokens.detach().requires_grad_() optimizer = torch.optim.SGD([ada_tokens], lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) def single_step(mode=True): with torch.set_grad_enabled(mode): supp_f = self.backbone.forward(supp_x, ada_tokens) supp_f = supp_f.view(B, nSupp, -1) # B, nC, nSupp x B, nSupp, d = B, nC, d prototypes = torch.bmm(supp_y_1hot.float(), supp_f) prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 feat = self.backbone.forward(x, ada_tokens) feat = feat.view(B, x.shape[1], -1) # B, nQry, d logits = self.cos_classifier(prototypes, feat) # B, nQry, nC loss = criterion(logits.view(-1, num_classes)) return logits, loss pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters) for i in pbar: optimizer.zero_grad() _, loss = single_step(True) loss.backward() optimizer.step() if is_main_process(): pbar.set_description(f'loss = {loss.item()}') logits, _ = single_step(False) return logits