from torch.autograd import Function | |
class GradientReversal(Function): | |
def forward(ctx, x, alpha): | |
ctx.save_for_backward(x, alpha) | |
return x | |
def backward(ctx, grad_output): | |
grad_input = None | |
_, alpha = ctx.saved_tensors | |
if ctx.needs_input_grad[0]: | |
grad_input = - alpha*grad_output | |
return grad_input, None | |
revgrad = GradientReversal.apply |