Fabrice-TIERCELIN commited on
Commit
ac6e4f8
1 Parent(s): 38a87f5

Upload 3 files

Browse files
sgm/modules/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner
2
+ from .encoders.modules import GeneralConditionerWithControl
3
+ from .encoders.modules import PreparedConditioner
4
+
5
+ UNCONDITIONAL_CONFIG = {
6
+ "target": "sgm.modules.GeneralConditioner",
7
+ "params": {"emb_models": []},
8
+ }
sgm/modules/attention.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ # from einops._torch_specific import allow_ops_in_compiled_graph
8
+ # allow_ops_in_compiled_graph()
9
+ from einops import rearrange, repeat
10
+ from packaging import version
11
+ from torch import nn
12
+
13
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
14
+ SDP_IS_AVAILABLE = True
15
+ from torch.backends.cuda import SDPBackend, sdp_kernel
16
+
17
+ BACKEND_MAP = {
18
+ SDPBackend.MATH: {
19
+ "enable_math": True,
20
+ "enable_flash": False,
21
+ "enable_mem_efficient": False,
22
+ },
23
+ SDPBackend.FLASH_ATTENTION: {
24
+ "enable_math": False,
25
+ "enable_flash": True,
26
+ "enable_mem_efficient": False,
27
+ },
28
+ SDPBackend.EFFICIENT_ATTENTION: {
29
+ "enable_math": False,
30
+ "enable_flash": False,
31
+ "enable_mem_efficient": True,
32
+ },
33
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
34
+ }
35
+ else:
36
+ from contextlib import nullcontext
37
+
38
+ SDP_IS_AVAILABLE = False
39
+ sdp_kernel = nullcontext
40
+ BACKEND_MAP = {}
41
+ print(
42
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
43
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
44
+ )
45
+
46
+ try:
47
+ import xformers
48
+ import xformers.ops
49
+
50
+ XFORMERS_IS_AVAILABLE = True
51
+ except:
52
+ XFORMERS_IS_AVAILABLE = False
53
+ print("no module 'xformers'. Processing without...")
54
+
55
+ from .diffusionmodules.util import checkpoint
56
+
57
+
58
+ def exists(val):
59
+ return val is not None
60
+
61
+
62
+ def uniq(arr):
63
+ return {el: True for el in arr}.keys()
64
+
65
+
66
+ def default(val, d):
67
+ if exists(val):
68
+ return val
69
+ return d() if isfunction(d) else d
70
+
71
+
72
+ def max_neg_value(t):
73
+ return -torch.finfo(t.dtype).max
74
+
75
+
76
+ def init_(tensor):
77
+ dim = tensor.shape[-1]
78
+ std = 1 / math.sqrt(dim)
79
+ tensor.uniform_(-std, std)
80
+ return tensor
81
+
82
+
83
+ # feedforward
84
+ class GEGLU(nn.Module):
85
+ def __init__(self, dim_in, dim_out):
86
+ super().__init__()
87
+ self.proj = nn.Linear(dim_in, dim_out * 2)
88
+
89
+ def forward(self, x):
90
+ x, gate = self.proj(x).chunk(2, dim=-1)
91
+ return x * F.gelu(gate)
92
+
93
+
94
+ class FeedForward(nn.Module):
95
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
96
+ super().__init__()
97
+ inner_dim = int(dim * mult)
98
+ dim_out = default(dim_out, dim)
99
+ project_in = (
100
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
101
+ if not glu
102
+ else GEGLU(dim, inner_dim)
103
+ )
104
+
105
+ self.net = nn.Sequential(
106
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
107
+ )
108
+
109
+ def forward(self, x):
110
+ return self.net(x)
111
+
112
+
113
+ def zero_module(module):
114
+ """
115
+ Zero out the parameters of a module and return it.
116
+ """
117
+ for p in module.parameters():
118
+ p.detach().zero_()
119
+ return module
120
+
121
+
122
+ def Normalize(in_channels):
123
+ return torch.nn.GroupNorm(
124
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
125
+ )
126
+
127
+
128
+ class LinearAttention(nn.Module):
129
+ def __init__(self, dim, heads=4, dim_head=32):
130
+ super().__init__()
131
+ self.heads = heads
132
+ hidden_dim = dim_head * heads
133
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
134
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
135
+
136
+ def forward(self, x):
137
+ b, c, h, w = x.shape
138
+ qkv = self.to_qkv(x)
139
+ q, k, v = rearrange(
140
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
141
+ )
142
+ k = k.softmax(dim=-1)
143
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
144
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
145
+ out = rearrange(
146
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
147
+ )
148
+ return self.to_out(out)
149
+
150
+
151
+ class SpatialSelfAttention(nn.Module):
152
+ def __init__(self, in_channels):
153
+ super().__init__()
154
+ self.in_channels = in_channels
155
+
156
+ self.norm = Normalize(in_channels)
157
+ self.q = torch.nn.Conv2d(
158
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
159
+ )
160
+ self.k = torch.nn.Conv2d(
161
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
162
+ )
163
+ self.v = torch.nn.Conv2d(
164
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
165
+ )
166
+ self.proj_out = torch.nn.Conv2d(
167
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
168
+ )
169
+
170
+ def forward(self, x):
171
+ h_ = x
172
+ h_ = self.norm(h_)
173
+ q = self.q(h_)
174
+ k = self.k(h_)
175
+ v = self.v(h_)
176
+
177
+ # compute attention
178
+ b, c, h, w = q.shape
179
+ q = rearrange(q, "b c h w -> b (h w) c")
180
+ k = rearrange(k, "b c h w -> b c (h w)")
181
+ w_ = torch.einsum("bij,bjk->bik", q, k)
182
+
183
+ w_ = w_ * (int(c) ** (-0.5))
184
+ w_ = torch.nn.functional.softmax(w_, dim=2)
185
+
186
+ # attend to values
187
+ v = rearrange(v, "b c h w -> b c (h w)")
188
+ w_ = rearrange(w_, "b i j -> b j i")
189
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
190
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
191
+ h_ = self.proj_out(h_)
192
+
193
+ return x + h_
194
+
195
+
196
+ class CrossAttention(nn.Module):
197
+ def __init__(
198
+ self,
199
+ query_dim,
200
+ context_dim=None,
201
+ heads=8,
202
+ dim_head=64,
203
+ dropout=0.0,
204
+ backend=None,
205
+ ):
206
+ super().__init__()
207
+ inner_dim = dim_head * heads
208
+ context_dim = default(context_dim, query_dim)
209
+
210
+ self.scale = dim_head**-0.5
211
+ self.heads = heads
212
+
213
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
214
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
215
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
216
+
217
+ self.to_out = nn.Sequential(
218
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
219
+ )
220
+ self.backend = backend
221
+
222
+ def forward(
223
+ self,
224
+ x,
225
+ context=None,
226
+ mask=None,
227
+ additional_tokens=None,
228
+ n_times_crossframe_attn_in_self=0,
229
+ ):
230
+ h = self.heads
231
+
232
+ if additional_tokens is not None:
233
+ # get the number of masked tokens at the beginning of the output sequence
234
+ n_tokens_to_mask = additional_tokens.shape[1]
235
+ # add additional token
236
+ x = torch.cat([additional_tokens, x], dim=1)
237
+
238
+ q = self.to_q(x)
239
+ context = default(context, x)
240
+ k = self.to_k(context)
241
+ v = self.to_v(context)
242
+
243
+ if n_times_crossframe_attn_in_self:
244
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
245
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
246
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
247
+ k = repeat(
248
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
249
+ )
250
+ v = repeat(
251
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
252
+ )
253
+
254
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
255
+
256
+ ## old
257
+ """
258
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
259
+ del q, k
260
+
261
+ if exists(mask):
262
+ mask = rearrange(mask, 'b ... -> b (...)')
263
+ max_neg_value = -torch.finfo(sim.dtype).max
264
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
265
+ sim.masked_fill_(~mask, max_neg_value)
266
+
267
+ # attention, what we cannot get enough of
268
+ sim = sim.softmax(dim=-1)
269
+
270
+ out = einsum('b i j, b j d -> b i d', sim, v)
271
+ """
272
+ ## new
273
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
274
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
275
+ out = F.scaled_dot_product_attention(
276
+ q, k, v, attn_mask=mask
277
+ ) # scale is dim_head ** -0.5 per default
278
+
279
+ del q, k, v
280
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
281
+
282
+ if additional_tokens is not None:
283
+ # remove additional token
284
+ out = out[:, n_tokens_to_mask:]
285
+ return self.to_out(out)
286
+
287
+
288
+ class MemoryEfficientCrossAttention(nn.Module):
289
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
290
+ def __init__(
291
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
292
+ ):
293
+ super().__init__()
294
+ print(
295
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
296
+ f"{heads} heads with a dimension of {dim_head}."
297
+ )
298
+ inner_dim = dim_head * heads
299
+ context_dim = default(context_dim, query_dim)
300
+
301
+ self.heads = heads
302
+ self.dim_head = dim_head
303
+
304
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
305
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
306
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
307
+
308
+ self.to_out = nn.Sequential(
309
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
310
+ )
311
+ self.attention_op: Optional[Any] = None
312
+
313
+ def forward(
314
+ self,
315
+ x,
316
+ context=None,
317
+ mask=None,
318
+ additional_tokens=None,
319
+ n_times_crossframe_attn_in_self=0,
320
+ ):
321
+ if additional_tokens is not None:
322
+ # get the number of masked tokens at the beginning of the output sequence
323
+ n_tokens_to_mask = additional_tokens.shape[1]
324
+ # add additional token
325
+ x = torch.cat([additional_tokens, x], dim=1)
326
+ q = self.to_q(x)
327
+ context = default(context, x)
328
+ k = self.to_k(context)
329
+ v = self.to_v(context)
330
+
331
+ if n_times_crossframe_attn_in_self:
332
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
333
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
334
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
335
+ k = repeat(
336
+ k[::n_times_crossframe_attn_in_self],
337
+ "b ... -> (b n) ...",
338
+ n=n_times_crossframe_attn_in_self,
339
+ )
340
+ v = repeat(
341
+ v[::n_times_crossframe_attn_in_self],
342
+ "b ... -> (b n) ...",
343
+ n=n_times_crossframe_attn_in_self,
344
+ )
345
+
346
+ b, _, _ = q.shape
347
+ q, k, v = map(
348
+ lambda t: t.unsqueeze(3)
349
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
350
+ .permute(0, 2, 1, 3)
351
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
352
+ .contiguous(),
353
+ (q, k, v),
354
+ )
355
+
356
+ # actually compute the attention, what we cannot get enough of
357
+ out = xformers.ops.memory_efficient_attention(
358
+ q, k, v, attn_bias=None, op=self.attention_op
359
+ )
360
+
361
+ # TODO: Use this directly in the attention operation, as a bias
362
+ if exists(mask):
363
+ raise NotImplementedError
364
+ out = (
365
+ out.unsqueeze(0)
366
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
367
+ .permute(0, 2, 1, 3)
368
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
369
+ )
370
+ if additional_tokens is not None:
371
+ # remove additional token
372
+ out = out[:, n_tokens_to_mask:]
373
+ return self.to_out(out)
374
+
375
+
376
+ class BasicTransformerBlock(nn.Module):
377
+ ATTENTION_MODES = {
378
+ "softmax": CrossAttention, # vanilla attention
379
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
380
+ }
381
+
382
+ def __init__(
383
+ self,
384
+ dim,
385
+ n_heads,
386
+ d_head,
387
+ dropout=0.0,
388
+ context_dim=None,
389
+ gated_ff=True,
390
+ checkpoint=True,
391
+ disable_self_attn=False,
392
+ attn_mode="softmax",
393
+ sdp_backend=None,
394
+ ):
395
+ super().__init__()
396
+ assert attn_mode in self.ATTENTION_MODES
397
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
398
+ print(
399
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
400
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
401
+ )
402
+ attn_mode = "softmax"
403
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
404
+ print(
405
+ "We do not support vanilla attention anymore, as it is too expensive. Sorry."
406
+ )
407
+ if not XFORMERS_IS_AVAILABLE:
408
+ assert (
409
+ False
410
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
411
+ else:
412
+ print("Falling back to xformers efficient attention.")
413
+ attn_mode = "softmax-xformers"
414
+ attn_cls = self.ATTENTION_MODES[attn_mode]
415
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
416
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
417
+ else:
418
+ assert sdp_backend is None
419
+ self.disable_self_attn = disable_self_attn
420
+ self.attn1 = attn_cls(
421
+ query_dim=dim,
422
+ heads=n_heads,
423
+ dim_head=d_head,
424
+ dropout=dropout,
425
+ context_dim=context_dim if self.disable_self_attn else None,
426
+ backend=sdp_backend,
427
+ ) # is a self-attention if not self.disable_self_attn
428
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
429
+ self.attn2 = attn_cls(
430
+ query_dim=dim,
431
+ context_dim=context_dim,
432
+ heads=n_heads,
433
+ dim_head=d_head,
434
+ dropout=dropout,
435
+ backend=sdp_backend,
436
+ ) # is self-attn if context is none
437
+ self.norm1 = nn.LayerNorm(dim)
438
+ self.norm2 = nn.LayerNorm(dim)
439
+ self.norm3 = nn.LayerNorm(dim)
440
+ self.checkpoint = checkpoint
441
+ if self.checkpoint:
442
+ print(f"{self.__class__.__name__} is using checkpointing")
443
+
444
+ def forward(
445
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
446
+ ):
447
+ kwargs = {"x": x}
448
+
449
+ if context is not None:
450
+ kwargs.update({"context": context})
451
+
452
+ if additional_tokens is not None:
453
+ kwargs.update({"additional_tokens": additional_tokens})
454
+
455
+ if n_times_crossframe_attn_in_self:
456
+ kwargs.update(
457
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
458
+ )
459
+
460
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
461
+ return checkpoint(
462
+ self._forward, (x, context), self.parameters(), self.checkpoint
463
+ )
464
+
465
+ def _forward(
466
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
467
+ ):
468
+ x = (
469
+ self.attn1(
470
+ self.norm1(x),
471
+ context=context if self.disable_self_attn else None,
472
+ additional_tokens=additional_tokens,
473
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
474
+ if not self.disable_self_attn
475
+ else 0,
476
+ )
477
+ + x
478
+ )
479
+ x = (
480
+ self.attn2(
481
+ self.norm2(x), context=context, additional_tokens=additional_tokens
482
+ )
483
+ + x
484
+ )
485
+ x = self.ff(self.norm3(x)) + x
486
+ return x
487
+
488
+
489
+ class BasicTransformerSingleLayerBlock(nn.Module):
490
+ ATTENTION_MODES = {
491
+ "softmax": CrossAttention, # vanilla attention
492
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
493
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
494
+ }
495
+
496
+ def __init__(
497
+ self,
498
+ dim,
499
+ n_heads,
500
+ d_head,
501
+ dropout=0.0,
502
+ context_dim=None,
503
+ gated_ff=True,
504
+ checkpoint=True,
505
+ attn_mode="softmax",
506
+ ):
507
+ super().__init__()
508
+ assert attn_mode in self.ATTENTION_MODES
509
+ attn_cls = self.ATTENTION_MODES[attn_mode]
510
+ self.attn1 = attn_cls(
511
+ query_dim=dim,
512
+ heads=n_heads,
513
+ dim_head=d_head,
514
+ dropout=dropout,
515
+ context_dim=context_dim,
516
+ )
517
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
518
+ self.norm1 = nn.LayerNorm(dim)
519
+ self.norm2 = nn.LayerNorm(dim)
520
+ self.checkpoint = checkpoint
521
+
522
+ def forward(self, x, context=None):
523
+ return checkpoint(
524
+ self._forward, (x, context), self.parameters(), self.checkpoint
525
+ )
526
+
527
+ def _forward(self, x, context=None):
528
+ x = self.attn1(self.norm1(x), context=context) + x
529
+ x = self.ff(self.norm2(x)) + x
530
+ return x
531
+
532
+
533
+ class SpatialTransformer(nn.Module):
534
+ """
535
+ Transformer block for image-like data.
536
+ First, project the input (aka embedding)
537
+ and reshape to b, t, d.
538
+ Then apply standard transformer action.
539
+ Finally, reshape to image
540
+ NEW: use_linear for more efficiency instead of the 1x1 convs
541
+ """
542
+
543
+ def __init__(
544
+ self,
545
+ in_channels,
546
+ n_heads,
547
+ d_head,
548
+ depth=1,
549
+ dropout=0.0,
550
+ context_dim=None,
551
+ disable_self_attn=False,
552
+ use_linear=False,
553
+ attn_type="softmax",
554
+ use_checkpoint=True,
555
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
556
+ sdp_backend=None,
557
+ ):
558
+ super().__init__()
559
+ print(
560
+ f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
561
+ )
562
+ from omegaconf import ListConfig
563
+
564
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
565
+ context_dim = [context_dim]
566
+ if exists(context_dim) and isinstance(context_dim, list):
567
+ if depth != len(context_dim):
568
+ print(
569
+ f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
570
+ f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
571
+ )
572
+ # depth does not match context dims.
573
+ assert all(
574
+ map(lambda x: x == context_dim[0], context_dim)
575
+ ), "need homogenous context_dim to match depth automatically"
576
+ context_dim = depth * [context_dim[0]]
577
+ elif context_dim is None:
578
+ context_dim = [None] * depth
579
+ self.in_channels = in_channels
580
+ inner_dim = n_heads * d_head
581
+ self.norm = Normalize(in_channels)
582
+ if not use_linear:
583
+ self.proj_in = nn.Conv2d(
584
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
585
+ )
586
+ else:
587
+ self.proj_in = nn.Linear(in_channels, inner_dim)
588
+
589
+ self.transformer_blocks = nn.ModuleList(
590
+ [
591
+ BasicTransformerBlock(
592
+ inner_dim,
593
+ n_heads,
594
+ d_head,
595
+ dropout=dropout,
596
+ context_dim=context_dim[d],
597
+ disable_self_attn=disable_self_attn,
598
+ attn_mode=attn_type,
599
+ checkpoint=use_checkpoint,
600
+ sdp_backend=sdp_backend,
601
+ )
602
+ for d in range(depth)
603
+ ]
604
+ )
605
+ if not use_linear:
606
+ self.proj_out = zero_module(
607
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
608
+ )
609
+ else:
610
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
611
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
612
+ self.use_linear = use_linear
613
+
614
+ def forward(self, x, context=None):
615
+ # note: if no context is given, cross-attention defaults to self-attention
616
+ if not isinstance(context, list):
617
+ context = [context]
618
+ b, c, h, w = x.shape
619
+ x_in = x
620
+ x = self.norm(x)
621
+ if not self.use_linear:
622
+ x = self.proj_in(x)
623
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
624
+ if self.use_linear:
625
+ x = self.proj_in(x)
626
+ for i, block in enumerate(self.transformer_blocks):
627
+ if i > 0 and len(context) == 1:
628
+ i = 0 # use same context for each block
629
+ x = block(x, context=context[i])
630
+ if self.use_linear:
631
+ x = self.proj_out(x)
632
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
633
+ if not self.use_linear:
634
+ x = self.proj_out(x)
635
+ return x + x_in
sgm/modules/ema.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError("Decay must be between 0 and 1")
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer(
14
+ "num_updates",
15
+ torch.tensor(0, dtype=torch.int)
16
+ if use_num_upates
17
+ else torch.tensor(-1, dtype=torch.int),
18
+ )
19
+
20
+ for name, p in model.named_parameters():
21
+ if p.requires_grad:
22
+ # remove as '.'-character is not allowed in buffers
23
+ s_name = name.replace(".", "")
24
+ self.m_name2s_name.update({name: s_name})
25
+ self.register_buffer(s_name, p.clone().detach().data)
26
+
27
+ self.collected_params = []
28
+
29
+ def reset_num_updates(self):
30
+ del self.num_updates
31
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32
+
33
+ def forward(self, model):
34
+ decay = self.decay
35
+
36
+ if self.num_updates >= 0:
37
+ self.num_updates += 1
38
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39
+
40
+ one_minus_decay = 1.0 - decay
41
+
42
+ with torch.no_grad():
43
+ m_param = dict(model.named_parameters())
44
+ shadow_params = dict(self.named_buffers())
45
+
46
+ for key in m_param:
47
+ if m_param[key].requires_grad:
48
+ sname = self.m_name2s_name[key]
49
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50
+ shadow_params[sname].sub_(
51
+ one_minus_decay * (shadow_params[sname] - m_param[key])
52
+ )
53
+ else:
54
+ assert not key in self.m_name2s_name
55
+
56
+ def copy_to(self, model):
57
+ m_param = dict(model.named_parameters())
58
+ shadow_params = dict(self.named_buffers())
59
+ for key in m_param:
60
+ if m_param[key].requires_grad:
61
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62
+ else:
63
+ assert not key in self.m_name2s_name
64
+
65
+ def store(self, parameters):
66
+ """
67
+ Save the current parameters for restoring later.
68
+ Args:
69
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
+ temporarily stored.
71
+ """
72
+ self.collected_params = [param.clone() for param in parameters]
73
+
74
+ def restore(self, parameters):
75
+ """
76
+ Restore the parameters stored with the `store` method.
77
+ Useful to validate the model with EMA parameters without affecting the
78
+ original optimization process. Store the parameters before the
79
+ `copy_to` method. After validation (or model saving), use this to
80
+ restore the former parameters.
81
+ Args:
82
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83
+ updated with the stored parameters.
84
+ """
85
+ for c_param, param in zip(self.collected_params, parameters):
86
+ param.data.copy_(c_param.data)