|
import numpy as np |
|
from functools import partial |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
from dassl.data import DataManager |
|
from dassl.optim import build_optimizer |
|
from dassl.utils import count_num_param |
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU |
|
from dassl.metrics import compute_accuracy |
|
from dassl.modeling.ops import ReverseGrad |
|
from dassl.engine.trainer import SimpleNet |
|
from dassl.data.transforms.transforms import build_transform |
|
|
|
|
|
def custom_scheduler(iter, max_iter=None, alpha=10, beta=0.75, init_lr=0.001): |
|
"""Custom LR Annealing |
|
|
|
https://arxiv.org/pdf/1409.7495.pdf |
|
""" |
|
if max_iter is None: |
|
return init_lr |
|
return (1 + float(iter / max_iter) * alpha)**(-1.0 * beta) |
|
|
|
|
|
class AAC(nn.Module): |
|
|
|
def forward(self, sim_mat, prob_u, prob_us): |
|
|
|
P = prob_u.matmul(prob_us.t()) |
|
|
|
loss = -( |
|
sim_mat * torch.log(P + 1e-7) + |
|
(1.-sim_mat) * torch.log(1. - P + 1e-7) |
|
) |
|
return loss.mean() |
|
|
|
|
|
class Prototypes(nn.Module): |
|
|
|
def __init__(self, fdim, num_classes, temp=0.05): |
|
super().__init__() |
|
self.prototypes = nn.Linear(fdim, num_classes, bias=False) |
|
self.temp = temp |
|
self.revgrad = ReverseGrad() |
|
|
|
def forward(self, x, reverse=False): |
|
if reverse: |
|
x = self.revgrad(x) |
|
x = F.normalize(x, p=2, dim=1) |
|
out = self.prototypes(x) |
|
out = out / self.temp |
|
return out |
|
|
|
|
|
@TRAINER_REGISTRY.register() |
|
class CDAC(TrainerXU): |
|
"""Cross Domain Adaptive Clustering. |
|
|
|
https://arxiv.org/pdf/2104.09415.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
self.rampup_coef = cfg.TRAINER.CDAC.RAMPUP_COEF |
|
self.rampup_iters = cfg.TRAINER.CDAC.RAMPUP_ITRS |
|
self.lr_multi = cfg.TRAINER.CDAC.CLASS_LR_MULTI |
|
self.topk = cfg.TRAINER.CDAC.TOPK_MATCH |
|
self.p_thresh = cfg.TRAINER.CDAC.P_THRESH |
|
self.aac_criterion = AAC() |
|
super().__init__(cfg) |
|
|
|
def check_cfg(self, cfg): |
|
assert len( |
|
cfg.TRAINER.CDAC.STRONG_TRANSFORMS |
|
) > 0, "Strong augmentations are necessary to run CDAC" |
|
assert cfg.DATALOADER.K_TRANSFORMS == 2, "CDAC needs two strong augmentations of the same image." |
|
|
|
def build_data_loader(self): |
|
|
|
cfg = self.cfg |
|
tfm_train = build_transform(cfg, is_train=True) |
|
custom_tfm_train = [tfm_train] |
|
choices = cfg.TRAINER.CDAC.STRONG_TRANSFORMS |
|
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) |
|
custom_tfm_train += [tfm_train_strong] |
|
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) |
|
self.train_loader_x = self.dm.train_loader_x |
|
self.train_loader_u = self.dm.train_loader_u |
|
self.val_loader = self.dm.val_loader |
|
self.test_loader = self.dm.test_loader |
|
self.num_classes = self.dm.num_classes |
|
self.lab2cname = self.dm.lab2cname |
|
|
|
def build_model(self): |
|
cfg = self.cfg |
|
|
|
|
|
if self.cfg.TRAIN.COUNT_ITER == "train_x": |
|
self.num_batches = len(self.train_loader_x) |
|
elif self.cfg.TRAIN.COUNT_ITER == "train_u": |
|
self.num_batches = len(self.len_train_loader_u) |
|
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one": |
|
self.num_batches = min( |
|
len(self.train_loader_x), len(self.train_loader_u) |
|
) |
|
self.max_iter = self.max_epoch * self.num_batches |
|
print("Max Iterations: %d" % self.max_iter) |
|
|
|
print("Building F") |
|
self.F = SimpleNet(cfg, cfg.MODEL, 0) |
|
self.F.to(self.device) |
|
print("# params: {:,}".format(count_num_param(self.F))) |
|
self.optim_F = build_optimizer(self.F, cfg.OPTIM) |
|
custom_lr_F = partial( |
|
custom_scheduler, max_iter=self.max_iter, init_lr=cfg.OPTIM.LR |
|
) |
|
self.sched_F = LambdaLR(self.optim_F, custom_lr_F) |
|
self.register_model("F", self.F, self.optim_F, self.sched_F) |
|
|
|
print("Building C") |
|
self.C = Prototypes(self.F.fdim, self.num_classes) |
|
self.C.to(self.device) |
|
print("# params: {:,}".format(count_num_param(self.C))) |
|
self.optim_C = build_optimizer(self.C, cfg.OPTIM) |
|
|
|
|
|
for group_param in self.optim_C.param_groups: |
|
group_param['lr'] *= self.lr_multi |
|
custom_lr_C = partial( |
|
custom_scheduler, |
|
max_iter=self.max_iter, |
|
init_lr=cfg.OPTIM.LR * self.lr_multi |
|
) |
|
self.sched_C = LambdaLR(self.optim_C, custom_lr_C) |
|
self.register_model("C", self.C, self.optim_C, self.sched_C) |
|
|
|
def assess_y_pred_quality(self, y_pred, y_true, mask): |
|
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum() |
|
acc_thre = n_masked_correct / (mask.sum() + 1e-5) |
|
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() |
|
keep_rate = mask.sum() / mask.numel() |
|
output = { |
|
"acc_thre": acc_thre, |
|
"acc_raw": acc_raw, |
|
"keep_rate": keep_rate |
|
} |
|
return output |
|
|
|
def forward_backward(self, batch_x, batch_u): |
|
|
|
current_itr = self.epoch * self.num_batches + self.batch_idx |
|
|
|
input_x, label_x, input_u, input_us, input_us2, label_u = self.parse_batch_train( |
|
batch_x, batch_u |
|
) |
|
|
|
|
|
|
|
feat_x = self.F(input_x) |
|
logit_x = self.C(feat_x) |
|
loss_x = F.cross_entropy(logit_x, label_x) |
|
|
|
self.model_backward_and_update(loss_x) |
|
|
|
feat_u = self.F(input_u) |
|
feat_us = self.F(input_us) |
|
feat_us2 = self.F(input_us2) |
|
|
|
|
|
logit_u = self.C(feat_u, reverse=True) |
|
logit_us = self.C(feat_us, reverse=True) |
|
prob_u, prob_us = F.softmax(logit_u, dim=1), F.softmax(logit_us, dim=1) |
|
|
|
|
|
sim_mat = self.get_similarity_matrix(feat_u, self.topk, self.device) |
|
|
|
aac_loss = (-1. * self.aac_criterion(sim_mat, prob_u, prob_us)) |
|
|
|
|
|
logit_u = self.C(feat_u) |
|
logit_us = self.C(feat_us) |
|
logit_us2 = self.C(feat_us2) |
|
prob_u, prob_us, prob_us2 = F.softmax( |
|
logit_u, dim=1 |
|
), F.softmax( |
|
logit_us, dim=1 |
|
), F.softmax( |
|
logit_us2, dim=1 |
|
) |
|
prob_u = prob_u.detach() |
|
max_probs, max_idx = torch.max(prob_u, dim=-1) |
|
mask = max_probs.ge(self.p_thresh).float() |
|
p_u_stats = self.assess_y_pred_quality(max_idx, label_u, mask) |
|
|
|
pl_loss = ( |
|
F.cross_entropy(logit_us2, max_idx, reduction='none') * mask |
|
).mean() |
|
|
|
|
|
cons_multi = self.sigmoid_rampup( |
|
current_itr=current_itr, rampup_itr=self.rampup_iters |
|
) * self.rampup_coef |
|
cons_loss = cons_multi * F.mse_loss(prob_us, prob_us2) |
|
|
|
loss_u = aac_loss + pl_loss + cons_loss |
|
|
|
self.model_backward_and_update(loss_u) |
|
|
|
loss_summary = { |
|
"loss_x": loss_x.item(), |
|
"acc_x": compute_accuracy(logit_x, label_x)[0].item(), |
|
"loss_u": loss_u.item(), |
|
"aac_loss": aac_loss.item(), |
|
"pl_loss": pl_loss.item(), |
|
"cons_loss": cons_loss.item(), |
|
"p_u_pred_acc": p_u_stats["acc_raw"], |
|
"p_u_pred_acc_thre": p_u_stats["acc_thre"], |
|
"p_u_pred_keep": p_u_stats["keep_rate"] |
|
} |
|
|
|
|
|
|
|
self.update_lr() |
|
|
|
return loss_summary |
|
|
|
def parse_batch_train(self, batch_x, batch_u): |
|
|
|
input_x = batch_x["img"][0] |
|
label_x = batch_x["label"] |
|
|
|
input_u = batch_u["img"][0] |
|
input_us = batch_u["img2"][0] |
|
input_us2 = batch_u["img2"][1] |
|
label_u = batch_u["label"] |
|
|
|
input_x = input_x.to(self.device) |
|
label_x = label_x.to(self.device) |
|
|
|
input_u = input_u.to(self.device) |
|
input_us = input_us.to(self.device) |
|
input_us2 = input_us2.to(self.device) |
|
label_u = label_u.to(self.device) |
|
|
|
return input_x, label_x, input_u, input_us, input_us2, label_u |
|
|
|
def model_inference(self, input): |
|
return self.C(self.F(input)) |
|
|
|
@staticmethod |
|
def get_similarity_matrix(feat, topk, device): |
|
|
|
feat_d = feat.detach() |
|
|
|
feat_d = torch.sort( |
|
torch.argsort(feat_d, dim=1, descending=True)[:, :topk], dim=1 |
|
)[0] |
|
sim_mat = torch.zeros((feat_d.shape[0], feat_d.shape[0])).to(device) |
|
for row in range(feat_d.shape[0]): |
|
sim_mat[row, torch.all(feat_d == feat_d[row, :], dim=1)] = 1 |
|
return sim_mat |
|
|
|
@staticmethod |
|
def sigmoid_rampup(current_itr, rampup_itr): |
|
"""Exponential Rampup |
|
https://arxiv.org/abs/1610.02242 |
|
""" |
|
if rampup_itr == 0: |
|
return 1.0 |
|
else: |
|
var = np.clip(current_itr, 0.0, rampup_itr) |
|
phase = 1.0 - var/rampup_itr |
|
return float(np.exp(-5.0 * phase * phase)) |
|
|