Fabrice-TIERCELIN commited on
Commit
0051b62
1 Parent(s): 99f420c

Fix runtime

Browse files
Files changed (1) hide show
  1. sgm/modules/encoders/modules.py +1064 -1062
sgm/modules/encoders/modules.py CHANGED
@@ -1,1062 +1,1064 @@
1
- from contextlib import nullcontext
2
- from functools import partial
3
- from typing import Dict, List, Optional, Tuple, Union
4
-
5
- import kornia
6
- import numpy as np
7
- import open_clip
8
- import torch
9
- import torch.nn as nn
10
- from einops import rearrange, repeat
11
- from omegaconf import ListConfig
12
- from torch.utils.checkpoint import checkpoint
13
- from transformers import (
14
- ByT5Tokenizer,
15
- CLIPTextModel,
16
- CLIPTokenizer,
17
- T5EncoderModel,
18
- T5Tokenizer,
19
- )
20
-
21
- from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
22
- from ...modules.diffusionmodules.model import Encoder
23
- from ...modules.diffusionmodules.openaimodel import Timestep
24
- from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
25
- from ...modules.distributions.distributions import DiagonalGaussianDistribution
26
- from ...util import (
27
- autocast,
28
- count_params,
29
- default,
30
- disabled_train,
31
- expand_dims_like,
32
- instantiate_from_config,
33
- )
34
-
35
- from CKPT_PTH import SDXL_CLIP1_PATH, SDXL_CLIP2_CKPT_PTH
36
-
37
- class AbstractEmbModel(nn.Module):
38
- def __init__(self):
39
- super().__init__()
40
- self._is_trainable = None
41
- self._ucg_rate = None
42
- self._input_key = None
43
-
44
- @property
45
- def is_trainable(self) -> bool:
46
- return self._is_trainable
47
-
48
- @property
49
- def ucg_rate(self) -> Union[float, torch.Tensor]:
50
- return self._ucg_rate
51
-
52
- @property
53
- def input_key(self) -> str:
54
- return self._input_key
55
-
56
- @is_trainable.setter
57
- def is_trainable(self, value: bool):
58
- self._is_trainable = value
59
-
60
- @ucg_rate.setter
61
- def ucg_rate(self, value: Union[float, torch.Tensor]):
62
- self._ucg_rate = value
63
-
64
- @input_key.setter
65
- def input_key(self, value: str):
66
- self._input_key = value
67
-
68
- @is_trainable.deleter
69
- def is_trainable(self):
70
- del self._is_trainable
71
-
72
- @ucg_rate.deleter
73
- def ucg_rate(self):
74
- del self._ucg_rate
75
-
76
- @input_key.deleter
77
- def input_key(self):
78
- del self._input_key
79
-
80
-
81
- class GeneralConditioner(nn.Module):
82
- OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
83
- KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1, 'control_vector': 1}
84
-
85
- def __init__(self, emb_models: Union[List, ListConfig]):
86
- super().__init__()
87
- embedders = []
88
- for n, embconfig in enumerate(emb_models):
89
- embedder = instantiate_from_config(embconfig)
90
- assert isinstance(
91
- embedder, AbstractEmbModel
92
- ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
93
- embedder.is_trainable = embconfig.get("is_trainable", False)
94
- embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
95
- if not embedder.is_trainable:
96
- embedder.train = disabled_train
97
- for param in embedder.parameters():
98
- param.requires_grad = False
99
- embedder.eval()
100
- print(
101
- f"Initialized embedder #{n}: {embedder.__class__.__name__} "
102
- f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
103
- )
104
-
105
- if "input_key" in embconfig:
106
- embedder.input_key = embconfig["input_key"]
107
- elif "input_keys" in embconfig:
108
- embedder.input_keys = embconfig["input_keys"]
109
- else:
110
- raise KeyError(
111
- f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
112
- )
113
-
114
- embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
115
- if embedder.legacy_ucg_val is not None:
116
- embedder.ucg_prng = np.random.RandomState()
117
-
118
- embedders.append(embedder)
119
- self.embedders = nn.ModuleList(embedders)
120
-
121
- def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
122
- assert embedder.legacy_ucg_val is not None
123
- p = embedder.ucg_rate
124
- val = embedder.legacy_ucg_val
125
- for i in range(len(batch[embedder.input_key])):
126
- if embedder.ucg_prng.choice(2, p=[1 - p, p]):
127
- batch[embedder.input_key][i] = val
128
- return batch
129
-
130
- def forward(
131
- self, batch: Dict, force_zero_embeddings: Optional[List] = None
132
- ) -> Dict:
133
- output = dict()
134
- if force_zero_embeddings is None:
135
- force_zero_embeddings = []
136
- for embedder in self.embedders:
137
- embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
138
- with embedding_context():
139
- if hasattr(embedder, "input_key") and (embedder.input_key is not None):
140
- if embedder.legacy_ucg_val is not None:
141
- batch = self.possibly_get_ucg_val(embedder, batch)
142
- emb_out = embedder(batch[embedder.input_key])
143
- elif hasattr(embedder, "input_keys"):
144
- emb_out = embedder(*[batch[k] for k in embedder.input_keys])
145
- assert isinstance(
146
- emb_out, (torch.Tensor, list, tuple)
147
- ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
148
- if not isinstance(emb_out, (list, tuple)):
149
- emb_out = [emb_out]
150
- for emb in emb_out:
151
- out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
152
- if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
153
- emb = (
154
- expand_dims_like(
155
- torch.bernoulli(
156
- (1.0 - embedder.ucg_rate)
157
- * torch.ones(emb.shape[0], device=emb.device)
158
- ),
159
- emb,
160
- )
161
- * emb
162
- )
163
- if (
164
- hasattr(embedder, "input_key")
165
- and embedder.input_key in force_zero_embeddings
166
- ):
167
- emb = torch.zeros_like(emb)
168
- if out_key in output:
169
- output[out_key] = torch.cat(
170
- (output[out_key], emb), self.KEY2CATDIM[out_key]
171
- )
172
- else:
173
- output[out_key] = emb
174
- return output
175
-
176
- def get_unconditional_conditioning(
177
- self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
178
- ):
179
- if force_uc_zero_embeddings is None:
180
- force_uc_zero_embeddings = []
181
- ucg_rates = list()
182
- for embedder in self.embedders:
183
- ucg_rates.append(embedder.ucg_rate)
184
- embedder.ucg_rate = 0.0
185
- c = self(batch_c)
186
- uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
187
-
188
- for embedder, rate in zip(self.embedders, ucg_rates):
189
- embedder.ucg_rate = rate
190
- return c, uc
191
-
192
-
193
- class GeneralConditionerWithControl(GeneralConditioner):
194
- def forward(
195
- self, batch: Dict, force_zero_embeddings: Optional[List] = None
196
- ) -> Dict:
197
- output = dict()
198
- if force_zero_embeddings is None:
199
- force_zero_embeddings = []
200
- for embedder in self.embedders:
201
- embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
202
- with embedding_context():
203
- if hasattr(embedder, "input_key") and (embedder.input_key is not None):
204
- if embedder.legacy_ucg_val is not None:
205
- batch = self.possibly_get_ucg_val(embedder, batch)
206
- emb_out = embedder(batch[embedder.input_key])
207
- elif hasattr(embedder, "input_keys"):
208
- emb_out = embedder(*[batch[k] for k in embedder.input_keys])
209
- assert isinstance(
210
- emb_out, (torch.Tensor, list, tuple)
211
- ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
212
- if not isinstance(emb_out, (list, tuple)):
213
- emb_out = [emb_out]
214
- for emb in emb_out:
215
- if 'control_vector' in embedder.input_key:
216
- out_key = 'control_vector'
217
- else:
218
- out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
219
- if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
220
- emb = (
221
- expand_dims_like(
222
- torch.bernoulli(
223
- (1.0 - embedder.ucg_rate)
224
- * torch.ones(emb.shape[0], device=emb.device)
225
- ),
226
- emb,
227
- )
228
- * emb
229
- )
230
- if (
231
- hasattr(embedder, "input_key")
232
- and embedder.input_key in force_zero_embeddings
233
- ):
234
- emb = torch.zeros_like(emb)
235
- if out_key in output:
236
- output[out_key] = torch.cat(
237
- (output[out_key], emb), self.KEY2CATDIM[out_key]
238
- )
239
- else:
240
- output[out_key] = emb
241
-
242
- output["control"] = batch["control"]
243
- return output
244
-
245
-
246
- class PreparedConditioner(nn.Module):
247
- def __init__(self, cond_pth, un_cond_pth=None):
248
- super().__init__()
249
- conditions = torch.load(cond_pth)
250
- for k, v in conditions.items():
251
- self.register_buffer(k, v)
252
- self.un_cond_pth = un_cond_pth
253
- if un_cond_pth is not None:
254
- un_conditions = torch.load(un_cond_pth)
255
- for k, v in un_conditions.items():
256
- self.register_buffer(k+'_uc', v)
257
-
258
-
259
- @torch.no_grad()
260
- def forward(
261
- self, batch: Dict, return_uc=False
262
- ) -> Dict:
263
- output = dict()
264
- for k, v in self.state_dict().items():
265
- if not return_uc:
266
- if k.endswith("_uc"):
267
- continue
268
- else:
269
- output[k] = v.detach().clone().repeat(batch['control'].shape[0], *[1 for _ in range(v.ndim - 1)])
270
- else:
271
- if k.endswith("_uc"):
272
- output[k[:-3]] = v.detach().clone().repeat(batch['control'].shape[0], *[1 for _ in range(v.ndim - 1)])
273
- else:
274
- continue
275
- output["control"] = batch["control"]
276
-
277
- for k, v in output.items():
278
- if isinstance(v, torch.Tensor):
279
- assert (torch.isnan(v).any()) is not None
280
- return output
281
-
282
- def get_unconditional_conditioning(
283
- self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
284
- ):
285
- c = self(batch_c)
286
- if self.un_cond_pth is not None:
287
- uc = self(batch_c, return_uc=True)
288
- else:
289
- uc = None
290
- return c, uc
291
-
292
-
293
-
294
- class InceptionV3(nn.Module):
295
- """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
296
- port with an additional squeeze at the end"""
297
-
298
- def __init__(self, normalize_input=False, **kwargs):
299
- super().__init__()
300
- from pytorch_fid import inception
301
-
302
- kwargs["resize_input"] = True
303
- self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
304
-
305
- def forward(self, inp):
306
- # inp = kornia.geometry.resize(inp, (299, 299),
307
- # interpolation='bicubic',
308
- # align_corners=False,
309
- # antialias=True)
310
- # inp = inp.clamp(min=-1, max=1)
311
-
312
- outp = self.model(inp)
313
-
314
- if len(outp) == 1:
315
- return outp[0].squeeze()
316
-
317
- return outp
318
-
319
-
320
- class IdentityEncoder(AbstractEmbModel):
321
- def encode(self, x):
322
- return x
323
-
324
- def forward(self, x):
325
- return x
326
-
327
-
328
- class ClassEmbedder(AbstractEmbModel):
329
- def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
330
- super().__init__()
331
- self.embedding = nn.Embedding(n_classes, embed_dim)
332
- self.n_classes = n_classes
333
- self.add_sequence_dim = add_sequence_dim
334
-
335
- def forward(self, c):
336
- c = self.embedding(c)
337
- if self.add_sequence_dim:
338
- c = c[:, None, :]
339
- return c
340
-
341
- def get_unconditional_conditioning(self, bs, device="cuda"):
342
- uc_class = (
343
- self.n_classes - 1
344
- ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
345
- uc = torch.ones((bs,), device=device) * uc_class
346
- uc = {self.key: uc.long()}
347
- return uc
348
-
349
-
350
- class ClassEmbedderForMultiCond(ClassEmbedder):
351
- def forward(self, batch, key=None, disable_dropout=False):
352
- out = batch
353
- key = default(key, self.key)
354
- islist = isinstance(batch[key], list)
355
- if islist:
356
- batch[key] = batch[key][0]
357
- c_out = super().forward(batch, key, disable_dropout)
358
- out[key] = [c_out] if islist else c_out
359
- return out
360
-
361
-
362
- class FrozenT5Embedder(AbstractEmbModel):
363
- """Uses the T5 transformer encoder for text"""
364
-
365
- def __init__(
366
- self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
367
- ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
368
- super().__init__()
369
- self.tokenizer = T5Tokenizer.from_pretrained(version)
370
- self.transformer = T5EncoderModel.from_pretrained(version)
371
- self.device = device
372
- self.max_length = max_length
373
- if freeze:
374
- self.freeze()
375
-
376
- def freeze(self):
377
- self.transformer = self.transformer.eval()
378
-
379
- for param in self.parameters():
380
- param.requires_grad = False
381
-
382
- # @autocast
383
- def forward(self, text):
384
- batch_encoding = self.tokenizer(
385
- text,
386
- truncation=True,
387
- max_length=self.max_length,
388
- return_length=True,
389
- return_overflowing_tokens=False,
390
- padding="max_length",
391
- return_tensors="pt",
392
- )
393
- tokens = batch_encoding["input_ids"].to(self.device)
394
- with torch.autocast("cuda", enabled=False):
395
- outputs = self.transformer(input_ids=tokens)
396
- z = outputs.last_hidden_state
397
- return z
398
-
399
- def encode(self, text):
400
- return self(text)
401
-
402
-
403
- class FrozenByT5Embedder(AbstractEmbModel):
404
- """
405
- Uses the ByT5 transformer encoder for text. Is character-aware.
406
- """
407
-
408
- def __init__(
409
- self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
410
- ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
411
- super().__init__()
412
- self.tokenizer = ByT5Tokenizer.from_pretrained(version)
413
- self.transformer = T5EncoderModel.from_pretrained(version)
414
- self.device = device
415
- self.max_length = max_length
416
- if freeze:
417
- self.freeze()
418
-
419
- def freeze(self):
420
- self.transformer = self.transformer.eval()
421
-
422
- for param in self.parameters():
423
- param.requires_grad = False
424
-
425
- def forward(self, text):
426
- batch_encoding = self.tokenizer(
427
- text,
428
- truncation=True,
429
- max_length=self.max_length,
430
- return_length=True,
431
- return_overflowing_tokens=False,
432
- padding="max_length",
433
- return_tensors="pt",
434
- )
435
- tokens = batch_encoding["input_ids"].to(self.device)
436
- with torch.autocast("cuda", enabled=False):
437
- outputs = self.transformer(input_ids=tokens)
438
- z = outputs.last_hidden_state
439
- return z
440
-
441
- def encode(self, text):
442
- return self(text)
443
-
444
-
445
- class FrozenCLIPEmbedder(AbstractEmbModel):
446
- """Uses the CLIP transformer encoder for text (from huggingface)"""
447
-
448
- LAYERS = ["last", "pooled", "hidden"]
449
-
450
- def __init__(
451
- self,
452
- version="openai/clip-vit-large-patch14",
453
- device="cuda",
454
- max_length=77,
455
- freeze=True,
456
- layer="last",
457
- layer_idx=None,
458
- always_return_pooled=False,
459
- ): # clip-vit-base-patch32
460
- super().__init__()
461
- assert layer in self.LAYERS
462
- self.tokenizer = CLIPTokenizer.from_pretrained(version if SDXL_CLIP1_PATH is None else SDXL_CLIP1_PATH)
463
- self.transformer = CLIPTextModel.from_pretrained(version if SDXL_CLIP1_PATH is None else SDXL_CLIP1_PATH)
464
- self.device = device
465
- self.max_length = max_length
466
- if freeze:
467
- self.freeze()
468
- self.layer = layer
469
- self.layer_idx = layer_idx
470
- self.return_pooled = always_return_pooled
471
- if layer == "hidden":
472
- assert layer_idx is not None
473
- assert 0 <= abs(layer_idx) <= 12
474
-
475
- def freeze(self):
476
- self.transformer = self.transformer.eval()
477
-
478
- for param in self.parameters():
479
- param.requires_grad = False
480
-
481
- @autocast
482
- def forward(self, text):
483
- batch_encoding = self.tokenizer(
484
- text,
485
- truncation=True,
486
- max_length=self.max_length,
487
- return_length=True,
488
- return_overflowing_tokens=False,
489
- padding="max_length",
490
- return_tensors="pt",
491
- )
492
- tokens = batch_encoding["input_ids"].to(self.device)
493
- outputs = self.transformer(
494
- input_ids=tokens, output_hidden_states=self.layer == "hidden"
495
- )
496
- if self.layer == "last":
497
- z = outputs.last_hidden_state
498
- elif self.layer == "pooled":
499
- z = outputs.pooler_output[:, None, :]
500
- else:
501
- z = outputs.hidden_states[self.layer_idx]
502
- if self.return_pooled:
503
- return z, outputs.pooler_output
504
- return z
505
-
506
- def encode(self, text):
507
- return self(text)
508
-
509
-
510
- class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
511
- """
512
- Uses the OpenCLIP transformer encoder for text
513
- """
514
-
515
- LAYERS = ["pooled", "last", "penultimate"]
516
-
517
- def __init__(
518
- self,
519
- arch="ViT-H-14",
520
- version="laion2b_s32b_b79k",
521
- device="cuda",
522
- max_length=77,
523
- freeze=True,
524
- layer="last",
525
- always_return_pooled=False,
526
- legacy=True,
527
- ):
528
- super().__init__()
529
- assert layer in self.LAYERS
530
- model, _, _ = open_clip.create_model_and_transforms(
531
- arch,
532
- device=torch.device("cpu"),
533
- pretrained=version if SDXL_CLIP2_CKPT_PTH is None else SDXL_CLIP2_CKPT_PTH,
534
- )
535
- del model.visual
536
- self.model = model
537
-
538
- self.device = device
539
- self.max_length = max_length
540
- self.return_pooled = always_return_pooled
541
- if freeze:
542
- self.freeze()
543
- self.layer = layer
544
- if self.layer == "last":
545
- self.layer_idx = 0
546
- elif self.layer == "penultimate":
547
- self.layer_idx = 1
548
- else:
549
- raise NotImplementedError()
550
- self.legacy = legacy
551
-
552
- def freeze(self):
553
- self.model = self.model.eval()
554
- for param in self.parameters():
555
- param.requires_grad = False
556
-
557
- @autocast
558
- def forward(self, text):
559
- tokens = open_clip.tokenize(text)
560
- z = self.encode_with_transformer(tokens.to(self.device))
561
- if not self.return_pooled and self.legacy:
562
- return z
563
- if self.return_pooled:
564
- assert not self.legacy
565
- return z[self.layer], z["pooled"]
566
- return z[self.layer]
567
-
568
- def encode_with_transformer(self, text):
569
- x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
570
- x = x + self.model.positional_embedding
571
- x = x.permute(1, 0, 2) # NLD -> LND
572
- x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
573
- if self.legacy:
574
- x = x[self.layer]
575
- x = self.model.ln_final(x)
576
- return x
577
- else:
578
- # x is a dict and will stay a dict
579
- o = x["last"]
580
- o = self.model.ln_final(o)
581
- pooled = self.pool(o, text)
582
- x["pooled"] = pooled
583
- return x
584
-
585
- def pool(self, x, text):
586
- # take features from the eot embedding (eot_token is the highest number in each sequence)
587
- x = (
588
- x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
589
- @ self.model.text_projection
590
- )
591
- return x
592
-
593
- def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
594
- outputs = {}
595
- for i, r in enumerate(self.model.transformer.resblocks):
596
- if i == len(self.model.transformer.resblocks) - 1:
597
- outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
598
- if (
599
- self.model.transformer.grad_checkpointing
600
- and not torch.jit.is_scripting()
601
- ):
602
- x = checkpoint(r, x, attn_mask)
603
- else:
604
- x = r(x, attn_mask=attn_mask)
605
- outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
606
- return outputs
607
-
608
- def encode(self, text):
609
- return self(text)
610
-
611
-
612
- class FrozenOpenCLIPEmbedder(AbstractEmbModel):
613
- LAYERS = [
614
- # "pooled",
615
- "last",
616
- "penultimate",
617
- ]
618
-
619
- def __init__(
620
- self,
621
- arch="ViT-H-14",
622
- version="laion2b_s32b_b79k",
623
- device="cuda",
624
- max_length=77,
625
- freeze=True,
626
- layer="last",
627
- ):
628
- super().__init__()
629
- assert layer in self.LAYERS
630
- model, _, _ = open_clip.create_model_and_transforms(
631
- arch, device=torch.device("cpu"), pretrained=version
632
- )
633
- del model.visual
634
- self.model = model
635
-
636
- self.device = device
637
- self.max_length = max_length
638
- if freeze:
639
- self.freeze()
640
- self.layer = layer
641
- if self.layer == "last":
642
- self.layer_idx = 0
643
- elif self.layer == "penultimate":
644
- self.layer_idx = 1
645
- else:
646
- raise NotImplementedError()
647
-
648
- def freeze(self):
649
- self.model = self.model.eval()
650
- for param in self.parameters():
651
- param.requires_grad = False
652
-
653
- def forward(self, text):
654
- tokens = open_clip.tokenize(text)
655
- z = self.encode_with_transformer(tokens.to(self.device))
656
- return z
657
-
658
- def encode_with_transformer(self, text):
659
- x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
660
- x = x + self.model.positional_embedding
661
- x = x.permute(1, 0, 2) # NLD -> LND
662
- x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
663
- x = x.permute(1, 0, 2) # LND -> NLD
664
- x = self.model.ln_final(x)
665
- return x
666
-
667
- def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
668
- for i, r in enumerate(self.model.transformer.resblocks):
669
- if i == len(self.model.transformer.resblocks) - self.layer_idx:
670
- break
671
- if (
672
- self.model.transformer.grad_checkpointing
673
- and not torch.jit.is_scripting()
674
- ):
675
- x = checkpoint(r, x, attn_mask)
676
- else:
677
- x = r(x, attn_mask=attn_mask)
678
- return x
679
-
680
- def encode(self, text):
681
- return self(text)
682
-
683
-
684
- class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
685
- """
686
- Uses the OpenCLIP vision transformer encoder for images
687
- """
688
-
689
- def __init__(
690
- self,
691
- arch="ViT-H-14",
692
- version="laion2b_s32b_b79k",
693
- device="cuda",
694
- max_length=77,
695
- freeze=True,
696
- antialias=True,
697
- ucg_rate=0.0,
698
- unsqueeze_dim=False,
699
- repeat_to_max_len=False,
700
- num_image_crops=0,
701
- output_tokens=False,
702
- ):
703
- super().__init__()
704
- model, _, _ = open_clip.create_model_and_transforms(
705
- arch,
706
- device=torch.device("cpu"),
707
- pretrained=version,
708
- )
709
- del model.transformer
710
- self.model = model
711
- self.max_crops = num_image_crops
712
- self.pad_to_max_len = self.max_crops > 0
713
- self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
714
- self.device = device
715
- self.max_length = max_length
716
- if freeze:
717
- self.freeze()
718
-
719
- self.antialias = antialias
720
-
721
- self.register_buffer(
722
- "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
723
- )
724
- self.register_buffer(
725
- "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
726
- )
727
- self.ucg_rate = ucg_rate
728
- self.unsqueeze_dim = unsqueeze_dim
729
- self.stored_batch = None
730
- self.model.visual.output_tokens = output_tokens
731
- self.output_tokens = output_tokens
732
-
733
- def preprocess(self, x):
734
- # normalize to [0,1]
735
- x = kornia.geometry.resize(
736
- x,
737
- (224, 224),
738
- interpolation="bicubic",
739
- align_corners=True,
740
- antialias=self.antialias,
741
- )
742
- x = (x + 1.0) / 2.0
743
- # renormalize according to clip
744
- x = kornia.enhance.normalize(x, self.mean, self.std)
745
- return x
746
-
747
- def freeze(self):
748
- self.model = self.model.eval()
749
- for param in self.parameters():
750
- param.requires_grad = False
751
-
752
- @autocast
753
- def forward(self, image, no_dropout=False):
754
- z = self.encode_with_vision_transformer(image)
755
- tokens = None
756
- if self.output_tokens:
757
- z, tokens = z[0], z[1]
758
- z = z.to(image.dtype)
759
- if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
760
- z = (
761
- torch.bernoulli(
762
- (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
763
- )[:, None]
764
- * z
765
- )
766
- if tokens is not None:
767
- tokens = (
768
- expand_dims_like(
769
- torch.bernoulli(
770
- (1.0 - self.ucg_rate)
771
- * torch.ones(tokens.shape[0], device=tokens.device)
772
- ),
773
- tokens,
774
- )
775
- * tokens
776
- )
777
- if self.unsqueeze_dim:
778
- z = z[:, None, :]
779
- if self.output_tokens:
780
- assert not self.repeat_to_max_len
781
- assert not self.pad_to_max_len
782
- return tokens, z
783
- if self.repeat_to_max_len:
784
- if z.dim() == 2:
785
- z_ = z[:, None, :]
786
- else:
787
- z_ = z
788
- return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
789
- elif self.pad_to_max_len:
790
- assert z.dim() == 3
791
- z_pad = torch.cat(
792
- (
793
- z,
794
- torch.zeros(
795
- z.shape[0],
796
- self.max_length - z.shape[1],
797
- z.shape[2],
798
- device=z.device,
799
- ),
800
- ),
801
- 1,
802
- )
803
- return z_pad, z_pad[:, 0, ...]
804
- return z
805
-
806
- def encode_with_vision_transformer(self, img):
807
- # if self.max_crops > 0:
808
- # img = self.preprocess_by_cropping(img)
809
- if img.dim() == 5:
810
- assert self.max_crops == img.shape[1]
811
- img = rearrange(img, "b n c h w -> (b n) c h w")
812
- img = self.preprocess(img)
813
- if not self.output_tokens:
814
- assert not self.model.visual.output_tokens
815
- x = self.model.visual(img)
816
- tokens = None
817
- else:
818
- assert self.model.visual.output_tokens
819
- x, tokens = self.model.visual(img)
820
- if self.max_crops > 0:
821
- x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
822
- # drop out between 0 and all along the sequence axis
823
- x = (
824
- torch.bernoulli(
825
- (1.0 - self.ucg_rate)
826
- * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
827
- )
828
- * x
829
- )
830
- if tokens is not None:
831
- tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
832
- print(
833
- f"You are running very experimental token-concat in {self.__class__.__name__}. "
834
- f"Check what you are doing, and then remove this message."
835
- )
836
- if self.output_tokens:
837
- return x, tokens
838
- return x
839
-
840
- def encode(self, text):
841
- return self(text)
842
-
843
-
844
- class FrozenCLIPT5Encoder(AbstractEmbModel):
845
- def __init__(
846
- self,
847
- clip_version="openai/clip-vit-large-patch14",
848
- t5_version="google/t5-v1_1-xl",
849
- device="cuda",
850
- clip_max_length=77,
851
- t5_max_length=77,
852
- ):
853
- super().__init__()
854
- self.clip_encoder = FrozenCLIPEmbedder(
855
- clip_version, device, max_length=clip_max_length
856
- )
857
- self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
858
- print(
859
- f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
860
- f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
861
- )
862
-
863
- def encode(self, text):
864
- return self(text)
865
-
866
- def forward(self, text):
867
- clip_z = self.clip_encoder.encode(text)
868
- t5_z = self.t5_encoder.encode(text)
869
- return [clip_z, t5_z]
870
-
871
-
872
- class SpatialRescaler(nn.Module):
873
- def __init__(
874
- self,
875
- n_stages=1,
876
- method="bilinear",
877
- multiplier=0.5,
878
- in_channels=3,
879
- out_channels=None,
880
- bias=False,
881
- wrap_video=False,
882
- kernel_size=1,
883
- remap_output=False,
884
- ):
885
- super().__init__()
886
- self.n_stages = n_stages
887
- assert self.n_stages >= 0
888
- assert method in [
889
- "nearest",
890
- "linear",
891
- "bilinear",
892
- "trilinear",
893
- "bicubic",
894
- "area",
895
- ]
896
- self.multiplier = multiplier
897
- self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
898
- self.remap_output = out_channels is not None or remap_output
899
- if self.remap_output:
900
- print(
901
- f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
902
- )
903
- self.channel_mapper = nn.Conv2d(
904
- in_channels,
905
- out_channels,
906
- kernel_size=kernel_size,
907
- bias=bias,
908
- padding=kernel_size // 2,
909
- )
910
- self.wrap_video = wrap_video
911
-
912
- def forward(self, x):
913
- if self.wrap_video and x.ndim == 5:
914
- B, C, T, H, W = x.shape
915
- x = rearrange(x, "b c t h w -> b t c h w")
916
- x = rearrange(x, "b t c h w -> (b t) c h w")
917
-
918
- for stage in range(self.n_stages):
919
- x = self.interpolator(x, scale_factor=self.multiplier)
920
-
921
- if self.wrap_video:
922
- x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
923
- x = rearrange(x, "b t c h w -> b c t h w")
924
- if self.remap_output:
925
- x = self.channel_mapper(x)
926
- return x
927
-
928
- def encode(self, x):
929
- return self(x)
930
-
931
-
932
- class LowScaleEncoder(nn.Module):
933
- def __init__(
934
- self,
935
- model_config,
936
- linear_start,
937
- linear_end,
938
- timesteps=1000,
939
- max_noise_level=250,
940
- output_size=64,
941
- scale_factor=1.0,
942
- ):
943
- super().__init__()
944
- self.max_noise_level = max_noise_level
945
- self.model = instantiate_from_config(model_config)
946
- self.augmentation_schedule = self.register_schedule(
947
- timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
948
- )
949
- self.out_size = output_size
950
- self.scale_factor = scale_factor
951
-
952
- def register_schedule(
953
- self,
954
- beta_schedule="linear",
955
- timesteps=1000,
956
- linear_start=1e-4,
957
- linear_end=2e-2,
958
- cosine_s=8e-3,
959
- ):
960
- betas = make_beta_schedule(
961
- beta_schedule,
962
- timesteps,
963
- linear_start=linear_start,
964
- linear_end=linear_end,
965
- cosine_s=cosine_s,
966
- )
967
- alphas = 1.0 - betas
968
- alphas_cumprod = np.cumprod(alphas, axis=0)
969
- alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
970
-
971
- (timesteps,) = betas.shape
972
- self.num_timesteps = int(timesteps)
973
- self.linear_start = linear_start
974
- self.linear_end = linear_end
975
- assert (
976
- alphas_cumprod.shape[0] == self.num_timesteps
977
- ), "alphas have to be defined for each timestep"
978
-
979
- to_torch = partial(torch.tensor, dtype=torch.float32)
980
-
981
- self.register_buffer("betas", to_torch(betas))
982
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
983
- self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
984
-
985
- # calculations for diffusion q(x_t | x_{t-1}) and others
986
- self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
987
- self.register_buffer(
988
- "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
989
- )
990
- self.register_buffer(
991
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
992
- )
993
- self.register_buffer(
994
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
995
- )
996
- self.register_buffer(
997
- "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
998
- )
999
-
1000
- def q_sample(self, x_start, t, noise=None):
1001
- noise = default(noise, lambda: torch.randn_like(x_start))
1002
- return (
1003
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
1004
- + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
1005
- * noise
1006
- )
1007
-
1008
- def forward(self, x):
1009
- z = self.model.encode(x)
1010
- if isinstance(z, DiagonalGaussianDistribution):
1011
- z = z.sample()
1012
- z = z * self.scale_factor
1013
- noise_level = torch.randint(
1014
- 0, self.max_noise_level, (x.shape[0],), device=x.device
1015
- ).long()
1016
- z = self.q_sample(z, noise_level)
1017
- if self.out_size is not None:
1018
- z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
1019
- # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
1020
- return z, noise_level
1021
-
1022
- def decode(self, z):
1023
- z = z / self.scale_factor
1024
- return self.model.decode(z)
1025
-
1026
-
1027
- class ConcatTimestepEmbedderND(AbstractEmbModel):
1028
- """embeds each dimension independently and concatenates them"""
1029
-
1030
- def __init__(self, outdim):
1031
- super().__init__()
1032
- self.timestep = Timestep(outdim)
1033
- self.outdim = outdim
1034
-
1035
- def forward(self, x):
1036
- if x.ndim == 1:
1037
- x = x[:, None]
1038
- assert len(x.shape) == 2
1039
- b, dims = x.shape[0], x.shape[1]
1040
- x = rearrange(x, "b d -> (b d)")
1041
- emb = self.timestep(x)
1042
- emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
1043
- return emb
1044
-
1045
-
1046
- class GaussianEncoder(Encoder, AbstractEmbModel):
1047
- def __init__(
1048
- self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
1049
- ):
1050
- super().__init__(*args, **kwargs)
1051
- self.posterior = DiagonalGaussianRegularizer()
1052
- self.weight = weight
1053
- self.flatten_output = flatten_output
1054
-
1055
- def forward(self, x) -> Tuple[Dict, torch.Tensor]:
1056
- z = super().forward(x)
1057
- z, log = self.posterior(z)
1058
- log["loss"] = log["kl_loss"]
1059
- log["weight"] = self.weight
1060
- if self.flatten_output:
1061
- z = rearrange(z, "b c h w -> b (h w ) c")
1062
- return log, z
 
 
 
1
+ from contextlib import nullcontext
2
+ from functools import partial
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import kornia
6
+ import numpy as np
7
+ import open_clip
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from omegaconf import ListConfig
12
+ from torch.utils.checkpoint import checkpoint
13
+ from transformers import (
14
+ ByT5Tokenizer,
15
+ CLIPTextModel,
16
+ CLIPTokenizer,
17
+ T5EncoderModel,
18
+ T5Tokenizer,
19
+ )
20
+
21
+ from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
22
+ from ...modules.diffusionmodules.model import Encoder
23
+ from ...modules.diffusionmodules.openaimodel import Timestep
24
+ from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
25
+ from ...modules.distributions.distributions import DiagonalGaussianDistribution
26
+ from ...util import (
27
+ autocast,
28
+ count_params,
29
+ default,
30
+ disabled_train,
31
+ expand_dims_like,
32
+ instantiate_from_config,
33
+ )
34
+
35
+ #from CKPT_PTH import SDXL_CLIP1_PATH, SDXL_CLIP2_CKPT_PTH
36
+ from CKPT_PTH import SDXL_CLIP1_PATH
37
+
38
+ class AbstractEmbModel(nn.Module):
39
+ def __init__(self):
40
+ super().__init__()
41
+ self._is_trainable = None
42
+ self._ucg_rate = None
43
+ self._input_key = None
44
+
45
+ @property
46
+ def is_trainable(self) -> bool:
47
+ return self._is_trainable
48
+
49
+ @property
50
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
51
+ return self._ucg_rate
52
+
53
+ @property
54
+ def input_key(self) -> str:
55
+ return self._input_key
56
+
57
+ @is_trainable.setter
58
+ def is_trainable(self, value: bool):
59
+ self._is_trainable = value
60
+
61
+ @ucg_rate.setter
62
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
63
+ self._ucg_rate = value
64
+
65
+ @input_key.setter
66
+ def input_key(self, value: str):
67
+ self._input_key = value
68
+
69
+ @is_trainable.deleter
70
+ def is_trainable(self):
71
+ del self._is_trainable
72
+
73
+ @ucg_rate.deleter
74
+ def ucg_rate(self):
75
+ del self._ucg_rate
76
+
77
+ @input_key.deleter
78
+ def input_key(self):
79
+ del self._input_key
80
+
81
+
82
+ class GeneralConditioner(nn.Module):
83
+ OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
84
+ KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1, 'control_vector': 1}
85
+
86
+ def __init__(self, emb_models: Union[List, ListConfig]):
87
+ super().__init__()
88
+ embedders = []
89
+ for n, embconfig in enumerate(emb_models):
90
+ embedder = instantiate_from_config(embconfig)
91
+ assert isinstance(
92
+ embedder, AbstractEmbModel
93
+ ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
94
+ embedder.is_trainable = embconfig.get("is_trainable", False)
95
+ embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
96
+ if not embedder.is_trainable:
97
+ embedder.train = disabled_train
98
+ for param in embedder.parameters():
99
+ param.requires_grad = False
100
+ embedder.eval()
101
+ print(
102
+ f"Initialized embedder #{n}: {embedder.__class__.__name__} "
103
+ f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
104
+ )
105
+
106
+ if "input_key" in embconfig:
107
+ embedder.input_key = embconfig["input_key"]
108
+ elif "input_keys" in embconfig:
109
+ embedder.input_keys = embconfig["input_keys"]
110
+ else:
111
+ raise KeyError(
112
+ f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
113
+ )
114
+
115
+ embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
116
+ if embedder.legacy_ucg_val is not None:
117
+ embedder.ucg_prng = np.random.RandomState()
118
+
119
+ embedders.append(embedder)
120
+ self.embedders = nn.ModuleList(embedders)
121
+
122
+ def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
123
+ assert embedder.legacy_ucg_val is not None
124
+ p = embedder.ucg_rate
125
+ val = embedder.legacy_ucg_val
126
+ for i in range(len(batch[embedder.input_key])):
127
+ if embedder.ucg_prng.choice(2, p=[1 - p, p]):
128
+ batch[embedder.input_key][i] = val
129
+ return batch
130
+
131
+ def forward(
132
+ self, batch: Dict, force_zero_embeddings: Optional[List] = None
133
+ ) -> Dict:
134
+ output = dict()
135
+ if force_zero_embeddings is None:
136
+ force_zero_embeddings = []
137
+ for embedder in self.embedders:
138
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
139
+ with embedding_context():
140
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
141
+ if embedder.legacy_ucg_val is not None:
142
+ batch = self.possibly_get_ucg_val(embedder, batch)
143
+ emb_out = embedder(batch[embedder.input_key])
144
+ elif hasattr(embedder, "input_keys"):
145
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
146
+ assert isinstance(
147
+ emb_out, (torch.Tensor, list, tuple)
148
+ ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
149
+ if not isinstance(emb_out, (list, tuple)):
150
+ emb_out = [emb_out]
151
+ for emb in emb_out:
152
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
153
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
154
+ emb = (
155
+ expand_dims_like(
156
+ torch.bernoulli(
157
+ (1.0 - embedder.ucg_rate)
158
+ * torch.ones(emb.shape[0], device=emb.device)
159
+ ),
160
+ emb,
161
+ )
162
+ * emb
163
+ )
164
+ if (
165
+ hasattr(embedder, "input_key")
166
+ and embedder.input_key in force_zero_embeddings
167
+ ):
168
+ emb = torch.zeros_like(emb)
169
+ if out_key in output:
170
+ output[out_key] = torch.cat(
171
+ (output[out_key], emb), self.KEY2CATDIM[out_key]
172
+ )
173
+ else:
174
+ output[out_key] = emb
175
+ return output
176
+
177
+ def get_unconditional_conditioning(
178
+ self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
179
+ ):
180
+ if force_uc_zero_embeddings is None:
181
+ force_uc_zero_embeddings = []
182
+ ucg_rates = list()
183
+ for embedder in self.embedders:
184
+ ucg_rates.append(embedder.ucg_rate)
185
+ embedder.ucg_rate = 0.0
186
+ c = self(batch_c)
187
+ uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
188
+
189
+ for embedder, rate in zip(self.embedders, ucg_rates):
190
+ embedder.ucg_rate = rate
191
+ return c, uc
192
+
193
+
194
+ class GeneralConditionerWithControl(GeneralConditioner):
195
+ def forward(
196
+ self, batch: Dict, force_zero_embeddings: Optional[List] = None
197
+ ) -> Dict:
198
+ output = dict()
199
+ if force_zero_embeddings is None:
200
+ force_zero_embeddings = []
201
+ for embedder in self.embedders:
202
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
203
+ with embedding_context():
204
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
205
+ if embedder.legacy_ucg_val is not None:
206
+ batch = self.possibly_get_ucg_val(embedder, batch)
207
+ emb_out = embedder(batch[embedder.input_key])
208
+ elif hasattr(embedder, "input_keys"):
209
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
210
+ assert isinstance(
211
+ emb_out, (torch.Tensor, list, tuple)
212
+ ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
213
+ if not isinstance(emb_out, (list, tuple)):
214
+ emb_out = [emb_out]
215
+ for emb in emb_out:
216
+ if 'control_vector' in embedder.input_key:
217
+ out_key = 'control_vector'
218
+ else:
219
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
220
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
221
+ emb = (
222
+ expand_dims_like(
223
+ torch.bernoulli(
224
+ (1.0 - embedder.ucg_rate)
225
+ * torch.ones(emb.shape[0], device=emb.device)
226
+ ),
227
+ emb,
228
+ )
229
+ * emb
230
+ )
231
+ if (
232
+ hasattr(embedder, "input_key")
233
+ and embedder.input_key in force_zero_embeddings
234
+ ):
235
+ emb = torch.zeros_like(emb)
236
+ if out_key in output:
237
+ output[out_key] = torch.cat(
238
+ (output[out_key], emb), self.KEY2CATDIM[out_key]
239
+ )
240
+ else:
241
+ output[out_key] = emb
242
+
243
+ output["control"] = batch["control"]
244
+ return output
245
+
246
+
247
+ class PreparedConditioner(nn.Module):
248
+ def __init__(self, cond_pth, un_cond_pth=None):
249
+ super().__init__()
250
+ conditions = torch.load(cond_pth)
251
+ for k, v in conditions.items():
252
+ self.register_buffer(k, v)
253
+ self.un_cond_pth = un_cond_pth
254
+ if un_cond_pth is not None:
255
+ un_conditions = torch.load(un_cond_pth)
256
+ for k, v in un_conditions.items():
257
+ self.register_buffer(k+'_uc', v)
258
+
259
+
260
+ @torch.no_grad()
261
+ def forward(
262
+ self, batch: Dict, return_uc=False
263
+ ) -> Dict:
264
+ output = dict()
265
+ for k, v in self.state_dict().items():
266
+ if not return_uc:
267
+ if k.endswith("_uc"):
268
+ continue
269
+ else:
270
+ output[k] = v.detach().clone().repeat(batch['control'].shape[0], *[1 for _ in range(v.ndim - 1)])
271
+ else:
272
+ if k.endswith("_uc"):
273
+ output[k[:-3]] = v.detach().clone().repeat(batch['control'].shape[0], *[1 for _ in range(v.ndim - 1)])
274
+ else:
275
+ continue
276
+ output["control"] = batch["control"]
277
+
278
+ for k, v in output.items():
279
+ if isinstance(v, torch.Tensor):
280
+ assert (torch.isnan(v).any()) is not None
281
+ return output
282
+
283
+ def get_unconditional_conditioning(
284
+ self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
285
+ ):
286
+ c = self(batch_c)
287
+ if self.un_cond_pth is not None:
288
+ uc = self(batch_c, return_uc=True)
289
+ else:
290
+ uc = None
291
+ return c, uc
292
+
293
+
294
+
295
+ class InceptionV3(nn.Module):
296
+ """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
297
+ port with an additional squeeze at the end"""
298
+
299
+ def __init__(self, normalize_input=False, **kwargs):
300
+ super().__init__()
301
+ from pytorch_fid import inception
302
+
303
+ kwargs["resize_input"] = True
304
+ self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
305
+
306
+ def forward(self, inp):
307
+ # inp = kornia.geometry.resize(inp, (299, 299),
308
+ # interpolation='bicubic',
309
+ # align_corners=False,
310
+ # antialias=True)
311
+ # inp = inp.clamp(min=-1, max=1)
312
+
313
+ outp = self.model(inp)
314
+
315
+ if len(outp) == 1:
316
+ return outp[0].squeeze()
317
+
318
+ return outp
319
+
320
+
321
+ class IdentityEncoder(AbstractEmbModel):
322
+ def encode(self, x):
323
+ return x
324
+
325
+ def forward(self, x):
326
+ return x
327
+
328
+
329
+ class ClassEmbedder(AbstractEmbModel):
330
+ def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
331
+ super().__init__()
332
+ self.embedding = nn.Embedding(n_classes, embed_dim)
333
+ self.n_classes = n_classes
334
+ self.add_sequence_dim = add_sequence_dim
335
+
336
+ def forward(self, c):
337
+ c = self.embedding(c)
338
+ if self.add_sequence_dim:
339
+ c = c[:, None, :]
340
+ return c
341
+
342
+ def get_unconditional_conditioning(self, bs, device="cuda"):
343
+ uc_class = (
344
+ self.n_classes - 1
345
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
346
+ uc = torch.ones((bs,), device=device) * uc_class
347
+ uc = {self.key: uc.long()}
348
+ return uc
349
+
350
+
351
+ class ClassEmbedderForMultiCond(ClassEmbedder):
352
+ def forward(self, batch, key=None, disable_dropout=False):
353
+ out = batch
354
+ key = default(key, self.key)
355
+ islist = isinstance(batch[key], list)
356
+ if islist:
357
+ batch[key] = batch[key][0]
358
+ c_out = super().forward(batch, key, disable_dropout)
359
+ out[key] = [c_out] if islist else c_out
360
+ return out
361
+
362
+
363
+ class FrozenT5Embedder(AbstractEmbModel):
364
+ """Uses the T5 transformer encoder for text"""
365
+
366
+ def __init__(
367
+ self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
368
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
369
+ super().__init__()
370
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
371
+ self.transformer = T5EncoderModel.from_pretrained(version)
372
+ self.device = device
373
+ self.max_length = max_length
374
+ if freeze:
375
+ self.freeze()
376
+
377
+ def freeze(self):
378
+ self.transformer = self.transformer.eval()
379
+
380
+ for param in self.parameters():
381
+ param.requires_grad = False
382
+
383
+ # @autocast
384
+ def forward(self, text):
385
+ batch_encoding = self.tokenizer(
386
+ text,
387
+ truncation=True,
388
+ max_length=self.max_length,
389
+ return_length=True,
390
+ return_overflowing_tokens=False,
391
+ padding="max_length",
392
+ return_tensors="pt",
393
+ )
394
+ tokens = batch_encoding["input_ids"].to(self.device)
395
+ with torch.autocast("cuda", enabled=False):
396
+ outputs = self.transformer(input_ids=tokens)
397
+ z = outputs.last_hidden_state
398
+ return z
399
+
400
+ def encode(self, text):
401
+ return self(text)
402
+
403
+
404
+ class FrozenByT5Embedder(AbstractEmbModel):
405
+ """
406
+ Uses the ByT5 transformer encoder for text. Is character-aware.
407
+ """
408
+
409
+ def __init__(
410
+ self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
411
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
412
+ super().__init__()
413
+ self.tokenizer = ByT5Tokenizer.from_pretrained(version)
414
+ self.transformer = T5EncoderModel.from_pretrained(version)
415
+ self.device = device
416
+ self.max_length = max_length
417
+ if freeze:
418
+ self.freeze()
419
+
420
+ def freeze(self):
421
+ self.transformer = self.transformer.eval()
422
+
423
+ for param in self.parameters():
424
+ param.requires_grad = False
425
+
426
+ def forward(self, text):
427
+ batch_encoding = self.tokenizer(
428
+ text,
429
+ truncation=True,
430
+ max_length=self.max_length,
431
+ return_length=True,
432
+ return_overflowing_tokens=False,
433
+ padding="max_length",
434
+ return_tensors="pt",
435
+ )
436
+ tokens = batch_encoding["input_ids"].to(self.device)
437
+ with torch.autocast("cuda", enabled=False):
438
+ outputs = self.transformer(input_ids=tokens)
439
+ z = outputs.last_hidden_state
440
+ return z
441
+
442
+ def encode(self, text):
443
+ return self(text)
444
+
445
+
446
+ class FrozenCLIPEmbedder(AbstractEmbModel):
447
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
448
+
449
+ LAYERS = ["last", "pooled", "hidden"]
450
+
451
+ def __init__(
452
+ self,
453
+ version="openai/clip-vit-large-patch14",
454
+ device="cuda",
455
+ max_length=77,
456
+ freeze=True,
457
+ layer="last",
458
+ layer_idx=None,
459
+ always_return_pooled=False,
460
+ ): # clip-vit-base-patch32
461
+ super().__init__()
462
+ assert layer in self.LAYERS
463
+ self.tokenizer = CLIPTokenizer.from_pretrained(version if SDXL_CLIP1_PATH is None else SDXL_CLIP1_PATH)
464
+ self.transformer = CLIPTextModel.from_pretrained(version if SDXL_CLIP1_PATH is None else SDXL_CLIP1_PATH)
465
+ self.device = device
466
+ self.max_length = max_length
467
+ if freeze:
468
+ self.freeze()
469
+ self.layer = layer
470
+ self.layer_idx = layer_idx
471
+ self.return_pooled = always_return_pooled
472
+ if layer == "hidden":
473
+ assert layer_idx is not None
474
+ assert 0 <= abs(layer_idx) <= 12
475
+
476
+ def freeze(self):
477
+ self.transformer = self.transformer.eval()
478
+
479
+ for param in self.parameters():
480
+ param.requires_grad = False
481
+
482
+ @autocast
483
+ def forward(self, text):
484
+ batch_encoding = self.tokenizer(
485
+ text,
486
+ truncation=True,
487
+ max_length=self.max_length,
488
+ return_length=True,
489
+ return_overflowing_tokens=False,
490
+ padding="max_length",
491
+ return_tensors="pt",
492
+ )
493
+ tokens = batch_encoding["input_ids"].to(self.device)
494
+ outputs = self.transformer(
495
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
496
+ )
497
+ if self.layer == "last":
498
+ z = outputs.last_hidden_state
499
+ elif self.layer == "pooled":
500
+ z = outputs.pooler_output[:, None, :]
501
+ else:
502
+ z = outputs.hidden_states[self.layer_idx]
503
+ if self.return_pooled:
504
+ return z, outputs.pooler_output
505
+ return z
506
+
507
+ def encode(self, text):
508
+ return self(text)
509
+
510
+
511
+ class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
512
+ """
513
+ Uses the OpenCLIP transformer encoder for text
514
+ """
515
+
516
+ LAYERS = ["pooled", "last", "penultimate"]
517
+
518
+ def __init__(
519
+ self,
520
+ arch="ViT-H-14",
521
+ version="laion2b_s32b_b79k",
522
+ device="cuda",
523
+ max_length=77,
524
+ freeze=True,
525
+ layer="last",
526
+ always_return_pooled=False,
527
+ legacy=True,
528
+ ):
529
+ super().__init__()
530
+ assert layer in self.LAYERS
531
+ model, _, _ = open_clip.create_model_and_transforms(
532
+ arch,
533
+ device=torch.device("cpu"),
534
+ #pretrained=version if SDXL_CLIP2_CKPT_PTH is None else SDXL_CLIP2_CKPT_PTH,
535
+ pretrained=version,
536
+ )
537
+ del model.visual
538
+ self.model = model
539
+
540
+ self.device = device
541
+ self.max_length = max_length
542
+ self.return_pooled = always_return_pooled
543
+ if freeze:
544
+ self.freeze()
545
+ self.layer = layer
546
+ if self.layer == "last":
547
+ self.layer_idx = 0
548
+ elif self.layer == "penultimate":
549
+ self.layer_idx = 1
550
+ else:
551
+ raise NotImplementedError()
552
+ self.legacy = legacy
553
+
554
+ def freeze(self):
555
+ self.model = self.model.eval()
556
+ for param in self.parameters():
557
+ param.requires_grad = False
558
+
559
+ @autocast
560
+ def forward(self, text):
561
+ tokens = open_clip.tokenize(text)
562
+ z = self.encode_with_transformer(tokens.to(self.device))
563
+ if not self.return_pooled and self.legacy:
564
+ return z
565
+ if self.return_pooled:
566
+ assert not self.legacy
567
+ return z[self.layer], z["pooled"]
568
+ return z[self.layer]
569
+
570
+ def encode_with_transformer(self, text):
571
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
572
+ x = x + self.model.positional_embedding
573
+ x = x.permute(1, 0, 2) # NLD -> LND
574
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
575
+ if self.legacy:
576
+ x = x[self.layer]
577
+ x = self.model.ln_final(x)
578
+ return x
579
+ else:
580
+ # x is a dict and will stay a dict
581
+ o = x["last"]
582
+ o = self.model.ln_final(o)
583
+ pooled = self.pool(o, text)
584
+ x["pooled"] = pooled
585
+ return x
586
+
587
+ def pool(self, x, text):
588
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
589
+ x = (
590
+ x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
591
+ @ self.model.text_projection
592
+ )
593
+ return x
594
+
595
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
596
+ outputs = {}
597
+ for i, r in enumerate(self.model.transformer.resblocks):
598
+ if i == len(self.model.transformer.resblocks) - 1:
599
+ outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
600
+ if (
601
+ self.model.transformer.grad_checkpointing
602
+ and not torch.jit.is_scripting()
603
+ ):
604
+ x = checkpoint(r, x, attn_mask)
605
+ else:
606
+ x = r(x, attn_mask=attn_mask)
607
+ outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
608
+ return outputs
609
+
610
+ def encode(self, text):
611
+ return self(text)
612
+
613
+
614
+ class FrozenOpenCLIPEmbedder(AbstractEmbModel):
615
+ LAYERS = [
616
+ # "pooled",
617
+ "last",
618
+ "penultimate",
619
+ ]
620
+
621
+ def __init__(
622
+ self,
623
+ arch="ViT-H-14",
624
+ version="laion2b_s32b_b79k",
625
+ device="cuda",
626
+ max_length=77,
627
+ freeze=True,
628
+ layer="last",
629
+ ):
630
+ super().__init__()
631
+ assert layer in self.LAYERS
632
+ model, _, _ = open_clip.create_model_and_transforms(
633
+ arch, device=torch.device("cpu"), pretrained=version
634
+ )
635
+ del model.visual
636
+ self.model = model
637
+
638
+ self.device = device
639
+ self.max_length = max_length
640
+ if freeze:
641
+ self.freeze()
642
+ self.layer = layer
643
+ if self.layer == "last":
644
+ self.layer_idx = 0
645
+ elif self.layer == "penultimate":
646
+ self.layer_idx = 1
647
+ else:
648
+ raise NotImplementedError()
649
+
650
+ def freeze(self):
651
+ self.model = self.model.eval()
652
+ for param in self.parameters():
653
+ param.requires_grad = False
654
+
655
+ def forward(self, text):
656
+ tokens = open_clip.tokenize(text)
657
+ z = self.encode_with_transformer(tokens.to(self.device))
658
+ return z
659
+
660
+ def encode_with_transformer(self, text):
661
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
662
+ x = x + self.model.positional_embedding
663
+ x = x.permute(1, 0, 2) # NLD -> LND
664
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
665
+ x = x.permute(1, 0, 2) # LND -> NLD
666
+ x = self.model.ln_final(x)
667
+ return x
668
+
669
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
670
+ for i, r in enumerate(self.model.transformer.resblocks):
671
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
672
+ break
673
+ if (
674
+ self.model.transformer.grad_checkpointing
675
+ and not torch.jit.is_scripting()
676
+ ):
677
+ x = checkpoint(r, x, attn_mask)
678
+ else:
679
+ x = r(x, attn_mask=attn_mask)
680
+ return x
681
+
682
+ def encode(self, text):
683
+ return self(text)
684
+
685
+
686
+ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
687
+ """
688
+ Uses the OpenCLIP vision transformer encoder for images
689
+ """
690
+
691
+ def __init__(
692
+ self,
693
+ arch="ViT-H-14",
694
+ version="laion2b_s32b_b79k",
695
+ device="cuda",
696
+ max_length=77,
697
+ freeze=True,
698
+ antialias=True,
699
+ ucg_rate=0.0,
700
+ unsqueeze_dim=False,
701
+ repeat_to_max_len=False,
702
+ num_image_crops=0,
703
+ output_tokens=False,
704
+ ):
705
+ super().__init__()
706
+ model, _, _ = open_clip.create_model_and_transforms(
707
+ arch,
708
+ device=torch.device("cpu"),
709
+ pretrained=version,
710
+ )
711
+ del model.transformer
712
+ self.model = model
713
+ self.max_crops = num_image_crops
714
+ self.pad_to_max_len = self.max_crops > 0
715
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
716
+ self.device = device
717
+ self.max_length = max_length
718
+ if freeze:
719
+ self.freeze()
720
+
721
+ self.antialias = antialias
722
+
723
+ self.register_buffer(
724
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
725
+ )
726
+ self.register_buffer(
727
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
728
+ )
729
+ self.ucg_rate = ucg_rate
730
+ self.unsqueeze_dim = unsqueeze_dim
731
+ self.stored_batch = None
732
+ self.model.visual.output_tokens = output_tokens
733
+ self.output_tokens = output_tokens
734
+
735
+ def preprocess(self, x):
736
+ # normalize to [0,1]
737
+ x = kornia.geometry.resize(
738
+ x,
739
+ (224, 224),
740
+ interpolation="bicubic",
741
+ align_corners=True,
742
+ antialias=self.antialias,
743
+ )
744
+ x = (x + 1.0) / 2.0
745
+ # renormalize according to clip
746
+ x = kornia.enhance.normalize(x, self.mean, self.std)
747
+ return x
748
+
749
+ def freeze(self):
750
+ self.model = self.model.eval()
751
+ for param in self.parameters():
752
+ param.requires_grad = False
753
+
754
+ @autocast
755
+ def forward(self, image, no_dropout=False):
756
+ z = self.encode_with_vision_transformer(image)
757
+ tokens = None
758
+ if self.output_tokens:
759
+ z, tokens = z[0], z[1]
760
+ z = z.to(image.dtype)
761
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
762
+ z = (
763
+ torch.bernoulli(
764
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
765
+ )[:, None]
766
+ * z
767
+ )
768
+ if tokens is not None:
769
+ tokens = (
770
+ expand_dims_like(
771
+ torch.bernoulli(
772
+ (1.0 - self.ucg_rate)
773
+ * torch.ones(tokens.shape[0], device=tokens.device)
774
+ ),
775
+ tokens,
776
+ )
777
+ * tokens
778
+ )
779
+ if self.unsqueeze_dim:
780
+ z = z[:, None, :]
781
+ if self.output_tokens:
782
+ assert not self.repeat_to_max_len
783
+ assert not self.pad_to_max_len
784
+ return tokens, z
785
+ if self.repeat_to_max_len:
786
+ if z.dim() == 2:
787
+ z_ = z[:, None, :]
788
+ else:
789
+ z_ = z
790
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
791
+ elif self.pad_to_max_len:
792
+ assert z.dim() == 3
793
+ z_pad = torch.cat(
794
+ (
795
+ z,
796
+ torch.zeros(
797
+ z.shape[0],
798
+ self.max_length - z.shape[1],
799
+ z.shape[2],
800
+ device=z.device,
801
+ ),
802
+ ),
803
+ 1,
804
+ )
805
+ return z_pad, z_pad[:, 0, ...]
806
+ return z
807
+
808
+ def encode_with_vision_transformer(self, img):
809
+ # if self.max_crops > 0:
810
+ # img = self.preprocess_by_cropping(img)
811
+ if img.dim() == 5:
812
+ assert self.max_crops == img.shape[1]
813
+ img = rearrange(img, "b n c h w -> (b n) c h w")
814
+ img = self.preprocess(img)
815
+ if not self.output_tokens:
816
+ assert not self.model.visual.output_tokens
817
+ x = self.model.visual(img)
818
+ tokens = None
819
+ else:
820
+ assert self.model.visual.output_tokens
821
+ x, tokens = self.model.visual(img)
822
+ if self.max_crops > 0:
823
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
824
+ # drop out between 0 and all along the sequence axis
825
+ x = (
826
+ torch.bernoulli(
827
+ (1.0 - self.ucg_rate)
828
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
829
+ )
830
+ * x
831
+ )
832
+ if tokens is not None:
833
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
834
+ print(
835
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
836
+ f"Check what you are doing, and then remove this message."
837
+ )
838
+ if self.output_tokens:
839
+ return x, tokens
840
+ return x
841
+
842
+ def encode(self, text):
843
+ return self(text)
844
+
845
+
846
+ class FrozenCLIPT5Encoder(AbstractEmbModel):
847
+ def __init__(
848
+ self,
849
+ clip_version="openai/clip-vit-large-patch14",
850
+ t5_version="google/t5-v1_1-xl",
851
+ device="cuda",
852
+ clip_max_length=77,
853
+ t5_max_length=77,
854
+ ):
855
+ super().__init__()
856
+ self.clip_encoder = FrozenCLIPEmbedder(
857
+ clip_version, device, max_length=clip_max_length
858
+ )
859
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
860
+ print(
861
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
862
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
863
+ )
864
+
865
+ def encode(self, text):
866
+ return self(text)
867
+
868
+ def forward(self, text):
869
+ clip_z = self.clip_encoder.encode(text)
870
+ t5_z = self.t5_encoder.encode(text)
871
+ return [clip_z, t5_z]
872
+
873
+
874
+ class SpatialRescaler(nn.Module):
875
+ def __init__(
876
+ self,
877
+ n_stages=1,
878
+ method="bilinear",
879
+ multiplier=0.5,
880
+ in_channels=3,
881
+ out_channels=None,
882
+ bias=False,
883
+ wrap_video=False,
884
+ kernel_size=1,
885
+ remap_output=False,
886
+ ):
887
+ super().__init__()
888
+ self.n_stages = n_stages
889
+ assert self.n_stages >= 0
890
+ assert method in [
891
+ "nearest",
892
+ "linear",
893
+ "bilinear",
894
+ "trilinear",
895
+ "bicubic",
896
+ "area",
897
+ ]
898
+ self.multiplier = multiplier
899
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
900
+ self.remap_output = out_channels is not None or remap_output
901
+ if self.remap_output:
902
+ print(
903
+ f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
904
+ )
905
+ self.channel_mapper = nn.Conv2d(
906
+ in_channels,
907
+ out_channels,
908
+ kernel_size=kernel_size,
909
+ bias=bias,
910
+ padding=kernel_size // 2,
911
+ )
912
+ self.wrap_video = wrap_video
913
+
914
+ def forward(self, x):
915
+ if self.wrap_video and x.ndim == 5:
916
+ B, C, T, H, W = x.shape
917
+ x = rearrange(x, "b c t h w -> b t c h w")
918
+ x = rearrange(x, "b t c h w -> (b t) c h w")
919
+
920
+ for stage in range(self.n_stages):
921
+ x = self.interpolator(x, scale_factor=self.multiplier)
922
+
923
+ if self.wrap_video:
924
+ x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
925
+ x = rearrange(x, "b t c h w -> b c t h w")
926
+ if self.remap_output:
927
+ x = self.channel_mapper(x)
928
+ return x
929
+
930
+ def encode(self, x):
931
+ return self(x)
932
+
933
+
934
+ class LowScaleEncoder(nn.Module):
935
+ def __init__(
936
+ self,
937
+ model_config,
938
+ linear_start,
939
+ linear_end,
940
+ timesteps=1000,
941
+ max_noise_level=250,
942
+ output_size=64,
943
+ scale_factor=1.0,
944
+ ):
945
+ super().__init__()
946
+ self.max_noise_level = max_noise_level
947
+ self.model = instantiate_from_config(model_config)
948
+ self.augmentation_schedule = self.register_schedule(
949
+ timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
950
+ )
951
+ self.out_size = output_size
952
+ self.scale_factor = scale_factor
953
+
954
+ def register_schedule(
955
+ self,
956
+ beta_schedule="linear",
957
+ timesteps=1000,
958
+ linear_start=1e-4,
959
+ linear_end=2e-2,
960
+ cosine_s=8e-3,
961
+ ):
962
+ betas = make_beta_schedule(
963
+ beta_schedule,
964
+ timesteps,
965
+ linear_start=linear_start,
966
+ linear_end=linear_end,
967
+ cosine_s=cosine_s,
968
+ )
969
+ alphas = 1.0 - betas
970
+ alphas_cumprod = np.cumprod(alphas, axis=0)
971
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
972
+
973
+ (timesteps,) = betas.shape
974
+ self.num_timesteps = int(timesteps)
975
+ self.linear_start = linear_start
976
+ self.linear_end = linear_end
977
+ assert (
978
+ alphas_cumprod.shape[0] == self.num_timesteps
979
+ ), "alphas have to be defined for each timestep"
980
+
981
+ to_torch = partial(torch.tensor, dtype=torch.float32)
982
+
983
+ self.register_buffer("betas", to_torch(betas))
984
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
985
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
986
+
987
+ # calculations for diffusion q(x_t | x_{t-1}) and others
988
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
989
+ self.register_buffer(
990
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
991
+ )
992
+ self.register_buffer(
993
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
994
+ )
995
+ self.register_buffer(
996
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
997
+ )
998
+ self.register_buffer(
999
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
1000
+ )
1001
+
1002
+ def q_sample(self, x_start, t, noise=None):
1003
+ noise = default(noise, lambda: torch.randn_like(x_start))
1004
+ return (
1005
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
1006
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
1007
+ * noise
1008
+ )
1009
+
1010
+ def forward(self, x):
1011
+ z = self.model.encode(x)
1012
+ if isinstance(z, DiagonalGaussianDistribution):
1013
+ z = z.sample()
1014
+ z = z * self.scale_factor
1015
+ noise_level = torch.randint(
1016
+ 0, self.max_noise_level, (x.shape[0],), device=x.device
1017
+ ).long()
1018
+ z = self.q_sample(z, noise_level)
1019
+ if self.out_size is not None:
1020
+ z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
1021
+ # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
1022
+ return z, noise_level
1023
+
1024
+ def decode(self, z):
1025
+ z = z / self.scale_factor
1026
+ return self.model.decode(z)
1027
+
1028
+
1029
+ class ConcatTimestepEmbedderND(AbstractEmbModel):
1030
+ """embeds each dimension independently and concatenates them"""
1031
+
1032
+ def __init__(self, outdim):
1033
+ super().__init__()
1034
+ self.timestep = Timestep(outdim)
1035
+ self.outdim = outdim
1036
+
1037
+ def forward(self, x):
1038
+ if x.ndim == 1:
1039
+ x = x[:, None]
1040
+ assert len(x.shape) == 2
1041
+ b, dims = x.shape[0], x.shape[1]
1042
+ x = rearrange(x, "b d -> (b d)")
1043
+ emb = self.timestep(x)
1044
+ emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
1045
+ return emb
1046
+
1047
+
1048
+ class GaussianEncoder(Encoder, AbstractEmbModel):
1049
+ def __init__(
1050
+ self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
1051
+ ):
1052
+ super().__init__(*args, **kwargs)
1053
+ self.posterior = DiagonalGaussianRegularizer()
1054
+ self.weight = weight
1055
+ self.flatten_output = flatten_output
1056
+
1057
+ def forward(self, x) -> Tuple[Dict, torch.Tensor]:
1058
+ z = super().forward(x)
1059
+ z, log = self.posterior(z)
1060
+ log["loss"] = log["kl_loss"]
1061
+ log["weight"] = self.weight
1062
+ if self.flatten_output:
1063
+ z = rearrange(z, "b c h w -> b (h w ) c")
1064
+ return log, z