pmf_with_gis / models /deploy.py
hushell's picture
add app.py
b9288df
raw
history blame
14.4 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from copy import deepcopy
from tqdm import tqdm
from timm.utils import accuracy
from .protonet import ProtoNet
from .utils import trunc_normal_, DiffAugment
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
@torch.jit.script
def entropy_loss(x):
return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean()
def unique_indices(x):
"""
Ref: https://github.com/rusty1s/pytorch_unique
"""
unique, inverse = torch.unique(x, sorted=True, return_inverse=True)
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
inverse, perm = inverse.flip([0]), perm.flip([0])
perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
return unique, perm
class ProtoNet_Auto_Finetune(ProtoNet):
def __init__(self, backbone, num_iters=50, aug_prob=0.9,
aug_types=['color', 'translation'], lr_lst=[0.01, 0.001, 0.0001]):
super().__init__(backbone)
self.num_iters = num_iters
self.lr_lst = lr_lst
self.aug_types = aug_types
self.aug_prob = aug_prob
state_dict = backbone.state_dict()
self.backbone_state = deepcopy(state_dict)
def forward(self, supp_x, supp_y, qry_x):
"""
supp_x.shape = [B, nSupp, C, H, W]
supp_y.shape = [B, nSupp]
qry_x.shape = [B, nQry, C, H, W]
"""
B, nSupp, C, H, W = supp_x.shape
num_classes = supp_y.max() + 1 # NOTE: assume B==1
device = qry_x.device
criterion = nn.CrossEntropyLoss()
supp_x = supp_x.view(-1, C, H, W)
qry_x = qry_x.view(-1, C, H, W)
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
supp_y = supp_y.view(-1)
def single_step(z, mode=True, x=None, y=None, y_1hot=None):
'''
z = Aug(supp_x) or qry_x
global vars: supp_x, supp_y, supp_y_1hot
'''
with torch.set_grad_enabled(mode):
# recalculate prototypes from supp_x with updated backbone
proto_f = self.backbone.forward(x).unsqueeze(0)
if y_1hot is None:
prototypes = proto_f
else:
prototypes = torch.bmm(y_1hot.float(), proto_f) # B, nC, d
prototypes = prototypes / y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
# compute feature for z
feat = self.backbone.forward(z)
feat = feat.view(B, z.shape[0], -1) # B, nQry, d
# classification
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
loss = None
if mode: # if enable grad, compute loss
loss = criterion(logits.view(len(y), -1), y)
return logits, loss
# load trained weights
self.backbone.load_state_dict(self.backbone_state, strict=True)
#zz = DiffAugment(supp_x, ["color", "offset", "offset_h", "offset_v", "translation", "cutout"], 1., detach=True)
proto_y, proto_i = unique_indices(supp_y)
proto_x = supp_x[proto_i]
zz_i = np.setdiff1d(range(len(supp_x)), proto_i.cpu().numpy())
zz_x = supp_x[zz_i]
zz_y = supp_y[zz_i]
best_lr = 0
max_acc1 = 0
if len(zz_y) > 0:
# eval non-finetuned weights (lr=0)
logits, _ = single_step(zz_x, False, x=proto_x)
max_acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0]
print(f'## *lr = 0: acc1 = {max_acc1}\n')
for lr in self.lr_lst:
# create optimizer
opt = torch.optim.Adam(self.backbone.parameters(),
lr=lr,
betas=(0.9, 0.999),
weight_decay=0.)
# main loop
_num_iters = 50
pbar = tqdm(range(_num_iters)) if is_main_process() else range(_num_iters)
for i in pbar:
opt.zero_grad()
z = DiffAugment(proto_x, self.aug_types, self.aug_prob, detach=True)
_, loss = single_step(z, True, x=proto_x, y=proto_y)
loss.backward()
opt.step()
if is_main_process():
pbar.set_description(f' << lr = {lr}: loss = {loss.item()}')
logits, _ = single_step(zz_x, False, x=proto_x)
acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0]
print(f'## *lr = {lr}: acc1 = {acc1}\n')
if acc1 > max_acc1:
max_acc1 = acc1
best_lr = lr
# reset backbone state
self.backbone.load_state_dict(self.backbone_state, strict=True)
print(f'***Best lr = {best_lr} with acc1 = {max_acc1}.\nStart final loop...\n')
# create optimizer
opt = torch.optim.Adam(self.backbone.parameters(),
lr=best_lr,
betas=(0.9, 0.999),
weight_decay=0.)
# main loop
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
for i in pbar:
opt.zero_grad()
z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True)
_, loss = single_step(z, True, x=supp_x, y=supp_y, y_1hot=supp_y_1hot)
loss.backward()
opt.step()
if is_main_process():
pbar.set_description(f' >> lr = {best_lr}: loss = {loss.item()}')
logits, _ = single_step(qry_x, False, x=supp_x, y_1hot=supp_y_1hot) # supp_x has to pair with y_1hot
return logits
class ProtoNet_Finetune(ProtoNet):
def __init__(self, backbone, num_iters=50, lr=5e-2, aug_prob=0.9,
aug_types=['color', 'translation']):
super().__init__(backbone)
self.num_iters = num_iters
self.lr = lr
self.aug_types = aug_types
self.aug_prob = aug_prob
def load_state_dict(self, state_dict, strict=True):
super().load_state_dict(state_dict, strict)
state_dict = self.backbone.state_dict()
self.backbone_state = deepcopy(state_dict)
def forward(self, supp_x, supp_y, x):
"""
supp_x.shape = [B, nSupp, C, H, W]
supp_y.shape = [B, nSupp]
x.shape = [B, nQry, C, H, W]
"""
# reset backbone state
self.backbone.load_state_dict(self.backbone_state, strict=True)
if self.lr == 0:
return super().forward(supp_x, supp_y, x)
B, nSupp, C, H, W = supp_x.shape
num_classes = supp_y.max() + 1 # NOTE: assume B==1
device = x.device
criterion = nn.CrossEntropyLoss()
supp_x = supp_x.view(-1, C, H, W)
x = x.view(-1, C, H, W)
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
supp_y = supp_y.view(-1)
# create optimizer
opt = torch.optim.Adam(self.backbone.parameters(),
lr=self.lr,
betas=(0.9, 0.999),
weight_decay=0.)
def single_step(z, mode=True):
'''
z = Aug(supp_x) or x
'''
with torch.set_grad_enabled(mode):
# recalculate prototypes from supp_x with updated backbone
supp_f = self.backbone.forward(supp_x)
supp_f = supp_f.view(B, nSupp, -1)
prototypes = torch.bmm(supp_y_1hot.float(), supp_f) # B, nC, d
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
# compute feature for z
feat = self.backbone.forward(z)
feat = feat.view(B, z.shape[0], -1) # B, nQry, d
# classification
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
loss = None
if mode: # if enable grad, compute loss
loss = criterion(logits.view(B*nSupp, -1), supp_y)
return logits, loss
# main loop
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
for i in pbar:
opt.zero_grad()
z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True)
_, loss = single_step(z, True)
loss.backward()
opt.step()
if is_main_process():
pbar.set_description(f'lr{self.lr}, nSupp{nSupp}, nQry{x.shape[0]}: loss = {loss.item()}')
logits, _ = single_step(x, False)
return logits
class ProtoNet_AdaTok(ProtoNet):
def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-2, momentum=0.9, weight_decay=0.):
super().__init__(backbone)
self.num_adapters = num_adapters
self.num_iters = num_iters
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
def forward(self, supp_x, supp_y, x):
"""
supp_x.shape = [B, nSupp, C, H, W]
supp_y.shape = [B, nSupp]
x.shape = [B, nQry, C, H, W]
"""
B, nSupp, C, H, W = supp_x.shape
nQry = x.shape[1]
num_classes = supp_y.max() + 1 # NOTE: assume B==1
device = x.device
criterion = nn.CrossEntropyLoss()
supp_x = supp_x.view(-1, C, H, W)
x = x.view(-1, C, H, W)
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
supp_y = supp_y.view(-1)
# prepare adapter tokens
ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device)
trunc_normal_(ada_tokens, std=.02)
ada_tokens = ada_tokens.detach().requires_grad_()
#optimizer = torch.optim.SGD([ada_tokens],
optimizer = torch.optim.Adadelta([ada_tokens],
lr=self.lr,
#momentum=self.momentum,
weight_decay=self.weight_decay)
def single_step(mode=True):
with torch.set_grad_enabled(mode):
supp_f = self.backbone.forward(supp_x, ada_tokens)
supp_f = supp_f.view(B, nSupp, -1)
# B, nC, nSupp x B, nSupp, d = B, nC, d
prototypes = torch.bmm(supp_y_1hot.float(), supp_f)
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
if mode == False: # no grad
feat = self.backbone.forward(x, ada_tokens)
feat = feat.view(B, nQry, -1) # B, nQry, d
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
loss = None
else:
with torch.enable_grad():
logits = self.cos_classifier(prototypes, supp_f) # B, nQry, nC
loss = criterion(logits.view(B*nSupp, -1), supp_y)
return logits, loss
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
for i in pbar:
optimizer.zero_grad()
_, loss = single_step(True)
loss.backward()
optimizer.step()
if is_main_process():
pbar.set_description(f'loss = {loss.item()}')
logits, _ = single_step(False)
return logits
class ProtoNet_AdaTok_EntMin(ProtoNet):
def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-3, momentum=0.9, weight_decay=0.):
super().__init__(backbone)
self.num_adapters = num_adapters
self.num_iters = num_iters
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
def forward(self, supp_x, supp_y, x):
"""
supp_x.shape = [B, nSupp, C, H, W]
supp_y.shape = [B, nSupp]
x.shape = [B, nQry, C, H, W]
"""
B, nSupp, C, H, W = supp_x.shape
num_classes = supp_y.max() + 1 # NOTE: assume B==1
device = x.device
criterion = entropy_loss
supp_x = supp_x.view(-1, C, H, W)
x = x.view(-1, C, H, W)
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
# adapter tokens
ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device)
trunc_normal_(ada_tokens, std=.02)
ada_tokens = ada_tokens.detach().requires_grad_()
optimizer = torch.optim.SGD([ada_tokens],
lr=self.lr,
momentum=self.momentum,
weight_decay=self.weight_decay)
def single_step(mode=True):
with torch.set_grad_enabled(mode):
supp_f = self.backbone.forward(supp_x, ada_tokens)
supp_f = supp_f.view(B, nSupp, -1)
# B, nC, nSupp x B, nSupp, d = B, nC, d
prototypes = torch.bmm(supp_y_1hot.float(), supp_f)
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
feat = self.backbone.forward(x, ada_tokens)
feat = feat.view(B, x.shape[1], -1) # B, nQry, d
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
loss = criterion(logits.view(-1, num_classes))
return logits, loss
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
for i in pbar:
optimizer.zero_grad()
_, loss = single_step(True)
loss.backward()
optimizer.step()
if is_main_process():
pbar.set_description(f'loss = {loss.item()}')
logits, _ = single_step(False)
return logits