Mehdi Cherti commited on
Commit
e96a195
1 Parent(s): 2ab447a
EMA.py CHANGED
@@ -21,8 +21,16 @@ class EMA(Optimizer):
21
  self.optimizer = opt
22
  self.state = opt.state
23
  self.param_groups = opt.param_groups
 
24
 
25
  def step(self, *args, **kwargs):
 
 
 
 
 
 
 
26
  retval = self.optimizer.step(*args, **kwargs)
27
 
28
  # stop here if we are not applying EMA
 
21
  self.optimizer = opt
22
  self.state = opt.state
23
  self.param_groups = opt.param_groups
24
+ self.defaults = {}
25
 
26
  def step(self, *args, **kwargs):
27
+ # for group in self.optimizer.param_groups:
28
+ # group.setdefault('amsgrad', False)
29
+ # group.setdefault('maximize', False)
30
+ # group.setdefault('foreach', None)
31
+ # group.setdefault('capturable', False)
32
+ # group.setdefault('differentiable', False)
33
+ # group.setdefault('fused', False)
34
  retval = self.optimizer.step(*args, **kwargs)
35
 
36
  # stop here if we are not applying EMA
eval_all.sh CHANGED
@@ -1,7 +1,17 @@
1
  #!/bin/bash
2
- for model in ddgan_sd_v10 ddgan_laion2b_v2 ddgan_ddb_v1 ddgan_ddb_v2 ddgan_ddb_v3;do
3
- if [ "$model" == "$ddgan_ddb_v3" ]; then
 
 
 
 
4
  bs=32
 
 
 
 
 
 
5
  else
6
  bs=64
7
  fi
 
1
  #!/bin/bash
2
+ #for model in ddgan_sd_v10 ddgan_laion2b_v2 ddgan_ddb_v1 ddgan_ddb_v2 ddgan_ddb_v3 ddgan_ddb_v4;do
3
+ #for model in ddgan_ddb_v2 ddgan_ddb_v3 ddgan_ddb_v4 ddgan_ddb_v5;do
4
+ #for model in ddgan_ddb_v4 ddgan_ddb_v6 ddgan_ddb_v7 ddgan_laion_aesthetic_v15;do
5
+ #for model in ddgan_ddb_v6;do
6
+ for model in ddgan_laion_aesthetic_v15;do
7
+ if [ "$model" == "ddgan_ddb_v3" ]; then
8
  bs=32
9
+ elif [ "$model" == "ddgan_laion_aesthetic_v15" ]; then
10
+ bs=32
11
+ elif [ "$model" == "ddgan_ddb_v6" ]; then
12
+ bs=32
13
+ elif [ "$model" == "ddgan_ddb_v4" ]; then
14
+ bs=16
15
  else
16
  bs=64
17
  fi
run.py CHANGED
@@ -256,6 +256,28 @@ def ddgan_ddb_v3():
256
  cfg['model']['num_timesteps'] = 2
257
  return cfg
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  models = [
260
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
261
  ddgan_cifar10_cond18, # cifar10, xl encoder
@@ -283,6 +305,7 @@ models = [
283
  ddgan_laion_aesthetic_v12,
284
  ddgan_laion_aesthetic_v13,
285
  ddgan_laion_aesthetic_v14,
 
286
  ddgan_laion2b_v1,
287
  ddgan_sd_v1,
288
  ddgan_sd_v2,
@@ -298,7 +321,11 @@ models = [
298
  ddgan_laion2b_v2,
299
  ddgan_ddb_v1,
300
  ddgan_ddb_v2,
301
- ddgan_ddb_v3
 
 
 
 
302
  ]
303
 
304
  def get_model(model_name):
 
256
  cfg['model']['num_timesteps'] = 2
257
  return cfg
258
 
259
+ def ddgan_ddb_v4():
260
+ cfg = ddgan_ddb_v1()
261
+ cfg['model']['num_channels_dae'] = 256
262
+ cfg['model']['num_timesteps'] = 2
263
+ return cfg
264
+
265
+ def ddgan_ddb_v5():
266
+ cfg = ddgan_ddb_v2()
267
+ return cfg
268
+
269
+ def ddgan_ddb_v6():
270
+ cfg = ddgan_ddb_v3()
271
+ return cfg
272
+
273
+ def ddgan_ddb_v7():
274
+ cfg = ddgan_ddb_v1()
275
+ return cfg
276
+
277
+ def ddgan_laion_aesthetic_v15():
278
+ cfg = ddgan_ddb_v3()
279
+ return cfg
280
+
281
  models = [
282
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
283
  ddgan_cifar10_cond18, # cifar10, xl encoder
 
305
  ddgan_laion_aesthetic_v12,
306
  ddgan_laion_aesthetic_v13,
307
  ddgan_laion_aesthetic_v14,
308
+ ddgan_laion_aesthetic_v15,
309
  ddgan_laion2b_v1,
310
  ddgan_sd_v1,
311
  ddgan_sd_v2,
 
321
  ddgan_laion2b_v2,
322
  ddgan_ddb_v1,
323
  ddgan_ddb_v2,
324
+ ddgan_ddb_v3,
325
+ ddgan_ddb_v4,
326
+ ddgan_ddb_v5,
327
+ ddgan_ddb_v6,
328
+ ddgan_ddb_v7,
329
  ]
330
 
331
  def get_model(model_name):
score_sde/models/discriminator.py CHANGED
@@ -181,7 +181,7 @@ class SmallCondAttnDiscriminator(nn.Module):
181
  hidden_dim=t_emb_dim,
182
  output_dim=t_emb_dim,
183
  act=act,
184
- )
185
 
186
 
187
 
@@ -368,7 +368,7 @@ class CondAttnDiscriminator(nn.Module):
368
  hidden_dim=t_emb_dim,
369
  output_dim=t_emb_dim,
370
  act=act,
371
- )
372
 
373
  self.start_conv = conv2d(nc,ngf*2,1, padding=0)
374
  self.conv1 = DownConvBlock(ngf*2, ngf*4, t_emb_dim = t_emb_dim, downsample = True, act=act)
 
181
  hidden_dim=t_emb_dim,
182
  output_dim=t_emb_dim,
183
  act=act,
184
+ )
185
 
186
 
187
 
 
368
  hidden_dim=t_emb_dim,
369
  output_dim=t_emb_dim,
370
  act=act,
371
+ )
372
 
373
  self.start_conv = conv2d(nc,ngf*2,1, padding=0)
374
  self.conv1 = DownConvBlock(ngf*2, ngf*4, t_emb_dim = t_emb_dim, downsample = True, act=act)
score_sde/models/layers.py CHANGED
@@ -559,6 +559,7 @@ class CondAttnBlock(nn.Module):
559
  h = h.permute(0,2,1)
560
  h = h.contiguous()
561
  h_new = self.ca(h, cond, mask=mask)
 
562
  h_new = h_new.permute(0,2,1)
563
  h_new = h_new.contiguous()
564
  h_new = h_new.view(B, C, H, W)
 
559
  h = h.permute(0,2,1)
560
  h = h.contiguous()
561
  h_new = self.ca(h, cond, mask=mask)
562
+ # print(h_new.min(), h_new.max())
563
  h_new = h_new.permute(0,2,1)
564
  h_new = h_new.contiguous()
565
  h_new = h_new.view(B, C, H, W)
score_sde/models/projected_discriminator.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ #from pg_modules.blocks import DownBlock, DownBlockPatch, conv2d
9
+ import functools
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.utils import spectral_norm
14
+ from . import layers
15
+ from .layers import CondAttnBlock
16
+ from .discriminator import *
17
+
18
+
19
+ def conv2d(*args, **kwargs):
20
+ return spectral_norm(nn.Conv2d(*args, **kwargs))
21
+
22
+
23
+ def convTranspose2d(*args, **kwargs):
24
+ return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))
25
+
26
+
27
+ def embedding(*args, **kwargs):
28
+ return spectral_norm(nn.Embedding(*args, **kwargs))
29
+
30
+
31
+ def linear(*args, **kwargs):
32
+ return spectral_norm(nn.Linear(*args, **kwargs))
33
+
34
+
35
+ def NormLayer(c, mode='batch'):
36
+ if mode == 'group':
37
+ return nn.GroupNorm(c//2, c)
38
+ elif mode == 'batch':
39
+ return nn.BatchNorm2d(c)
40
+
41
+
42
+ ### Activations
43
+
44
+
45
+ class GLU(nn.Module):
46
+ def forward(self, x):
47
+ nc = x.size(1)
48
+ assert nc % 2 == 0, 'channels dont divide 2!'
49
+ nc = int(nc/2)
50
+ return x[:, :nc] * torch.sigmoid(x[:, nc:])
51
+
52
+
53
+ class Swish(nn.Module):
54
+ def forward(self, feat):
55
+ return feat * torch.sigmoid(feat)
56
+
57
+
58
+ ### Upblocks
59
+
60
+
61
+ class InitLayer(nn.Module):
62
+ def __init__(self, nz, channel, sz=4):
63
+ super().__init__()
64
+
65
+ self.init = nn.Sequential(
66
+ convTranspose2d(nz, channel*2, sz, 1, 0, bias=False),
67
+ NormLayer(channel*2),
68
+ GLU(),
69
+ )
70
+
71
+ def forward(self, noise):
72
+ noise = noise.view(noise.shape[0], -1, 1, 1)
73
+ return self.init(noise)
74
+
75
+
76
+ def UpBlockSmall(in_planes, out_planes):
77
+ block = nn.Sequential(
78
+ nn.Upsample(scale_factor=2, mode='nearest'),
79
+ conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
80
+ NormLayer(out_planes*2), GLU())
81
+ return block
82
+
83
+
84
+ class UpBlockSmallCond(nn.Module):
85
+ def __init__(self, in_planes, out_planes, z_dim):
86
+ super().__init__()
87
+ self.in_planes = in_planes
88
+ self.out_planes = out_planes
89
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
90
+ self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
91
+
92
+ which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
93
+ self.bn = which_bn(2*out_planes)
94
+ self.act = GLU()
95
+
96
+ def forward(self, x, c):
97
+ x = self.up(x)
98
+ x = self.conv(x)
99
+ x = self.bn(x, c)
100
+ x = self.act(x)
101
+ return x
102
+
103
+
104
+ def UpBlockBig(in_planes, out_planes):
105
+ block = nn.Sequential(
106
+ nn.Upsample(scale_factor=2, mode='nearest'),
107
+ conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
108
+ NoiseInjection(),
109
+ NormLayer(out_planes*2), GLU(),
110
+ conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False),
111
+ NoiseInjection(),
112
+ NormLayer(out_planes*2), GLU()
113
+ )
114
+ return block
115
+
116
+
117
+ class UpBlockBigCond(nn.Module):
118
+ def __init__(self, in_planes, out_planes, z_dim):
119
+ super().__init__()
120
+ self.in_planes = in_planes
121
+ self.out_planes = out_planes
122
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
123
+ self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
124
+ self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False)
125
+
126
+ which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
127
+ self.bn1 = which_bn(2*out_planes)
128
+ self.bn2 = which_bn(2*out_planes)
129
+ self.act = GLU()
130
+ self.noise = NoiseInjection()
131
+
132
+ def forward(self, x, c):
133
+ # block 1
134
+ x = self.up(x)
135
+ x = self.conv1(x)
136
+ x = self.noise(x)
137
+ x = self.bn1(x, c)
138
+ x = self.act(x)
139
+
140
+ # block 2
141
+ x = self.conv2(x)
142
+ x = self.noise(x)
143
+ x = self.bn2(x, c)
144
+ x = self.act(x)
145
+
146
+ return x
147
+
148
+
149
+ class SEBlock(nn.Module):
150
+ def __init__(self, ch_in, ch_out):
151
+ super().__init__()
152
+ self.main = nn.Sequential(
153
+ nn.AdaptiveAvgPool2d(4),
154
+ conv2d(ch_in, ch_out, 4, 1, 0, bias=False),
155
+ Swish(),
156
+ conv2d(ch_out, ch_out, 1, 1, 0, bias=False),
157
+ nn.Sigmoid(),
158
+ )
159
+
160
+ def forward(self, feat_small, feat_big):
161
+ return feat_big * self.main(feat_small)
162
+
163
+
164
+ ### Downblocks
165
+
166
+
167
+ class SeparableConv2d(nn.Module):
168
+ def __init__(self, in_channels, out_channels, kernel_size, bias=False):
169
+ super(SeparableConv2d, self).__init__()
170
+ self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size,
171
+ groups=in_channels, bias=bias, padding=1)
172
+ self.pointwise = conv2d(in_channels, out_channels,
173
+ kernel_size=1, bias=bias)
174
+
175
+ def forward(self, x):
176
+ out = self.depthwise(x)
177
+ out = self.pointwise(out)
178
+ return out
179
+
180
+
181
+ class DownBlock(nn.Module):
182
+ def __init__(self, in_planes, out_planes, separable=False):
183
+ super().__init__()
184
+ if not separable:
185
+ self.main = nn.Sequential(
186
+ conv2d(in_planes, out_planes, 4, 2, 1),
187
+ NormLayer(out_planes),
188
+ nn.LeakyReLU(0.2, inplace=True),
189
+ )
190
+ else:
191
+ self.main = nn.Sequential(
192
+ SeparableConv2d(in_planes, out_planes, 3),
193
+ NormLayer(out_planes),
194
+ nn.LeakyReLU(0.2, inplace=True),
195
+ nn.AvgPool2d(2, 2),
196
+ )
197
+
198
+ def forward(self, feat):
199
+ return self.main(feat)
200
+
201
+
202
+ class DownBlockPatch(nn.Module):
203
+ def __init__(self, in_planes, out_planes, separable=False):
204
+ super().__init__()
205
+ self.main = nn.Sequential(
206
+ DownBlock(in_planes, out_planes, separable),
207
+ conv2d(out_planes, out_planes, 1, 1, 0, bias=False),
208
+ NormLayer(out_planes),
209
+ nn.LeakyReLU(0.2, inplace=True),
210
+ )
211
+
212
+ def forward(self, feat):
213
+ return self.main(feat)
214
+
215
+
216
+ ### CSM
217
+
218
+
219
+ class ResidualConvUnit(nn.Module):
220
+ def __init__(self, cin, activation, bn):
221
+ super().__init__()
222
+ self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True)
223
+ self.skip_add = nn.quantized.FloatFunctional()
224
+
225
+ def forward(self, x):
226
+ return self.skip_add.add(self.conv(x), x)
227
+
228
+
229
+ class FeatureFusionBlock(nn.Module):
230
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
231
+ super().__init__()
232
+
233
+ self.deconv = deconv
234
+ self.align_corners = align_corners
235
+
236
+ self.expand = expand
237
+ out_features = features
238
+ if self.expand==True:
239
+ out_features = features//2
240
+
241
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
242
+ self.skip_add = nn.quantized.FloatFunctional()
243
+
244
+ def forward(self, *xs):
245
+ output = xs[0]
246
+
247
+ if len(xs) == 2:
248
+ output = self.skip_add.add(output, xs[1])
249
+
250
+ output = nn.functional.interpolate(
251
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
252
+ )
253
+
254
+ output = self.out_conv(output)
255
+
256
+ return output
257
+
258
+
259
+ ### Misc
260
+
261
+
262
+ class NoiseInjection(nn.Module):
263
+ def __init__(self):
264
+ super().__init__()
265
+ self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
266
+
267
+ def forward(self, feat, noise=None):
268
+ if noise is None:
269
+ batch, _, height, width = feat.shape
270
+ noise = torch.randn(batch, 1, height, width).to(feat.device)
271
+
272
+ return feat + self.weight * noise
273
+
274
+
275
+ class CCBN(nn.Module):
276
+ ''' conditional batchnorm '''
277
+ def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1):
278
+ super().__init__()
279
+ self.output_size, self.input_size = output_size, input_size
280
+
281
+ # Prepare gain and bias layers
282
+ self.gain = which_linear(input_size, output_size)
283
+ self.bias = which_linear(input_size, output_size)
284
+
285
+ # epsilon to avoid dividing by 0
286
+ self.eps = eps
287
+ # Momentum
288
+ self.momentum = momentum
289
+
290
+ self.register_buffer('stored_mean', torch.zeros(output_size))
291
+ self.register_buffer('stored_var', torch.ones(output_size))
292
+
293
+ def forward(self, x, y):
294
+ # Calculate class-conditional gains and biases
295
+ gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
296
+ bias = self.bias(y).view(y.size(0), -1, 1, 1)
297
+ out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
298
+ self.training, 0.1, self.eps)
299
+ return out * gain + bias
300
+
301
+
302
+ class Interpolate(nn.Module):
303
+ """Interpolation module."""
304
+
305
+ def __init__(self, size, mode='bilinear', align_corners=False):
306
+ """Init.
307
+ Args:
308
+ scale_factor (float): scaling
309
+ mode (str): interpolation mode
310
+ """
311
+ super(Interpolate, self).__init__()
312
+
313
+ self.interp = nn.functional.interpolate
314
+ self.size = size
315
+ self.mode = mode
316
+ self.align_corners = align_corners
317
+
318
+ def forward(self, x):
319
+ """Forward pass.
320
+ Args:
321
+ x (tensor): input
322
+ Returns:
323
+ tensor: interpolated data
324
+ """
325
+
326
+ x = self.interp(
327
+ x,
328
+ size=self.size,
329
+ mode=self.mode,
330
+ align_corners=self.align_corners,
331
+ )
332
+
333
+ return x
334
+
335
+
336
+
337
+ #from pg_modules.projector import F_RandomProj
338
+
339
+ import torch
340
+ import torch.nn as nn
341
+ import timm
342
+ #from pg_modules.blocks import FeatureFusionBlock
343
+
344
+
345
+ def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
346
+ # shapes
347
+ out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
348
+
349
+ scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
350
+ scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
351
+ scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
352
+ scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
353
+
354
+ scratch.CHANNELS = out_channels
355
+
356
+ return scratch
357
+
358
+
359
+ def _make_scratch_csm(scratch, in_channels, cout, expand):
360
+ scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
361
+ scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
362
+ scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
363
+ scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
364
+
365
+ # last refinenet does not expand to save channels in higher dimensions
366
+ scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
367
+
368
+ return scratch
369
+
370
+
371
+ def _make_efficientnet(model):
372
+ pretrained = nn.Module()
373
+ pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2])
374
+ pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
375
+ pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
376
+ pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
377
+ return pretrained
378
+
379
+
380
+ def calc_channels(pretrained, inp_res=224):
381
+ channels = []
382
+ tmp = torch.zeros(1, 3, inp_res, inp_res)
383
+
384
+ # forward pass
385
+ tmp = pretrained.layer0(tmp)
386
+ channels.append(tmp.shape[1])
387
+ tmp = pretrained.layer1(tmp)
388
+ channels.append(tmp.shape[1])
389
+ tmp = pretrained.layer2(tmp)
390
+ channels.append(tmp.shape[1])
391
+ tmp = pretrained.layer3(tmp)
392
+ channels.append(tmp.shape[1])
393
+
394
+ return channels
395
+
396
+
397
+ def _make_projector(im_res, cout, proj_type, expand=False):
398
+ assert proj_type in [0, 1, 2], "Invalid projection type"
399
+
400
+ ### Build pretrained feature network
401
+ model = timm.create_model('tf_efficientnet_lite0', pretrained=True)
402
+ pretrained = _make_efficientnet(model)
403
+
404
+ # determine resolution of feature maps, this is later used to calculate the number
405
+ # of down blocks in the discriminators. Interestingly, the best results are achieved
406
+ # by fixing this to 256, ie., we use the same number of down blocks per discriminator
407
+ # independent of the dataset resolution
408
+ im_res = 256
409
+ pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
410
+ pretrained.CHANNELS = calc_channels(pretrained)
411
+
412
+ if proj_type == 0: return pretrained, None
413
+
414
+ ### Build CCM
415
+ scratch = nn.Module()
416
+ scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
417
+ pretrained.CHANNELS = scratch.CHANNELS
418
+
419
+ if proj_type == 1: return pretrained, scratch
420
+
421
+ ### build CSM
422
+ scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
423
+
424
+ # CSM upsamples x2 so the feature map resolution doubles
425
+ pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
426
+ pretrained.CHANNELS = scratch.CHANNELS
427
+
428
+ return pretrained, scratch
429
+
430
+
431
+ class F_RandomProj(nn.Module):
432
+ def __init__(
433
+ self,
434
+ im_res=256,
435
+ cout=64,
436
+ expand=True,
437
+ proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
438
+ **kwargs,
439
+ ):
440
+ super().__init__()
441
+ self.proj_type = proj_type
442
+ self.cout = cout
443
+ self.expand = expand
444
+
445
+ # build pretrained feature network and random decoder (scratch)
446
+ self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand)
447
+ self.CHANNELS = self.pretrained.CHANNELS
448
+ self.RESOLUTIONS = self.pretrained.RESOLUTIONS
449
+
450
+ def forward(self, x):
451
+ # predict feature maps
452
+ out0 = self.pretrained.layer0(x)
453
+ out1 = self.pretrained.layer1(out0)
454
+ out2 = self.pretrained.layer2(out1)
455
+ out3 = self.pretrained.layer3(out2)
456
+
457
+ # start enumerating at the lowest layer (this is where we put the first discriminator)
458
+ out = {
459
+ '0': out0,
460
+ '1': out1,
461
+ '2': out2,
462
+ '3': out3,
463
+ }
464
+
465
+ if self.proj_type == 0: return out
466
+
467
+ out0_channel_mixed = self.scratch.layer0_ccm(out['0'])
468
+ out1_channel_mixed = self.scratch.layer1_ccm(out['1'])
469
+ out2_channel_mixed = self.scratch.layer2_ccm(out['2'])
470
+ out3_channel_mixed = self.scratch.layer3_ccm(out['3'])
471
+
472
+ out = {
473
+ '0': out0_channel_mixed,
474
+ '1': out1_channel_mixed,
475
+ '2': out2_channel_mixed,
476
+ '3': out3_channel_mixed,
477
+ }
478
+
479
+ if self.proj_type == 1: return out
480
+
481
+ # from bottom to top
482
+ out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
483
+ out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
484
+ out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
485
+ out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
486
+
487
+ out = {
488
+ '0': out0_scale_mixed,
489
+ '1': out1_scale_mixed,
490
+ '2': out2_scale_mixed,
491
+ '3': out3_scale_mixed,
492
+ }
493
+
494
+ return out
495
+
496
+
497
+ #from pg_modules.diffaug import DiffAugment
498
+ # Differentiable Augmentation for Data-Efficient GAN Training
499
+ # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
500
+ # https://arxiv.org/pdf/2006.10738
501
+
502
+ import torch
503
+ import torch.nn.functional as F
504
+
505
+
506
+ def DiffAugment(x, policy='', channels_first=True):
507
+ if policy:
508
+ if not channels_first:
509
+ x = x.permute(0, 3, 1, 2)
510
+ for p in policy.split(','):
511
+ for f in AUGMENT_FNS[p]:
512
+ x = f(x)
513
+ if not channels_first:
514
+ x = x.permute(0, 2, 3, 1)
515
+ x = x.contiguous()
516
+ return x
517
+
518
+
519
+ def rand_brightness(x):
520
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
521
+ return x
522
+
523
+
524
+ def rand_saturation(x):
525
+ x_mean = x.mean(dim=1, keepdim=True)
526
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
527
+ return x
528
+
529
+
530
+ def rand_contrast(x):
531
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
532
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
533
+ return x
534
+
535
+
536
+ def rand_translation(x, ratio=0.125):
537
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
538
+ translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
539
+ translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
540
+ grid_batch, grid_x, grid_y = torch.meshgrid(
541
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
542
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
543
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
544
+ )
545
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
546
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
547
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
548
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
549
+ return x
550
+
551
+
552
+ def rand_cutout(x, ratio=0.2):
553
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
554
+ offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
555
+ offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
556
+ grid_batch, grid_x, grid_y = torch.meshgrid(
557
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
558
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
559
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
560
+ )
561
+ grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
562
+ grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
563
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
564
+ mask[grid_batch, grid_x, grid_y] = 0
565
+ x = x * mask.unsqueeze(1)
566
+ return x
567
+
568
+
569
+ AUGMENT_FNS = {
570
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
571
+ 'translation': [rand_translation],
572
+ 'cutout': [rand_cutout],
573
+ }
574
+
575
+
576
+
577
+ class SingleDisc(nn.Module):
578
+ def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False):
579
+ super().__init__()
580
+ channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
581
+ 256: 32, 512: 16, 1024: 8}
582
+
583
+ # interpolate for start sz that are not powers of two
584
+ if start_sz not in channel_dict.keys():
585
+ sizes = np.array(list(channel_dict.keys()))
586
+ start_sz = sizes[np.argmin(abs(sizes - start_sz))]
587
+ self.start_sz = start_sz
588
+
589
+ # if given ndf, allocate all layers with the same ndf
590
+ if ndf is None:
591
+ nfc = channel_dict
592
+ else:
593
+ nfc = {k: ndf for k, v in channel_dict.items()}
594
+
595
+ # for feature map discriminators with nfc not in channel_dict
596
+ # this is the case for the pretrained backbone (midas.pretrained)
597
+ if nc is not None and head is None:
598
+ nfc[start_sz] = nc
599
+
600
+ layers = []
601
+
602
+ # Head if the initial input is the full modality
603
+ if head:
604
+ layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
605
+ nn.LeakyReLU(0.2, inplace=True)]
606
+
607
+ # Down Blocks
608
+ DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
609
+ while start_sz > end_sz:
610
+ layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
611
+ start_sz = start_sz // 2
612
+
613
+ layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
614
+ self.main = nn.Sequential(*layers)
615
+
616
+ def forward(self, x, c):
617
+ return self.main(x)
618
+
619
+
620
+ class SingleDiscCond(nn.Module):
621
+ def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128, cond_size=128):
622
+ super().__init__()
623
+ self.cmap_dim = cmap_dim
624
+ self.cond_attn = CondAttnBlock(cmap_dim, cond_size, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False)
625
+ # midas channels
626
+ channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
627
+ 256: 32, 512: 16, 1024: 8}
628
+
629
+ # interpolate for start sz that are not powers of two
630
+ if start_sz not in channel_dict.keys():
631
+ sizes = np.array(list(channel_dict.keys()))
632
+ start_sz = sizes[np.argmin(abs(sizes - start_sz))]
633
+ self.start_sz = start_sz
634
+
635
+ # if given ndf, allocate all layers with the same ndf
636
+ if ndf is None:
637
+ nfc = channel_dict
638
+ else:
639
+ nfc = {k: ndf for k, v in channel_dict.items()}
640
+
641
+ # for feature map discriminators with nfc not in channel_dict
642
+ # this is the case for the pretrained backbone (midas.pretrained)
643
+ if nc is not None and head is None:
644
+ nfc[start_sz] = nc
645
+
646
+ layers = []
647
+
648
+ # Head if the initial input is the full modality
649
+ if head:
650
+ layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
651
+ nn.LeakyReLU(0.2, inplace=True)]
652
+
653
+ # Down Blocks
654
+ DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
655
+ while start_sz > end_sz:
656
+ layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
657
+ start_sz = start_sz // 2
658
+ self.main = nn.Sequential(*layers)
659
+
660
+ # additions for conditioning on class information
661
+ self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False)
662
+ #self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim)
663
+ #self.embed_proj = nn.Sequential(
664
+ # nn.Linear(self.embed.embedding_dim, self.cmap_dim),
665
+ # nn.LeakyReLU(0.2, inplace=True),
666
+ #)
667
+
668
+ def forward(self, x, c):
669
+ h = self.main(x)
670
+ out = self.cls(h)
671
+ cond_pooled, cond, cond_mask = c
672
+ #print("COND", out.shape, cond.shape, cond_mask.shape, self.cond_sie)
673
+ cmap = self.cond_attn(out, cond, cond_mask)
674
+ # conditioning via projection
675
+ #cmap = self.embed_proj(self.embed(c)).unsqueeze(-1).unsqueeze(-1)
676
+ #cmap = 1
677
+ out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
678
+ return out
679
+
680
+
681
+ class MultiScaleD(nn.Module):
682
+ def __init__(
683
+ self,
684
+ channels,
685
+ resolutions,
686
+ num_discs=1,
687
+ proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
688
+ cond=1,
689
+ separable=False,
690
+ patch=False,
691
+ cond_size=128,
692
+ **kwargs,
693
+ ):
694
+ super().__init__()
695
+
696
+ assert num_discs in [1, 2, 3, 4]
697
+
698
+ # the first disc is on the lowest level of the backbone
699
+ self.disc_in_channels = channels[:num_discs]
700
+ self.disc_in_res = resolutions[:num_discs]
701
+
702
+ Disc = SingleDiscCond if cond else SingleDisc
703
+ mini_discs = []
704
+ for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
705
+ start_sz = res if not patch else 16
706
+ mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch, cond_size=cond_size)],
707
+ self.mini_discs = nn.ModuleDict(mini_discs)
708
+
709
+ def forward(self, features, c):
710
+ all_logits = []
711
+ for k, disc in self.mini_discs.items():
712
+ all_logits.append(disc(features[k], c).view(features[k].size(0), -1))
713
+
714
+ all_logits = torch.cat(all_logits, dim=1)
715
+ return all_logits
716
+
717
+
718
+ class ProjectedDiscriminator(torch.nn.Module):
719
+ def __init__(
720
+ self,
721
+ diffaug=False,
722
+ interp224=False,
723
+ t_emb_dim = 128,
724
+ out_dim=64,
725
+ backbone_kwargs={},
726
+ act=torch.nn.LeakyReLU(0.2),
727
+ num_discs=1,
728
+ **kwargs
729
+ ):
730
+ super().__init__()
731
+ self.diffaug = diffaug
732
+ self.act = act
733
+ self.interp224 = interp224
734
+ self.num_discs = num_discs
735
+ self.feature_network = F_RandomProj(**backbone_kwargs)
736
+ self.discriminator = MultiScaleD(
737
+ channels=[c*2+out_dim for c in self.feature_network.CHANNELS],
738
+ resolutions=self.feature_network.RESOLUTIONS,
739
+ **backbone_kwargs,
740
+ )
741
+ self.t_embed = torch.nn.ModuleList([TimestepEmbedding(
742
+ embedding_dim=t_emb_dim,
743
+ hidden_dim=t_emb_dim,
744
+ output_dim=out_dim,
745
+ act=act,
746
+ ) for _ in range(num_discs)])
747
+
748
+
749
+ def train(self, mode=True):
750
+ self.feature_network = self.feature_network.train(False)
751
+ self.discriminator = self.discriminator.train(mode)
752
+ return self
753
+
754
+ def eval(self):
755
+ return self.train(False)
756
+
757
+ def forward(self, x, t, xprev, cond=None):
758
+ #t_embed = self.t_embed(t)
759
+ #t_embed = self.act(t_embed)
760
+
761
+ if self.diffaug:
762
+ x = DiffAugment(x, policy='color,translation,cutout')
763
+
764
+ if self.interp224:
765
+ x = F.interpolate(x, 256, mode='bilinear', align_corners=False)
766
+
767
+ features1 = self.feature_network(x)
768
+ features2 = self.feature_network(xprev)
769
+ features = {}
770
+ for k in features1.keys():
771
+ if int(k) >= self.num_discs:
772
+ continue
773
+ tcat = self.t_embed[int(k)](t)
774
+ #print(tcat.shape)
775
+ h, w = features1[k].shape[2:]
776
+ tcat = tcat.view(tcat.shape[0], tcat.shape[1], 1, 1).repeat(1,1, h, w)
777
+ #print(x.shape, xprev.shape, features1[k].shape, features2[k].shape, tcat.shape)
778
+ features[k] = torch.cat((features1[k], features2[k], tcat), dim=1)
779
+ #print(features[k].shape)
780
+ logits = self.discriminator(features, cond)
781
+
782
+ return logits
783
+
scripts/init.sh CHANGED
@@ -1,2 +1,14 @@
1
- source /p/project/laionize/miniconda/bin/activate
2
- conda activate ddgan
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ml purge
2
+ ml use $OTHERSTAGES
3
+ ml Stages/2022
4
+ ml GCC/11.2.0
5
+ ml OpenMPI/4.1.2
6
+ ml CUDA/11.5
7
+ ml cuDNN/8.3.1.22-CUDA-11.5
8
+ ml NCCL/2.12.7-1-CUDA-11.5
9
+ ml PyTorch/1.11-CUDA-11.5
10
+ ml Horovod/0.24
11
+ ml torchvision/0.12.0
12
+ source /p/home/jusers/cherti1/jureca/ccstdl/code/feed_forward_vqgan_clip/envs/jureca_2022/bin/activate
13
+ export HOROVOD_CACHE_CAPACITY=4096
14
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
scripts/run_jurecadc_conda.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -x
2
+ #SBATCH --account=zam
3
+ #SBATCH --nodes=1
4
+ #SBATCH --ntasks-per-node=4
5
+ #SBATCH --cpus-per-task=24
6
+ #SBATCH --time=06:00:00
7
+ #SBATCH --gres=gpu:4
8
+ #SBATCH --partition=dc-gpu
9
+ ml CUDA
10
+ source /p/project/laionize/miniconda/bin/activate
11
+ conda activate ddgan
12
+ #source scripts/init_2022.sh
13
+ #source scripts/init_2020.sh
14
+ #source scripts/init.sh
15
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
16
+ echo "Job id: $SLURM_JOB_ID"
17
+ export TOKENIZERS_PARALLELISM=false
18
+ #export NCCL_ASYNC_ERROR_HANDLING=1
19
+ export NCCL_IB_TIMEOUT=50
20
+ export UCX_RC_TIMEOUT=4s
21
+ export NCCL_IB_RETRY_CNT=10
22
+ export TORCH_DISTRIBUTED_DEBUG=INFO
23
+ srun python -u $*
scripts/run_juwelsbooster_conda.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -x
2
+ #SBATCH --account=laionize
3
+ #SBATCH --nodes=1
4
+ #SBATCH --ntasks-per-node=4
5
+ #SBATCH --cpus-per-task=24
6
+ #SBATCH --time=06:00:00
7
+ #SBATCH --gres=gpu:4
8
+ #SBATCH --partition=booster
9
+ ml CUDA
10
+ source /p/project/laionize/miniconda/bin/activate
11
+ conda activate ddgan
12
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
13
+ echo "Job id: $SLURM_JOB_ID"
14
+ export TOKENIZERS_PARALLELISM=false
15
+ #export NCCL_ASYNC_ERROR_HANDLING=1
16
+ export NCCL_IB_TIMEOUT=50
17
+ export UCX_RC_TIMEOUT=4s
18
+ export NCCL_IB_RETRY_CNT=10
19
+ srun python -u $*
test.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from score_sde.models.projected_discriminator import ProjectedDiscriminator
2
+ import torch
3
+ discr = ProjectedDiscriminator(num_discs=4, backbone_kwargs={"cond_size": 768})
4
+ x = torch.randn(1,3,224,224)
5
+ t = torch.randint(0, 1, size=(1,))
6
+ cond = (None, torch.randn(1,77, 768), torch.ones(1,77, dtype=torch.bool))
7
+ y = discr(x, t, x, cond=cond)
8
+ print(y.shape)
test_ddgan.py CHANGED
@@ -384,15 +384,20 @@ def sample_and_test(args):
384
  for epoch in epochs:
385
  args.epoch_id = epoch
386
  path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
387
- next_path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id+1)
388
  if not os.path.exists(path):
389
  continue
 
 
390
  print(path)
391
 
392
  #if not os.path.exists(next_path):
393
  # print(f"STOP at {epoch}")
394
  # break
395
- ckpt = torch.load(path, map_location=device)
 
 
 
396
  suffix = '_' + args.eval_name if args.eval_name else ""
397
  dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
398
  next_dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id+1, suffix)
 
384
  for epoch in epochs:
385
  args.epoch_id = epoch
386
  path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
387
+ next_next_path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id+2)
388
  if not os.path.exists(path):
389
  continue
390
+ if not os.path.exists(next_next_path):
391
+ break
392
  print(path)
393
 
394
  #if not os.path.exists(next_path):
395
  # print(f"STOP at {epoch}")
396
  # break
397
+ try:
398
+ ckpt = torch.load(path, map_location=device)
399
+ except Exception:
400
+ continue
401
  suffix = '_' + args.eval_name if args.eval_name else ""
402
  dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
403
  next_dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id+1, suffix)
train_ddgan.py CHANGED
@@ -210,6 +210,7 @@ def get_autocast(precision):
210
 
211
  def train(rank, gpu, args):
212
  from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
 
213
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
214
  from EMA import EMA
215
 
@@ -281,6 +282,12 @@ def train(rank, gpu, args):
281
  transforms.ToTensor(),
282
  transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
283
  ])
 
 
 
 
 
 
284
  shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
285
  pipeline = [ResampledShards2(shards)]
286
  pipeline.extend([
@@ -295,7 +302,7 @@ def train(rank, gpu, args):
295
  pipeline.extend([
296
  wds.select(filter_no_caption),
297
  wds.decode("pilrgb", handler=log_and_continue),
298
- wds.rename(image="jpg;png"),
299
  wds.map_dict(image=train_transform),
300
  wds.to_tuple("image","txt"),
301
  wds.batched(batch_size, partial=False),
@@ -361,7 +368,13 @@ def train(rank, gpu, args):
361
  t_emb_dim = args.t_emb_dim,
362
  cond_size=text_encoder.output_size,
363
  act=nn.LeakyReLU(0.2)).to(device)
364
-
 
 
 
 
 
 
365
  broadcast_params(netG.parameters())
366
  broadcast_params(netD.parameters())
367
 
@@ -387,7 +400,10 @@ def train(rank, gpu, args):
387
  netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
388
  else:
389
  netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
390
- netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
 
 
 
391
 
392
  if args.grad_checkpointing:
393
  from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
@@ -430,7 +446,7 @@ def train(rank, gpu, args):
430
  .format(checkpoint['epoch']))
431
  else:
432
  global_step, epoch, init_epoch = 0, 0, 0
433
- use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool")
434
  for epoch in range(init_epoch, args.num_epoch+1):
435
  if args.dataset == "wds":
436
  os.environ["WDS_EPOCH"] = str(epoch)
 
210
 
211
  def train(rank, gpu, args):
212
  from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
213
+ from score_sde.models.projected_discriminator import ProjectedDiscriminator
214
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
215
  from EMA import EMA
216
 
 
282
  transforms.ToTensor(),
283
  transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
284
  ])
285
+ elif args.preprocessing == "simple_random_crop":
286
+ train_transform = transforms.Compose([
287
+ transforms.RandomCrop(args.image_size, interpolation=3),
288
+ transforms.ToTensor(),
289
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
290
+ ])
291
  shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
292
  pipeline = [ResampledShards2(shards)]
293
  pipeline.extend([
 
302
  pipeline.extend([
303
  wds.select(filter_no_caption),
304
  wds.decode("pilrgb", handler=log_and_continue),
305
+ wds.rename(image="jpg;png;webp"),
306
  wds.map_dict(image=train_transform),
307
  wds.to_tuple("image","txt"),
308
  wds.batched(batch_size, partial=False),
 
368
  t_emb_dim = args.t_emb_dim,
369
  cond_size=text_encoder.output_size,
370
  act=nn.LeakyReLU(0.2)).to(device)
371
+ elif args.discr_type == "projected_gan":
372
+ netD = ProjectedDiscriminator(
373
+ num_discs=4,
374
+ backbone_kwargs={"cond_size": text_encoder.output_size}
375
+ )
376
+ netD = netD.to(device)
377
+
378
  broadcast_params(netG.parameters())
379
  broadcast_params(netD.parameters())
380
 
 
400
  netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
401
  else:
402
  netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
403
+ netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
404
+ #if args.discr_type == "projected_gan":
405
+ # netD._set_static_graph()
406
+
407
 
408
  if args.grad_checkpointing:
409
  from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
 
446
  .format(checkpoint['epoch']))
447
  else:
448
  global_step, epoch, init_epoch = 0, 0, 0
449
+ use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool", "projected_gan")
450
  for epoch in range(init_epoch, args.num_epoch+1):
451
  if args.dataset == "wds":
452
  os.environ["WDS_EPOCH"] = str(epoch)