Spaces:
Runtime error
Runtime error
File size: 498 Bytes
1ba3df3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
import torch.nn as nn
class GANHingeLoss(nn.Module):
def __init__(self):
super(GANHingeLoss, self).__init__()
self.relu = nn.ReLU()
def __call__(self, pred, is_real, for_discriminator):
if for_discriminator:
if is_real:
return self.relu(1 - pred).mean()
return self.relu(1 + pred).mean()
assert is_real, "The generator's hinge loss must be aiming for real"
return -1.0 * pred.mean() |