|
import torch |
|
|
|
from dassl.utils import check_isfile |
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU |
|
|
|
|
|
@TRAINER_REGISTRY.register() |
|
class AdaBN(TrainerXU): |
|
"""Adaptive Batch Normalization. |
|
|
|
https://arxiv.org/abs/1603.04779. |
|
""" |
|
|
|
def __init__(self, cfg): |
|
super().__init__(cfg) |
|
self.done_reset_bn_stats = False |
|
|
|
def check_cfg(self, cfg): |
|
assert check_isfile( |
|
cfg.MODEL.INIT_WEIGHTS |
|
), "The weights of source model must be provided" |
|
|
|
def before_epoch(self): |
|
if not self.done_reset_bn_stats: |
|
for m in self.model.modules(): |
|
classname = m.__class__.__name__ |
|
if classname.find("BatchNorm") != -1: |
|
m.reset_running_stats() |
|
|
|
self.done_reset_bn_stats = True |
|
|
|
def forward_backward(self, batch_x, batch_u): |
|
input_u = batch_u["img"].to(self.device) |
|
|
|
with torch.no_grad(): |
|
self.model(input_u) |
|
|
|
return None |
|
|