from typing import Tuple, Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F class BaseAdversarialLoss: def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, generator: nn.Module, discriminator: nn.Module): """ Prepare for generator step :param real_batch: Tensor, a batch of real samples :param fake_batch: Tensor, a batch of samples produced by generator :param generator: :param discriminator: :return: None """ def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, generator: nn.Module, discriminator: nn.Module): """ Prepare for discriminator step :param real_batch: Tensor, a batch of real samples :param fake_batch: Tensor, a batch of samples produced by generator :param generator: :param discriminator: :return: None """ def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, mask: Optional[torch.Tensor] = None) \ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Calculate generator loss :param real_batch: Tensor, a batch of real samples :param fake_batch: Tensor, a batch of samples produced by generator :param discr_real_pred: Tensor, discriminator output for real_batch :param discr_fake_pred: Tensor, discriminator output for fake_batch :param mask: Tensor, actual mask, which was at input of generator when making fake_batch :return: total generator loss along with some values that might be interesting to log """ raise NotImplemented() def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, mask: Optional[torch.Tensor] = None) \ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Calculate discriminator loss and call .backward() on it :param real_batch: Tensor, a batch of real samples :param fake_batch: Tensor, a batch of samples produced by generator :param discr_real_pred: Tensor, discriminator output for real_batch :param discr_fake_pred: Tensor, discriminator output for fake_batch :param mask: Tensor, actual mask, which was at input of generator when making fake_batch :return: total discriminator loss along with some values that might be interesting to log """ raise NotImplemented() def interpolate_mask(self, mask, shape): assert mask is not None assert self.allow_scale_mask or shape == mask.shape[-2:] if shape != mask.shape[-2:] and self.allow_scale_mask: if self.mask_scale_mode == 'maxpool': mask = F.adaptive_max_pool2d(mask, shape) else: mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode) return mask def make_r1_gp(discr_real_pred, real_batch): if torch.is_grad_enabled(): grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0] grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean() else: grad_penalty = 0 real_batch.requires_grad = False return grad_penalty class NonSaturatingWithR1(BaseAdversarialLoss): def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False, mask_scale_mode='nearest', extra_mask_weight_for_gen=0, use_unmasked_for_gen=True, use_unmasked_for_discr=True): self.gp_coef = gp_coef self.weight = weight # use for discr => use for gen; # otherwise we teach only the discr to pay attention to very small difference assert use_unmasked_for_gen or (not use_unmasked_for_discr) # mask as target => use unmasked for discr: # if we don't care about unmasked regions at all # then it doesn't matter if the value of mask_as_fake_target is true or false assert use_unmasked_for_discr or (not mask_as_fake_target) self.use_unmasked_for_gen = use_unmasked_for_gen self.use_unmasked_for_discr = use_unmasked_for_discr self.mask_as_fake_target = mask_as_fake_target self.allow_scale_mask = allow_scale_mask self.mask_scale_mode = mask_scale_mode self.extra_mask_weight_for_gen = extra_mask_weight_for_gen def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, mask=None) \ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: fake_loss = F.softplus(-discr_fake_pred) if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ not self.use_unmasked_for_gen: # == if masked region should be treated differently mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) if not self.use_unmasked_for_gen: fake_loss = fake_loss * mask else: pixel_weights = 1 + mask * self.extra_mask_weight_for_gen fake_loss = fake_loss * pixel_weights return fake_loss.mean() * self.weight, dict() def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, generator: nn.Module, discriminator: nn.Module): real_batch.requires_grad = True def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, mask=None) \ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: real_loss = F.softplus(-discr_real_pred) grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef fake_loss = F.softplus(discr_fake_pred) if not self.use_unmasked_for_discr or self.mask_as_fake_target: # == if masked region should be treated differently mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) # use_unmasked_for_discr=False only makes sense for fakes; # for reals there is no difference beetween two regions fake_loss = fake_loss * mask if self.mask_as_fake_target: fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred) sum_discr_loss = real_loss + grad_penalty + fake_loss metrics = dict(discr_real_out=discr_real_pred.mean(), discr_fake_out=discr_fake_pred.mean(), discr_real_gp=grad_penalty) return sum_discr_loss.mean(), metrics class BCELoss(BaseAdversarialLoss): def __init__(self, weight): self.weight = weight self.bce_loss = nn.BCEWithLogitsLoss() def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device) fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight return fake_loss, dict() def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, generator: nn.Module, discriminator: nn.Module): real_batch.requires_grad = True def discriminator_loss(self, mask: torch.Tensor, discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device) sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2 metrics = dict(discr_real_out=discr_real_pred.mean(), discr_fake_out=discr_fake_pred.mean(), discr_real_gp=0) return sum_discr_loss, metrics def make_discrim_loss(kind, **kwargs): if kind == 'r1': return NonSaturatingWithR1(**kwargs) elif kind == 'bce': return BCELoss(**kwargs) raise ValueError(f'Unknown adversarial loss kind {kind}')