spandey8's picture
Upload 3 files
3650b90 verified
raw
history blame contribute delete
286 Bytes
from .functional import revgrad
import torch
from torch import nn
class GradientReversal(nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = torch.tensor(alpha, requires_grad=False)
def forward(self, x):
return revgrad(x, self.alpha)