import copy import torch import torch.nn as nn from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import check_isfile, count_num_param, open_specified_layers from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.modeling import build_head @TRAINER_REGISTRY.register() class ADDA(TrainerXU): """Adversarial Discriminative Domain Adaptation. https://arxiv.org/abs/1702.05464. """ def __init__(self, cfg): super().__init__(cfg) self.open_layers = ["backbone"] if isinstance(self.model.head, nn.Module): self.open_layers.append("head") self.source_model = copy.deepcopy(self.model) self.source_model.eval() for param in self.source_model.parameters(): param.requires_grad_(False) self.build_critic() self.bce = nn.BCEWithLogitsLoss() def check_cfg(self, cfg): assert check_isfile( cfg.MODEL.INIT_WEIGHTS ), "The weights of source model must be provided" def build_critic(self): cfg = self.cfg print("Building critic network") fdim = self.model.fdim critic_body = build_head( "mlp", verbose=cfg.VERBOSE, in_features=fdim, hidden_layers=[fdim, fdim // 2], activation="leaky_relu", ) self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1)) print("# params: {:,}".format(count_num_param(self.critic))) self.critic.to(self.device) self.optim_c = build_optimizer(self.critic, cfg.OPTIM) self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) self.register_model("critic", self.critic, self.optim_c, self.sched_c) def forward_backward(self, batch_x, batch_u): open_specified_layers(self.model, self.open_layers) input_x, _, input_u = self.parse_batch_train(batch_x, batch_u) domain_x = torch.ones(input_x.shape[0], 1).to(self.device) domain_u = torch.zeros(input_u.shape[0], 1).to(self.device) _, feat_x = self.source_model(input_x, return_feature=True) _, feat_u = self.model(input_u, return_feature=True) logit_xd = self.critic(feat_x) logit_ud = self.critic(feat_u.detach()) loss_critic = self.bce(logit_xd, domain_x) loss_critic += self.bce(logit_ud, domain_u) self.model_backward_and_update(loss_critic, "critic") logit_ud = self.critic(feat_u) loss_model = self.bce(logit_ud, 1 - domain_u) self.model_backward_and_update(loss_model, "model") loss_summary = { "loss_critic": loss_critic.item(), "loss_model": loss_model.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary