Spaces:
Build error
Build error
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from copy import deepcopy | |
from tqdm import tqdm | |
from timm.utils import accuracy | |
from .protonet import ProtoNet | |
from .utils import trunc_normal_, DiffAugment | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_rank(): | |
if not is_dist_avail_and_initialized(): | |
return 0 | |
return dist.get_rank() | |
def is_main_process(): | |
return get_rank() == 0 | |
def entropy_loss(x): | |
return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean() | |
def unique_indices(x): | |
""" | |
Ref: https://github.com/rusty1s/pytorch_unique | |
""" | |
unique, inverse = torch.unique(x, sorted=True, return_inverse=True) | |
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) | |
inverse, perm = inverse.flip([0]), perm.flip([0]) | |
perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm) | |
return unique, perm | |
class ProtoNet_Auto_Finetune(ProtoNet): | |
def __init__(self, backbone, num_iters=50, aug_prob=0.9, | |
aug_types=['color', 'translation'], lr_lst=[0.01, 0.001, 0.0001]): | |
super().__init__(backbone) | |
self.num_iters = num_iters | |
self.lr_lst = lr_lst | |
self.aug_types = aug_types | |
self.aug_prob = aug_prob | |
state_dict = backbone.state_dict() | |
self.backbone_state = deepcopy(state_dict) | |
def forward(self, supp_x, supp_y, qry_x): | |
""" | |
supp_x.shape = [B, nSupp, C, H, W] | |
supp_y.shape = [B, nSupp] | |
qry_x.shape = [B, nQry, C, H, W] | |
""" | |
B, nSupp, C, H, W = supp_x.shape | |
num_classes = supp_y.max() + 1 # NOTE: assume B==1 | |
device = qry_x.device | |
criterion = nn.CrossEntropyLoss() | |
supp_x = supp_x.view(-1, C, H, W) | |
qry_x = qry_x.view(-1, C, H, W) | |
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp | |
supp_y = supp_y.view(-1) | |
def single_step(z, mode=True, x=None, y=None, y_1hot=None): | |
''' | |
z = Aug(supp_x) or qry_x | |
global vars: supp_x, supp_y, supp_y_1hot | |
''' | |
with torch.set_grad_enabled(mode): | |
# recalculate prototypes from supp_x with updated backbone | |
proto_f = self.backbone.forward(x).unsqueeze(0) | |
if y_1hot is None: | |
prototypes = proto_f | |
else: | |
prototypes = torch.bmm(y_1hot.float(), proto_f) # B, nC, d | |
prototypes = prototypes / y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 | |
# compute feature for z | |
feat = self.backbone.forward(z) | |
feat = feat.view(B, z.shape[0], -1) # B, nQry, d | |
# classification | |
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC | |
loss = None | |
if mode: # if enable grad, compute loss | |
loss = criterion(logits.view(len(y), -1), y) | |
return logits, loss | |
# load trained weights | |
self.backbone.load_state_dict(self.backbone_state, strict=True) | |
#zz = DiffAugment(supp_x, ["color", "offset", "offset_h", "offset_v", "translation", "cutout"], 1., detach=True) | |
proto_y, proto_i = unique_indices(supp_y) | |
proto_x = supp_x[proto_i] | |
zz_i = np.setdiff1d(range(len(supp_x)), proto_i.cpu().numpy()) | |
zz_x = supp_x[zz_i] | |
zz_y = supp_y[zz_i] | |
best_lr = 0 | |
max_acc1 = 0 | |
if len(zz_y) > 0: | |
# eval non-finetuned weights (lr=0) | |
logits, _ = single_step(zz_x, False, x=proto_x) | |
max_acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0] | |
print(f'## *lr = 0: acc1 = {max_acc1}\n') | |
for lr in self.lr_lst: | |
# create optimizer | |
opt = torch.optim.Adam(self.backbone.parameters(), | |
lr=lr, | |
betas=(0.9, 0.999), | |
weight_decay=0.) | |
# main loop | |
_num_iters = 50 | |
pbar = tqdm(range(_num_iters)) if is_main_process() else range(_num_iters) | |
for i in pbar: | |
opt.zero_grad() | |
z = DiffAugment(proto_x, self.aug_types, self.aug_prob, detach=True) | |
_, loss = single_step(z, True, x=proto_x, y=proto_y) | |
loss.backward() | |
opt.step() | |
if is_main_process(): | |
pbar.set_description(f' << lr = {lr}: loss = {loss.item()}') | |
logits, _ = single_step(zz_x, False, x=proto_x) | |
acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0] | |
print(f'## *lr = {lr}: acc1 = {acc1}\n') | |
if acc1 > max_acc1: | |
max_acc1 = acc1 | |
best_lr = lr | |
# reset backbone state | |
self.backbone.load_state_dict(self.backbone_state, strict=True) | |
print(f'***Best lr = {best_lr} with acc1 = {max_acc1}.\nStart final loop...\n') | |
# create optimizer | |
opt = torch.optim.Adam(self.backbone.parameters(), | |
lr=best_lr, | |
betas=(0.9, 0.999), | |
weight_decay=0.) | |
# main loop | |
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters) | |
for i in pbar: | |
opt.zero_grad() | |
z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True) | |
_, loss = single_step(z, True, x=supp_x, y=supp_y, y_1hot=supp_y_1hot) | |
loss.backward() | |
opt.step() | |
if is_main_process(): | |
pbar.set_description(f' >> lr = {best_lr}: loss = {loss.item()}') | |
logits, _ = single_step(qry_x, False, x=supp_x, y_1hot=supp_y_1hot) # supp_x has to pair with y_1hot | |
return logits | |
class ProtoNet_Finetune(ProtoNet): | |
def __init__(self, backbone, num_iters=50, lr=5e-2, aug_prob=0.9, | |
aug_types=['color', 'translation']): | |
super().__init__(backbone) | |
self.num_iters = num_iters | |
self.lr = lr | |
self.aug_types = aug_types | |
self.aug_prob = aug_prob | |
def load_state_dict(self, state_dict, strict=True): | |
super().load_state_dict(state_dict, strict) | |
state_dict = self.backbone.state_dict() | |
self.backbone_state = deepcopy(state_dict) | |
def forward(self, supp_x, supp_y, x): | |
""" | |
supp_x.shape = [B, nSupp, C, H, W] | |
supp_y.shape = [B, nSupp] | |
x.shape = [B, nQry, C, H, W] | |
""" | |
# reset backbone state | |
self.backbone.load_state_dict(self.backbone_state, strict=True) | |
if self.lr == 0: | |
return super().forward(supp_x, supp_y, x) | |
B, nSupp, C, H, W = supp_x.shape | |
num_classes = supp_y.max() + 1 # NOTE: assume B==1 | |
device = x.device | |
criterion = nn.CrossEntropyLoss() | |
supp_x = supp_x.view(-1, C, H, W) | |
x = x.view(-1, C, H, W) | |
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp | |
supp_y = supp_y.view(-1) | |
# create optimizer | |
opt = torch.optim.Adam(self.backbone.parameters(), | |
lr=self.lr, | |
betas=(0.9, 0.999), | |
weight_decay=0.) | |
def single_step(z, mode=True): | |
''' | |
z = Aug(supp_x) or x | |
''' | |
with torch.set_grad_enabled(mode): | |
# recalculate prototypes from supp_x with updated backbone | |
supp_f = self.backbone.forward(supp_x) | |
supp_f = supp_f.view(B, nSupp, -1) | |
prototypes = torch.bmm(supp_y_1hot.float(), supp_f) # B, nC, d | |
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 | |
# compute feature for z | |
feat = self.backbone.forward(z) | |
feat = feat.view(B, z.shape[0], -1) # B, nQry, d | |
# classification | |
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC | |
loss = None | |
if mode: # if enable grad, compute loss | |
loss = criterion(logits.view(B*nSupp, -1), supp_y) | |
return logits, loss | |
# main loop | |
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters) | |
for i in pbar: | |
opt.zero_grad() | |
z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True) | |
_, loss = single_step(z, True) | |
loss.backward() | |
opt.step() | |
if is_main_process(): | |
pbar.set_description(f'lr{self.lr}, nSupp{nSupp}, nQry{x.shape[0]}: loss = {loss.item()}') | |
logits, _ = single_step(x, False) | |
return logits | |
class ProtoNet_AdaTok(ProtoNet): | |
def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-2, momentum=0.9, weight_decay=0.): | |
super().__init__(backbone) | |
self.num_adapters = num_adapters | |
self.num_iters = num_iters | |
self.lr = lr | |
self.momentum = momentum | |
self.weight_decay = weight_decay | |
def forward(self, supp_x, supp_y, x): | |
""" | |
supp_x.shape = [B, nSupp, C, H, W] | |
supp_y.shape = [B, nSupp] | |
x.shape = [B, nQry, C, H, W] | |
""" | |
B, nSupp, C, H, W = supp_x.shape | |
nQry = x.shape[1] | |
num_classes = supp_y.max() + 1 # NOTE: assume B==1 | |
device = x.device | |
criterion = nn.CrossEntropyLoss() | |
supp_x = supp_x.view(-1, C, H, W) | |
x = x.view(-1, C, H, W) | |
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp | |
supp_y = supp_y.view(-1) | |
# prepare adapter tokens | |
ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device) | |
trunc_normal_(ada_tokens, std=.02) | |
ada_tokens = ada_tokens.detach().requires_grad_() | |
#optimizer = torch.optim.SGD([ada_tokens], | |
optimizer = torch.optim.Adadelta([ada_tokens], | |
lr=self.lr, | |
#momentum=self.momentum, | |
weight_decay=self.weight_decay) | |
def single_step(mode=True): | |
with torch.set_grad_enabled(mode): | |
supp_f = self.backbone.forward(supp_x, ada_tokens) | |
supp_f = supp_f.view(B, nSupp, -1) | |
# B, nC, nSupp x B, nSupp, d = B, nC, d | |
prototypes = torch.bmm(supp_y_1hot.float(), supp_f) | |
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 | |
if mode == False: # no grad | |
feat = self.backbone.forward(x, ada_tokens) | |
feat = feat.view(B, nQry, -1) # B, nQry, d | |
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC | |
loss = None | |
else: | |
with torch.enable_grad(): | |
logits = self.cos_classifier(prototypes, supp_f) # B, nQry, nC | |
loss = criterion(logits.view(B*nSupp, -1), supp_y) | |
return logits, loss | |
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters) | |
for i in pbar: | |
optimizer.zero_grad() | |
_, loss = single_step(True) | |
loss.backward() | |
optimizer.step() | |
if is_main_process(): | |
pbar.set_description(f'loss = {loss.item()}') | |
logits, _ = single_step(False) | |
return logits | |
class ProtoNet_AdaTok_EntMin(ProtoNet): | |
def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-3, momentum=0.9, weight_decay=0.): | |
super().__init__(backbone) | |
self.num_adapters = num_adapters | |
self.num_iters = num_iters | |
self.lr = lr | |
self.momentum = momentum | |
self.weight_decay = weight_decay | |
def forward(self, supp_x, supp_y, x): | |
""" | |
supp_x.shape = [B, nSupp, C, H, W] | |
supp_y.shape = [B, nSupp] | |
x.shape = [B, nQry, C, H, W] | |
""" | |
B, nSupp, C, H, W = supp_x.shape | |
num_classes = supp_y.max() + 1 # NOTE: assume B==1 | |
device = x.device | |
criterion = entropy_loss | |
supp_x = supp_x.view(-1, C, H, W) | |
x = x.view(-1, C, H, W) | |
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp | |
# adapter tokens | |
ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device) | |
trunc_normal_(ada_tokens, std=.02) | |
ada_tokens = ada_tokens.detach().requires_grad_() | |
optimizer = torch.optim.SGD([ada_tokens], | |
lr=self.lr, | |
momentum=self.momentum, | |
weight_decay=self.weight_decay) | |
def single_step(mode=True): | |
with torch.set_grad_enabled(mode): | |
supp_f = self.backbone.forward(supp_x, ada_tokens) | |
supp_f = supp_f.view(B, nSupp, -1) | |
# B, nC, nSupp x B, nSupp, d = B, nC, d | |
prototypes = torch.bmm(supp_y_1hot.float(), supp_f) | |
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 | |
feat = self.backbone.forward(x, ada_tokens) | |
feat = feat.view(B, x.shape[1], -1) # B, nQry, d | |
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC | |
loss = criterion(logits.view(-1, num_classes)) | |
return logits, loss | |
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters) | |
for i in pbar: | |
optimizer.zero_grad() | |
_, loss = single_step(True) | |
loss.backward() | |
optimizer.step() | |
if is_main_process(): | |
pbar.set_description(f'loss = {loss.item()}') | |
logits, _ = single_step(False) | |
return logits | |