zyliu's picture
release iChatApp
0f90f73
raw
history blame
8.59 kB
from typing import Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaseAdversarialLoss:
def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
"""
Prepare for generator step
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param generator:
:param discriminator:
:return: None
"""
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
"""
Prepare for discriminator step
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param generator:
:param discriminator:
:return: None
"""
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask: Optional[torch.Tensor] = None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Calculate generator loss
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param discr_real_pred: Tensor, discriminator output for real_batch
:param discr_fake_pred: Tensor, discriminator output for fake_batch
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
:return: total generator loss along with some values that might be interesting to log
"""
raise NotImplemented()
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask: Optional[torch.Tensor] = None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Calculate discriminator loss and call .backward() on it
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param discr_real_pred: Tensor, discriminator output for real_batch
:param discr_fake_pred: Tensor, discriminator output for fake_batch
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
:return: total discriminator loss along with some values that might be interesting to log
"""
raise NotImplemented()
def interpolate_mask(self, mask, shape):
assert mask is not None
assert self.allow_scale_mask or shape == mask.shape[-2:]
if shape != mask.shape[-2:] and self.allow_scale_mask:
if self.mask_scale_mode == 'maxpool':
mask = F.adaptive_max_pool2d(mask, shape)
else:
mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode)
return mask
def make_r1_gp(discr_real_pred, real_batch):
if torch.is_grad_enabled():
grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0]
grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean()
else:
grad_penalty = 0
real_batch.requires_grad = False
return grad_penalty
class NonSaturatingWithR1(BaseAdversarialLoss):
def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False,
mask_scale_mode='nearest', extra_mask_weight_for_gen=0,
use_unmasked_for_gen=True, use_unmasked_for_discr=True):
self.gp_coef = gp_coef
self.weight = weight
# use for discr => use for gen;
# otherwise we teach only the discr to pay attention to very small difference
assert use_unmasked_for_gen or (not use_unmasked_for_discr)
# mask as target => use unmasked for discr:
# if we don't care about unmasked regions at all
# then it doesn't matter if the value of mask_as_fake_target is true or false
assert use_unmasked_for_discr or (not mask_as_fake_target)
self.use_unmasked_for_gen = use_unmasked_for_gen
self.use_unmasked_for_discr = use_unmasked_for_discr
self.mask_as_fake_target = mask_as_fake_target
self.allow_scale_mask = allow_scale_mask
self.mask_scale_mode = mask_scale_mode
self.extra_mask_weight_for_gen = extra_mask_weight_for_gen
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask=None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
fake_loss = F.softplus(-discr_fake_pred)
if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \
not self.use_unmasked_for_gen: # == if masked region should be treated differently
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
if not self.use_unmasked_for_gen:
fake_loss = fake_loss * mask
else:
pixel_weights = 1 + mask * self.extra_mask_weight_for_gen
fake_loss = fake_loss * pixel_weights
return fake_loss.mean() * self.weight, dict()
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
real_batch.requires_grad = True
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask=None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
real_loss = F.softplus(-discr_real_pred)
grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef
fake_loss = F.softplus(discr_fake_pred)
if not self.use_unmasked_for_discr or self.mask_as_fake_target:
# == if masked region should be treated differently
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
# use_unmasked_for_discr=False only makes sense for fakes;
# for reals there is no difference beetween two regions
fake_loss = fake_loss * mask
if self.mask_as_fake_target:
fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred)
sum_discr_loss = real_loss + grad_penalty + fake_loss
metrics = dict(discr_real_out=discr_real_pred.mean(),
discr_fake_out=discr_fake_pred.mean(),
discr_real_gp=grad_penalty)
return sum_discr_loss.mean(), metrics
class BCELoss(BaseAdversarialLoss):
def __init__(self, weight):
self.weight = weight
self.bce_loss = nn.BCEWithLogitsLoss()
def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device)
fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight
return fake_loss, dict()
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
real_batch.requires_grad = True
def discriminator_loss(self,
mask: torch.Tensor,
discr_real_pred: torch.Tensor,
discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device)
sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2
metrics = dict(discr_real_out=discr_real_pred.mean(),
discr_fake_out=discr_fake_pred.mean(),
discr_real_gp=0)
return sum_discr_loss, metrics
def make_discrim_loss(kind, **kwargs):
if kind == 'r1':
return NonSaturatingWithR1(**kwargs)
elif kind == 'bce':
return BCELoss(**kwargs)
raise ValueError(f'Unknown adversarial loss kind {kind}')