File size: 9,126 Bytes
5325fcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility module to handle adversarial losses without requiring to mess up the main training loop.
"""
import typing as tp
import flashy
import torch
import torch.nn as nn
import torch.nn.functional as F
ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
class AdversarialLoss(nn.Module):
"""Adversary training wrapper.
Args:
adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
where the first item is a list of logits and the second item is a list of feature maps.
optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
loss (AdvLossType): Loss function for generator training.
loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
loss_feat (FeatLossType): Feature matching loss function for generator training.
normalize (bool): Whether to normalize by number of sub-discriminators.
Example of usage:
adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
for real in loader:
noise = torch.randn(...)
fake = model(noise)
adv_loss.train_adv(fake, real)
loss, _ = adv_loss(fake, real)
loss.backward()
"""
def __init__(self,
adversary: nn.Module,
optimizer: torch.optim.Optimizer,
loss: AdvLossType,
loss_real: AdvLossType,
loss_fake: AdvLossType,
loss_feat: tp.Optional[FeatLossType] = None,
normalize: bool = True):
super().__init__()
self.adversary: nn.Module = adversary
flashy.distrib.broadcast_model(self.adversary)
self.optimizer = optimizer
self.loss = loss
self.loss_real = loss_real
self.loss_fake = loss_fake
self.loss_feat = loss_feat
self.normalize = normalize
def _save_to_state_dict(self, destination, prefix, keep_vars):
# Add the optimizer state dict inside our own.
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'optimizer'] = self.optimizer.state_dict()
return destination
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Load optimizer state.
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def get_adversary_pred(self, x):
"""Run adversary model, validating expected output format."""
logits, fmaps = self.adversary(x)
assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
f'Expecting a list of tensors as logits but {type(logits)} found.'
assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
for fmap in fmaps:
assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
return logits, fmaps
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
"""Train the adversary with the given fake and real example.
We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
The first item being the logits and second item being a list of feature maps for each sub-discriminator.
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
and call the optimizer.
"""
loss = torch.tensor(0., device=fake.device)
all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
n_sub_adversaries = len(all_logits_fake_is_fake)
for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
if self.normalize:
loss /= n_sub_adversaries
self.optimizer.zero_grad()
with flashy.distrib.eager_sync_model(self.adversary):
loss.backward()
self.optimizer.step()
return loss
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Return the loss for the generator, i.e. trying to fool the adversary,
and feature matching loss if provided.
"""
adv = torch.tensor(0., device=fake.device)
feat = torch.tensor(0., device=fake.device)
with flashy.utils.readonly(self.adversary):
all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
n_sub_adversaries = len(all_logits_fake_is_fake)
for logit_fake_is_fake in all_logits_fake_is_fake:
adv += self.loss(logit_fake_is_fake)
if self.loss_feat:
for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
feat += self.loss_feat(fmap_fake, fmap_real)
if self.normalize:
adv /= n_sub_adversaries
feat /= n_sub_adversaries
return adv, feat
def get_adv_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_loss
elif loss_type == 'hinge':
return hinge_loss
elif loss_type == 'hinge2':
return hinge2_loss
raise ValueError('Unsupported loss')
def get_fake_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_fake_loss
elif loss_type in ['hinge', 'hinge2']:
return hinge_fake_loss
raise ValueError('Unsupported loss')
def get_real_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_real_loss
elif loss_type in ['hinge', 'hinge2']:
return hinge_real_loss
raise ValueError('Unsupported loss')
def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
def mse_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0], device=x.device)
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
def hinge_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0], device=x.device)
return -x.mean()
def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0])
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
class FeatureMatchingLoss(nn.Module):
"""Feature matching loss for adversarial training.
Args:
loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
normalize (bool): Whether to normalize the loss.
by number of feature maps.
"""
def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
super().__init__()
self.loss = loss
self.normalize = normalize
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
feat_loss = torch.tensor(0., device=fmap_fake[0].device)
feat_scale = torch.tensor(0., device=fmap_fake[0].device)
n_fmaps = 0
for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
assert feat_fake.shape == feat_real.shape
n_fmaps += 1
feat_loss += self.loss(feat_fake, feat_real)
feat_scale += torch.mean(torch.abs(feat_real))
if self.normalize:
feat_loss /= n_fmaps
return feat_loss
|