File size: 445 Bytes
3650b90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from torch.autograd import Function
class GradientReversal(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return x
@staticmethod
def backward(ctx, grad_output):
grad_input = None
_, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = - alpha*grad_output
return grad_input, None
revgrad = GradientReversal.apply |