spandey8's picture
Upload 3 files
3650b90 verified
raw
history blame contribute delete
445 Bytes
from torch.autograd import Function
class GradientReversal(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return x
@staticmethod
def backward(ctx, grad_output):
grad_input = None
_, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = - alpha*grad_output
return grad_input, None
revgrad = GradientReversal.apply