unpairedelectron07 commited on
Commit
f2a28c0
1 Parent(s): d61af5c

Upload losses.py

Browse files
Files changed (1) hide show
  1. audiocraft/adversarial/losses.py +228 -0
audiocraft/adversarial/losses.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility module to handle adversarial losses without requiring to mess up the main training loop.
9
+ """
10
+
11
+ import typing as tp
12
+
13
+ import flashy
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
20
+
21
+
22
+ AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
23
+ FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
24
+
25
+
26
+ class AdversarialLoss(nn.Module):
27
+ """Adversary training wrapper.
28
+
29
+ Args:
30
+ adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
31
+ We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
32
+ where the first item is a list of logits and the second item is a list of feature maps.
33
+ optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
34
+ loss (AdvLossType): Loss function for generator training.
35
+ loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
36
+ loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
37
+ loss_feat (FeatLossType): Feature matching loss function for generator training.
38
+ normalize (bool): Whether to normalize by number of sub-discriminators.
39
+
40
+ Example of usage:
41
+ adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
42
+ for real in loader:
43
+ noise = torch.randn(...)
44
+ fake = model(noise)
45
+ adv_loss.train_adv(fake, real)
46
+ loss, _ = adv_loss(fake, real)
47
+ loss.backward()
48
+ """
49
+ def __init__(self,
50
+ adversary: nn.Module,
51
+ optimizer: torch.optim.Optimizer,
52
+ loss: AdvLossType,
53
+ loss_real: AdvLossType,
54
+ loss_fake: AdvLossType,
55
+ loss_feat: tp.Optional[FeatLossType] = None,
56
+ normalize: bool = True):
57
+ super().__init__()
58
+ self.adversary: nn.Module = adversary
59
+ flashy.distrib.broadcast_model(self.adversary)
60
+ self.optimizer = optimizer
61
+ self.loss = loss
62
+ self.loss_real = loss_real
63
+ self.loss_fake = loss_fake
64
+ self.loss_feat = loss_feat
65
+ self.normalize = normalize
66
+
67
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
68
+ # Add the optimizer state dict inside our own.
69
+ super()._save_to_state_dict(destination, prefix, keep_vars)
70
+ destination[prefix + 'optimizer'] = self.optimizer.state_dict()
71
+ return destination
72
+
73
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
74
+ # Load optimizer state.
75
+ self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
76
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
77
+
78
+ def get_adversary_pred(self, x):
79
+ """Run adversary model, validating expected output format."""
80
+ logits, fmaps = self.adversary(x)
81
+ assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
82
+ f'Expecting a list of tensors as logits but {type(logits)} found.'
83
+ assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
84
+ for fmap in fmaps:
85
+ assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
86
+ f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
87
+ return logits, fmaps
88
+
89
+ def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
90
+ """Train the adversary with the given fake and real example.
91
+
92
+ We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
93
+ The first item being the logits and second item being a list of feature maps for each sub-discriminator.
94
+
95
+ This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
96
+ and call the optimizer.
97
+ """
98
+ loss = torch.tensor(0., device=fake.device)
99
+ all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
100
+ all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
101
+ n_sub_adversaries = len(all_logits_fake_is_fake)
102
+ for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
103
+ loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
104
+
105
+ if self.normalize:
106
+ loss /= n_sub_adversaries
107
+
108
+ self.optimizer.zero_grad()
109
+ with flashy.distrib.eager_sync_model(self.adversary):
110
+ loss.backward()
111
+ self.optimizer.step()
112
+
113
+ return loss
114
+
115
+ def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
116
+ """Return the loss for the generator, i.e. trying to fool the adversary,
117
+ and feature matching loss if provided.
118
+ """
119
+ adv = torch.tensor(0., device=fake.device)
120
+ feat = torch.tensor(0., device=fake.device)
121
+ with flashy.utils.readonly(self.adversary):
122
+ all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
123
+ all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
124
+ n_sub_adversaries = len(all_logits_fake_is_fake)
125
+ for logit_fake_is_fake in all_logits_fake_is_fake:
126
+ adv += self.loss(logit_fake_is_fake)
127
+ if self.loss_feat:
128
+ for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
129
+ feat += self.loss_feat(fmap_fake, fmap_real)
130
+
131
+ if self.normalize:
132
+ adv /= n_sub_adversaries
133
+ feat /= n_sub_adversaries
134
+
135
+ return adv, feat
136
+
137
+
138
+ def get_adv_criterion(loss_type: str) -> tp.Callable:
139
+ assert loss_type in ADVERSARIAL_LOSSES
140
+ if loss_type == 'mse':
141
+ return mse_loss
142
+ elif loss_type == 'hinge':
143
+ return hinge_loss
144
+ elif loss_type == 'hinge2':
145
+ return hinge2_loss
146
+ raise ValueError('Unsupported loss')
147
+
148
+
149
+ def get_fake_criterion(loss_type: str) -> tp.Callable:
150
+ assert loss_type in ADVERSARIAL_LOSSES
151
+ if loss_type == 'mse':
152
+ return mse_fake_loss
153
+ elif loss_type in ['hinge', 'hinge2']:
154
+ return hinge_fake_loss
155
+ raise ValueError('Unsupported loss')
156
+
157
+
158
+ def get_real_criterion(loss_type: str) -> tp.Callable:
159
+ assert loss_type in ADVERSARIAL_LOSSES
160
+ if loss_type == 'mse':
161
+ return mse_real_loss
162
+ elif loss_type in ['hinge', 'hinge2']:
163
+ return hinge_real_loss
164
+ raise ValueError('Unsupported loss')
165
+
166
+
167
+ def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
168
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
169
+
170
+
171
+ def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
172
+ return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
173
+
174
+
175
+ def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
176
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
177
+
178
+
179
+ def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
180
+ return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
181
+
182
+
183
+ def mse_loss(x: torch.Tensor) -> torch.Tensor:
184
+ if x.numel() == 0:
185
+ return torch.tensor([0.0], device=x.device)
186
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
187
+
188
+
189
+ def hinge_loss(x: torch.Tensor) -> torch.Tensor:
190
+ if x.numel() == 0:
191
+ return torch.tensor([0.0], device=x.device)
192
+ return -x.mean()
193
+
194
+
195
+ def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
196
+ if x.numel() == 0:
197
+ return torch.tensor([0.0])
198
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
199
+
200
+
201
+ class FeatureMatchingLoss(nn.Module):
202
+ """Feature matching loss for adversarial training.
203
+
204
+ Args:
205
+ loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
206
+ normalize (bool): Whether to normalize the loss.
207
+ by number of feature maps.
208
+ """
209
+ def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
210
+ super().__init__()
211
+ self.loss = loss
212
+ self.normalize = normalize
213
+
214
+ def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
215
+ assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
216
+ feat_loss = torch.tensor(0., device=fmap_fake[0].device)
217
+ feat_scale = torch.tensor(0., device=fmap_fake[0].device)
218
+ n_fmaps = 0
219
+ for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
220
+ assert feat_fake.shape == feat_real.shape
221
+ n_fmaps += 1
222
+ feat_loss += self.loss(feat_fake, feat_real)
223
+ feat_scale += torch.mean(torch.abs(feat_real))
224
+
225
+ if self.normalize:
226
+ feat_loss /= n_fmaps
227
+
228
+ return feat_loss