|
import random |
|
from contextlib import contextmanager |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def deactivate_mixstyle(m): |
|
if type(m) == MixStyle: |
|
m.set_activation_status(False) |
|
|
|
|
|
def activate_mixstyle(m): |
|
if type(m) == MixStyle: |
|
m.set_activation_status(True) |
|
|
|
|
|
def random_mixstyle(m): |
|
if type(m) == MixStyle: |
|
m.update_mix_method("random") |
|
|
|
|
|
def crossdomain_mixstyle(m): |
|
if type(m) == MixStyle: |
|
m.update_mix_method("crossdomain") |
|
|
|
|
|
@contextmanager |
|
def run_without_mixstyle(model): |
|
|
|
try: |
|
model.apply(deactivate_mixstyle) |
|
yield |
|
finally: |
|
model.apply(activate_mixstyle) |
|
|
|
|
|
@contextmanager |
|
def run_with_mixstyle(model, mix=None): |
|
|
|
if mix == "random": |
|
model.apply(random_mixstyle) |
|
|
|
elif mix == "crossdomain": |
|
model.apply(crossdomain_mixstyle) |
|
|
|
try: |
|
model.apply(activate_mixstyle) |
|
yield |
|
finally: |
|
model.apply(deactivate_mixstyle) |
|
|
|
|
|
class MixStyle(nn.Module): |
|
"""MixStyle. |
|
|
|
Reference: |
|
Zhou et al. Domain Generalization with MixStyle. ICLR 2021. |
|
""" |
|
|
|
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"): |
|
""" |
|
Args: |
|
p (float): probability of using MixStyle. |
|
alpha (float): parameter of the Beta distribution. |
|
eps (float): scaling parameter to avoid numerical issues. |
|
mix (str): how to mix. |
|
""" |
|
super().__init__() |
|
self.p = p |
|
self.beta = torch.distributions.Beta(alpha, alpha) |
|
self.eps = eps |
|
self.alpha = alpha |
|
self.mix = mix |
|
self._activated = True |
|
|
|
def __repr__(self): |
|
return ( |
|
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})" |
|
) |
|
|
|
def set_activation_status(self, status=True): |
|
self._activated = status |
|
|
|
def update_mix_method(self, mix="random"): |
|
self.mix = mix |
|
|
|
def forward(self, x): |
|
if not self.training or not self._activated: |
|
return x |
|
|
|
if random.random() > self.p: |
|
return x |
|
|
|
B = x.size(0) |
|
|
|
mu = x.mean(dim=[2, 3], keepdim=True) |
|
var = x.var(dim=[2, 3], keepdim=True) |
|
sig = (var + self.eps).sqrt() |
|
mu, sig = mu.detach(), sig.detach() |
|
x_normed = (x-mu) / sig |
|
|
|
lmda = self.beta.sample((B, 1, 1, 1)) |
|
lmda = lmda.to(x.device) |
|
|
|
if self.mix == "random": |
|
|
|
perm = torch.randperm(B) |
|
|
|
elif self.mix == "crossdomain": |
|
|
|
perm = torch.arange(B - 1, -1, -1) |
|
perm_b, perm_a = perm.chunk(2) |
|
perm_b = perm_b[torch.randperm(perm_b.shape[0])] |
|
perm_a = perm_a[torch.randperm(perm_a.shape[0])] |
|
perm = torch.cat([perm_b, perm_a], 0) |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
mu2, sig2 = mu[perm], sig[perm] |
|
mu_mix = mu*lmda + mu2 * (1-lmda) |
|
sig_mix = sig*lmda + sig2 * (1-lmda) |
|
|
|
return x_normed*sig_mix + mu_mix |
|
|