from torch.autograd import Function import torch from torch import nn class GradientReversal(Function): @staticmethod def forward(ctx, x, alpha): ctx.save_for_backward(x, alpha) return x @staticmethod def backward(ctx, grad_output): grad_input = None _, alpha = ctx.saved_tensors if ctx.needs_input_grad[0]: grad_input = -alpha * grad_output return grad_input, None revgrad = GradientReversal.apply class GradientReversal(nn.Module): def __init__(self, alpha): super().__init__() self.alpha = torch.tensor(alpha, requires_grad=False) def forward(self, x): return revgrad(x, self.alpha)