multimodalart HF staff commited on
Commit
097567a
1 Parent(s): 90e5afe

Delete sgm

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. sgm/__init__.py +0 -4
  2. sgm/data/__init__.py +0 -1
  3. sgm/data/cifar10.py +0 -67
  4. sgm/data/dataset.py +0 -80
  5. sgm/data/mnist.py +0 -85
  6. sgm/inference/api.py +0 -385
  7. sgm/inference/helpers.py +0 -305
  8. sgm/lr_scheduler.py +0 -135
  9. sgm/models/__init__.py +0 -2
  10. sgm/models/autoencoder.py +0 -615
  11. sgm/models/diffusion.py +0 -341
  12. sgm/modules/__init__.py +0 -6
  13. sgm/modules/attention.py +0 -759
  14. sgm/modules/autoencoding/__init__.py +0 -0
  15. sgm/modules/autoencoding/losses/__init__.py +0 -7
  16. sgm/modules/autoencoding/losses/discriminator_loss.py +0 -306
  17. sgm/modules/autoencoding/losses/lpips.py +0 -73
  18. sgm/modules/autoencoding/lpips/__init__.py +0 -0
  19. sgm/modules/autoencoding/lpips/loss/.gitignore +0 -1
  20. sgm/modules/autoencoding/lpips/loss/LICENSE +0 -23
  21. sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
  22. sgm/modules/autoencoding/lpips/loss/lpips.py +0 -147
  23. sgm/modules/autoencoding/lpips/model/LICENSE +0 -58
  24. sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
  25. sgm/modules/autoencoding/lpips/model/model.py +0 -88
  26. sgm/modules/autoencoding/lpips/util.py +0 -128
  27. sgm/modules/autoencoding/lpips/vqperceptual.py +0 -17
  28. sgm/modules/autoencoding/regularizers/__init__.py +0 -31
  29. sgm/modules/autoencoding/regularizers/base.py +0 -40
  30. sgm/modules/autoencoding/regularizers/quantize.py +0 -487
  31. sgm/modules/autoencoding/temporal_ae.py +0 -349
  32. sgm/modules/diffusionmodules/__init__.py +0 -0
  33. sgm/modules/diffusionmodules/denoiser.py +0 -75
  34. sgm/modules/diffusionmodules/denoiser_scaling.py +0 -59
  35. sgm/modules/diffusionmodules/denoiser_weighting.py +0 -24
  36. sgm/modules/diffusionmodules/discretizer.py +0 -69
  37. sgm/modules/diffusionmodules/guiders.py +0 -99
  38. sgm/modules/diffusionmodules/loss.py +0 -105
  39. sgm/modules/diffusionmodules/loss_weighting.py +0 -32
  40. sgm/modules/diffusionmodules/model.py +0 -748
  41. sgm/modules/diffusionmodules/openaimodel.py +0 -853
  42. sgm/modules/diffusionmodules/sampling.py +0 -362
  43. sgm/modules/diffusionmodules/sampling_utils.py +0 -43
  44. sgm/modules/diffusionmodules/sigma_sampling.py +0 -31
  45. sgm/modules/diffusionmodules/util.py +0 -369
  46. sgm/modules/diffusionmodules/video_model.py +0 -493
  47. sgm/modules/diffusionmodules/wrappers.py +0 -34
  48. sgm/modules/distributions/__init__.py +0 -0
  49. sgm/modules/distributions/distributions.py +0 -102
  50. sgm/modules/ema.py +0 -86
sgm/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .models import AutoencodingEngine, DiffusionEngine
2
- from .util import get_configs_path, instantiate_from_config
3
-
4
- __version__ = "0.1.0"
 
 
 
 
 
sgm/data/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .dataset import StableDataModuleFromConfig
 
 
sgm/data/cifar10.py DELETED
@@ -1,67 +0,0 @@
1
- import pytorch_lightning as pl
2
- import torchvision
3
- from torch.utils.data import DataLoader, Dataset
4
- from torchvision import transforms
5
-
6
-
7
- class CIFAR10DataDictWrapper(Dataset):
8
- def __init__(self, dset):
9
- super().__init__()
10
- self.dset = dset
11
-
12
- def __getitem__(self, i):
13
- x, y = self.dset[i]
14
- return {"jpg": x, "cls": y}
15
-
16
- def __len__(self):
17
- return len(self.dset)
18
-
19
-
20
- class CIFAR10Loader(pl.LightningDataModule):
21
- def __init__(self, batch_size, num_workers=0, shuffle=True):
22
- super().__init__()
23
-
24
- transform = transforms.Compose(
25
- [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
- )
27
-
28
- self.batch_size = batch_size
29
- self.num_workers = num_workers
30
- self.shuffle = shuffle
31
- self.train_dataset = CIFAR10DataDictWrapper(
32
- torchvision.datasets.CIFAR10(
33
- root=".data/", train=True, download=True, transform=transform
34
- )
35
- )
36
- self.test_dataset = CIFAR10DataDictWrapper(
37
- torchvision.datasets.CIFAR10(
38
- root=".data/", train=False, download=True, transform=transform
39
- )
40
- )
41
-
42
- def prepare_data(self):
43
- pass
44
-
45
- def train_dataloader(self):
46
- return DataLoader(
47
- self.train_dataset,
48
- batch_size=self.batch_size,
49
- shuffle=self.shuffle,
50
- num_workers=self.num_workers,
51
- )
52
-
53
- def test_dataloader(self):
54
- return DataLoader(
55
- self.test_dataset,
56
- batch_size=self.batch_size,
57
- shuffle=self.shuffle,
58
- num_workers=self.num_workers,
59
- )
60
-
61
- def val_dataloader(self):
62
- return DataLoader(
63
- self.test_dataset,
64
- batch_size=self.batch_size,
65
- shuffle=self.shuffle,
66
- num_workers=self.num_workers,
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/data/dataset.py DELETED
@@ -1,80 +0,0 @@
1
- from typing import Optional
2
-
3
- import torchdata.datapipes.iter
4
- import webdataset as wds
5
- from omegaconf import DictConfig
6
- from pytorch_lightning import LightningDataModule
7
-
8
- try:
9
- from sdata import create_dataset, create_dummy_dataset, create_loader
10
- except ImportError as e:
11
- print("#" * 100)
12
- print("Datasets not yet available")
13
- print("to enable, we need to add stable-datasets as a submodule")
14
- print("please use ``git submodule update --init --recursive``")
15
- print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16
- print("#" * 100)
17
- exit(1)
18
-
19
-
20
- class StableDataModuleFromConfig(LightningDataModule):
21
- def __init__(
22
- self,
23
- train: DictConfig,
24
- validation: Optional[DictConfig] = None,
25
- test: Optional[DictConfig] = None,
26
- skip_val_loader: bool = False,
27
- dummy: bool = False,
28
- ):
29
- super().__init__()
30
- self.train_config = train
31
- assert (
32
- "datapipeline" in self.train_config and "loader" in self.train_config
33
- ), "train config requires the fields `datapipeline` and `loader`"
34
-
35
- self.val_config = validation
36
- if not skip_val_loader:
37
- if self.val_config is not None:
38
- assert (
39
- "datapipeline" in self.val_config and "loader" in self.val_config
40
- ), "validation config requires the fields `datapipeline` and `loader`"
41
- else:
42
- print(
43
- "Warning: No Validation datapipeline defined, using that one from training"
44
- )
45
- self.val_config = train
46
-
47
- self.test_config = test
48
- if self.test_config is not None:
49
- assert (
50
- "datapipeline" in self.test_config and "loader" in self.test_config
51
- ), "test config requires the fields `datapipeline` and `loader`"
52
-
53
- self.dummy = dummy
54
- if self.dummy:
55
- print("#" * 100)
56
- print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57
- print("#" * 100)
58
-
59
- def setup(self, stage: str) -> None:
60
- print("Preparing datasets")
61
- if self.dummy:
62
- data_fn = create_dummy_dataset
63
- else:
64
- data_fn = create_dataset
65
-
66
- self.train_datapipeline = data_fn(**self.train_config.datapipeline)
67
- if self.val_config:
68
- self.val_datapipeline = data_fn(**self.val_config.datapipeline)
69
- if self.test_config:
70
- self.test_datapipeline = data_fn(**self.test_config.datapipeline)
71
-
72
- def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
73
- loader = create_loader(self.train_datapipeline, **self.train_config.loader)
74
- return loader
75
-
76
- def val_dataloader(self) -> wds.DataPipeline:
77
- return create_loader(self.val_datapipeline, **self.val_config.loader)
78
-
79
- def test_dataloader(self) -> wds.DataPipeline:
80
- return create_loader(self.test_datapipeline, **self.test_config.loader)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/data/mnist.py DELETED
@@ -1,85 +0,0 @@
1
- import pytorch_lightning as pl
2
- import torchvision
3
- from torch.utils.data import DataLoader, Dataset
4
- from torchvision import transforms
5
-
6
-
7
- class MNISTDataDictWrapper(Dataset):
8
- def __init__(self, dset):
9
- super().__init__()
10
- self.dset = dset
11
-
12
- def __getitem__(self, i):
13
- x, y = self.dset[i]
14
- return {"jpg": x, "cls": y}
15
-
16
- def __len__(self):
17
- return len(self.dset)
18
-
19
-
20
- class MNISTLoader(pl.LightningDataModule):
21
- def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
22
- super().__init__()
23
-
24
- transform = transforms.Compose(
25
- [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
- )
27
-
28
- self.batch_size = batch_size
29
- self.num_workers = num_workers
30
- self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
31
- self.shuffle = shuffle
32
- self.train_dataset = MNISTDataDictWrapper(
33
- torchvision.datasets.MNIST(
34
- root=".data/", train=True, download=True, transform=transform
35
- )
36
- )
37
- self.test_dataset = MNISTDataDictWrapper(
38
- torchvision.datasets.MNIST(
39
- root=".data/", train=False, download=True, transform=transform
40
- )
41
- )
42
-
43
- def prepare_data(self):
44
- pass
45
-
46
- def train_dataloader(self):
47
- return DataLoader(
48
- self.train_dataset,
49
- batch_size=self.batch_size,
50
- shuffle=self.shuffle,
51
- num_workers=self.num_workers,
52
- prefetch_factor=self.prefetch_factor,
53
- )
54
-
55
- def test_dataloader(self):
56
- return DataLoader(
57
- self.test_dataset,
58
- batch_size=self.batch_size,
59
- shuffle=self.shuffle,
60
- num_workers=self.num_workers,
61
- prefetch_factor=self.prefetch_factor,
62
- )
63
-
64
- def val_dataloader(self):
65
- return DataLoader(
66
- self.test_dataset,
67
- batch_size=self.batch_size,
68
- shuffle=self.shuffle,
69
- num_workers=self.num_workers,
70
- prefetch_factor=self.prefetch_factor,
71
- )
72
-
73
-
74
- if __name__ == "__main__":
75
- dset = MNISTDataDictWrapper(
76
- torchvision.datasets.MNIST(
77
- root=".data/",
78
- train=False,
79
- download=True,
80
- transform=transforms.Compose(
81
- [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
82
- ),
83
- )
84
- )
85
- ex = dset[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/inference/api.py DELETED
@@ -1,385 +0,0 @@
1
- import pathlib
2
- from dataclasses import asdict, dataclass
3
- from enum import Enum
4
- from typing import Optional
5
-
6
- from omegaconf import OmegaConf
7
-
8
- from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
9
- do_sample)
10
- from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
11
- DPMPP2SAncestralSampler,
12
- EulerAncestralSampler,
13
- EulerEDMSampler,
14
- HeunEDMSampler,
15
- LinearMultistepSampler)
16
- from sgm.util import load_model_from_config
17
-
18
-
19
- class ModelArchitecture(str, Enum):
20
- SD_2_1 = "stable-diffusion-v2-1"
21
- SD_2_1_768 = "stable-diffusion-v2-1-768"
22
- SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
23
- SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
24
- SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
25
- SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
26
-
27
-
28
- class Sampler(str, Enum):
29
- EULER_EDM = "EulerEDMSampler"
30
- HEUN_EDM = "HeunEDMSampler"
31
- EULER_ANCESTRAL = "EulerAncestralSampler"
32
- DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
33
- DPMPP2M = "DPMPP2MSampler"
34
- LINEAR_MULTISTEP = "LinearMultistepSampler"
35
-
36
-
37
- class Discretization(str, Enum):
38
- LEGACY_DDPM = "LegacyDDPMDiscretization"
39
- EDM = "EDMDiscretization"
40
-
41
-
42
- class Guider(str, Enum):
43
- VANILLA = "VanillaCFG"
44
- IDENTITY = "IdentityGuider"
45
-
46
-
47
- class Thresholder(str, Enum):
48
- NONE = "None"
49
-
50
-
51
- @dataclass
52
- class SamplingParams:
53
- width: int = 1024
54
- height: int = 1024
55
- steps: int = 50
56
- sampler: Sampler = Sampler.DPMPP2M
57
- discretization: Discretization = Discretization.LEGACY_DDPM
58
- guider: Guider = Guider.VANILLA
59
- thresholder: Thresholder = Thresholder.NONE
60
- scale: float = 6.0
61
- aesthetic_score: float = 5.0
62
- negative_aesthetic_score: float = 5.0
63
- img2img_strength: float = 1.0
64
- orig_width: int = 1024
65
- orig_height: int = 1024
66
- crop_coords_top: int = 0
67
- crop_coords_left: int = 0
68
- sigma_min: float = 0.0292
69
- sigma_max: float = 14.6146
70
- rho: float = 3.0
71
- s_churn: float = 0.0
72
- s_tmin: float = 0.0
73
- s_tmax: float = 999.0
74
- s_noise: float = 1.0
75
- eta: float = 1.0
76
- order: int = 4
77
-
78
-
79
- @dataclass
80
- class SamplingSpec:
81
- width: int
82
- height: int
83
- channels: int
84
- factor: int
85
- is_legacy: bool
86
- config: str
87
- ckpt: str
88
- is_guided: bool
89
-
90
-
91
- model_specs = {
92
- ModelArchitecture.SD_2_1: SamplingSpec(
93
- height=512,
94
- width=512,
95
- channels=4,
96
- factor=8,
97
- is_legacy=True,
98
- config="sd_2_1.yaml",
99
- ckpt="v2-1_512-ema-pruned.safetensors",
100
- is_guided=True,
101
- ),
102
- ModelArchitecture.SD_2_1_768: SamplingSpec(
103
- height=768,
104
- width=768,
105
- channels=4,
106
- factor=8,
107
- is_legacy=True,
108
- config="sd_2_1_768.yaml",
109
- ckpt="v2-1_768-ema-pruned.safetensors",
110
- is_guided=True,
111
- ),
112
- ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
113
- height=1024,
114
- width=1024,
115
- channels=4,
116
- factor=8,
117
- is_legacy=False,
118
- config="sd_xl_base.yaml",
119
- ckpt="sd_xl_base_0.9.safetensors",
120
- is_guided=True,
121
- ),
122
- ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
123
- height=1024,
124
- width=1024,
125
- channels=4,
126
- factor=8,
127
- is_legacy=True,
128
- config="sd_xl_refiner.yaml",
129
- ckpt="sd_xl_refiner_0.9.safetensors",
130
- is_guided=True,
131
- ),
132
- ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
133
- height=1024,
134
- width=1024,
135
- channels=4,
136
- factor=8,
137
- is_legacy=False,
138
- config="sd_xl_base.yaml",
139
- ckpt="sd_xl_base_1.0.safetensors",
140
- is_guided=True,
141
- ),
142
- ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
143
- height=1024,
144
- width=1024,
145
- channels=4,
146
- factor=8,
147
- is_legacy=True,
148
- config="sd_xl_refiner.yaml",
149
- ckpt="sd_xl_refiner_1.0.safetensors",
150
- is_guided=True,
151
- ),
152
- }
153
-
154
-
155
- class SamplingPipeline:
156
- def __init__(
157
- self,
158
- model_id: ModelArchitecture,
159
- model_path="checkpoints",
160
- config_path="configs/inference",
161
- device="cuda",
162
- use_fp16=True,
163
- ) -> None:
164
- if model_id not in model_specs:
165
- raise ValueError(f"Model {model_id} not supported")
166
- self.model_id = model_id
167
- self.specs = model_specs[self.model_id]
168
- self.config = str(pathlib.Path(config_path, self.specs.config))
169
- self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
170
- self.device = device
171
- self.model = self._load_model(device=device, use_fp16=use_fp16)
172
-
173
- def _load_model(self, device="cuda", use_fp16=True):
174
- config = OmegaConf.load(self.config)
175
- model = load_model_from_config(config, self.ckpt)
176
- if model is None:
177
- raise ValueError(f"Model {self.model_id} could not be loaded")
178
- model.to(device)
179
- if use_fp16:
180
- model.conditioner.half()
181
- model.model.half()
182
- return model
183
-
184
- def text_to_image(
185
- self,
186
- params: SamplingParams,
187
- prompt: str,
188
- negative_prompt: str = "",
189
- samples: int = 1,
190
- return_latents: bool = False,
191
- ):
192
- sampler = get_sampler_config(params)
193
- value_dict = asdict(params)
194
- value_dict["prompt"] = prompt
195
- value_dict["negative_prompt"] = negative_prompt
196
- value_dict["target_width"] = params.width
197
- value_dict["target_height"] = params.height
198
- return do_sample(
199
- self.model,
200
- sampler,
201
- value_dict,
202
- samples,
203
- params.height,
204
- params.width,
205
- self.specs.channels,
206
- self.specs.factor,
207
- force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
208
- return_latents=return_latents,
209
- filter=None,
210
- )
211
-
212
- def image_to_image(
213
- self,
214
- params: SamplingParams,
215
- image,
216
- prompt: str,
217
- negative_prompt: str = "",
218
- samples: int = 1,
219
- return_latents: bool = False,
220
- ):
221
- sampler = get_sampler_config(params)
222
-
223
- if params.img2img_strength < 1.0:
224
- sampler.discretization = Img2ImgDiscretizationWrapper(
225
- sampler.discretization,
226
- strength=params.img2img_strength,
227
- )
228
- height, width = image.shape[2], image.shape[3]
229
- value_dict = asdict(params)
230
- value_dict["prompt"] = prompt
231
- value_dict["negative_prompt"] = negative_prompt
232
- value_dict["target_width"] = width
233
- value_dict["target_height"] = height
234
- return do_img2img(
235
- image,
236
- self.model,
237
- sampler,
238
- value_dict,
239
- samples,
240
- force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
241
- return_latents=return_latents,
242
- filter=None,
243
- )
244
-
245
- def refiner(
246
- self,
247
- params: SamplingParams,
248
- image,
249
- prompt: str,
250
- negative_prompt: Optional[str] = None,
251
- samples: int = 1,
252
- return_latents: bool = False,
253
- ):
254
- sampler = get_sampler_config(params)
255
- value_dict = {
256
- "orig_width": image.shape[3] * 8,
257
- "orig_height": image.shape[2] * 8,
258
- "target_width": image.shape[3] * 8,
259
- "target_height": image.shape[2] * 8,
260
- "prompt": prompt,
261
- "negative_prompt": negative_prompt,
262
- "crop_coords_top": 0,
263
- "crop_coords_left": 0,
264
- "aesthetic_score": 6.0,
265
- "negative_aesthetic_score": 2.5,
266
- }
267
-
268
- return do_img2img(
269
- image,
270
- self.model,
271
- sampler,
272
- value_dict,
273
- samples,
274
- skip_encode=True,
275
- return_latents=return_latents,
276
- filter=None,
277
- )
278
-
279
-
280
- def get_guider_config(params: SamplingParams):
281
- if params.guider == Guider.IDENTITY:
282
- guider_config = {
283
- "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
284
- }
285
- elif params.guider == Guider.VANILLA:
286
- scale = params.scale
287
-
288
- thresholder = params.thresholder
289
-
290
- if thresholder == Thresholder.NONE:
291
- dyn_thresh_config = {
292
- "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
293
- }
294
- else:
295
- raise NotImplementedError
296
-
297
- guider_config = {
298
- "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
299
- "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
300
- }
301
- else:
302
- raise NotImplementedError
303
- return guider_config
304
-
305
-
306
- def get_discretization_config(params: SamplingParams):
307
- if params.discretization == Discretization.LEGACY_DDPM:
308
- discretization_config = {
309
- "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
310
- }
311
- elif params.discretization == Discretization.EDM:
312
- discretization_config = {
313
- "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
314
- "params": {
315
- "sigma_min": params.sigma_min,
316
- "sigma_max": params.sigma_max,
317
- "rho": params.rho,
318
- },
319
- }
320
- else:
321
- raise ValueError(f"unknown discretization {params.discretization}")
322
- return discretization_config
323
-
324
-
325
- def get_sampler_config(params: SamplingParams):
326
- discretization_config = get_discretization_config(params)
327
- guider_config = get_guider_config(params)
328
- sampler = None
329
- if params.sampler == Sampler.EULER_EDM:
330
- return EulerEDMSampler(
331
- num_steps=params.steps,
332
- discretization_config=discretization_config,
333
- guider_config=guider_config,
334
- s_churn=params.s_churn,
335
- s_tmin=params.s_tmin,
336
- s_tmax=params.s_tmax,
337
- s_noise=params.s_noise,
338
- verbose=True,
339
- )
340
- if params.sampler == Sampler.HEUN_EDM:
341
- return HeunEDMSampler(
342
- num_steps=params.steps,
343
- discretization_config=discretization_config,
344
- guider_config=guider_config,
345
- s_churn=params.s_churn,
346
- s_tmin=params.s_tmin,
347
- s_tmax=params.s_tmax,
348
- s_noise=params.s_noise,
349
- verbose=True,
350
- )
351
- if params.sampler == Sampler.EULER_ANCESTRAL:
352
- return EulerAncestralSampler(
353
- num_steps=params.steps,
354
- discretization_config=discretization_config,
355
- guider_config=guider_config,
356
- eta=params.eta,
357
- s_noise=params.s_noise,
358
- verbose=True,
359
- )
360
- if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
361
- return DPMPP2SAncestralSampler(
362
- num_steps=params.steps,
363
- discretization_config=discretization_config,
364
- guider_config=guider_config,
365
- eta=params.eta,
366
- s_noise=params.s_noise,
367
- verbose=True,
368
- )
369
- if params.sampler == Sampler.DPMPP2M:
370
- return DPMPP2MSampler(
371
- num_steps=params.steps,
372
- discretization_config=discretization_config,
373
- guider_config=guider_config,
374
- verbose=True,
375
- )
376
- if params.sampler == Sampler.LINEAR_MULTISTEP:
377
- return LinearMultistepSampler(
378
- num_steps=params.steps,
379
- discretization_config=discretization_config,
380
- guider_config=guider_config,
381
- order=params.order,
382
- verbose=True,
383
- )
384
-
385
- raise ValueError(f"unknown sampler {params.sampler}!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/inference/helpers.py DELETED
@@ -1,305 +0,0 @@
1
- import math
2
- import os
3
- from typing import List, Optional, Union
4
-
5
- import numpy as np
6
- import torch
7
- from einops import rearrange
8
- from imwatermark import WatermarkEncoder
9
- from omegaconf import ListConfig
10
- from PIL import Image
11
- from torch import autocast
12
-
13
- from sgm.util import append_dims
14
-
15
-
16
- class WatermarkEmbedder:
17
- def __init__(self, watermark):
18
- self.watermark = watermark
19
- self.num_bits = len(WATERMARK_BITS)
20
- self.encoder = WatermarkEncoder()
21
- self.encoder.set_watermark("bits", self.watermark)
22
-
23
- def __call__(self, image: torch.Tensor) -> torch.Tensor:
24
- """
25
- Adds a predefined watermark to the input image
26
-
27
- Args:
28
- image: ([N,] B, RGB, H, W) in range [0, 1]
29
-
30
- Returns:
31
- same as input but watermarked
32
- """
33
- squeeze = len(image.shape) == 4
34
- if squeeze:
35
- image = image[None, ...]
36
- n = image.shape[0]
37
- image_np = rearrange(
38
- (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
39
- ).numpy()[:, :, :, ::-1]
40
- # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
41
- # watermarking libary expects input as cv2 BGR format
42
- for k in range(image_np.shape[0]):
43
- image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
44
- image = torch.from_numpy(
45
- rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
46
- ).to(image.device)
47
- image = torch.clamp(image / 255, min=0.0, max=1.0)
48
- if squeeze:
49
- image = image[0]
50
- return image
51
-
52
-
53
- # A fixed 48-bit message that was choosen at random
54
- # WATERMARK_MESSAGE = 0xB3EC907BB19E
55
- WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
56
- # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
57
- WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
58
- embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
59
-
60
-
61
- def get_unique_embedder_keys_from_conditioner(conditioner):
62
- return list({x.input_key for x in conditioner.embedders})
63
-
64
-
65
- def perform_save_locally(save_path, samples):
66
- os.makedirs(os.path.join(save_path), exist_ok=True)
67
- base_count = len(os.listdir(os.path.join(save_path)))
68
- samples = embed_watermark(samples)
69
- for sample in samples:
70
- sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
71
- Image.fromarray(sample.astype(np.uint8)).save(
72
- os.path.join(save_path, f"{base_count:09}.png")
73
- )
74
- base_count += 1
75
-
76
-
77
- class Img2ImgDiscretizationWrapper:
78
- """
79
- wraps a discretizer, and prunes the sigmas
80
- params:
81
- strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
82
- """
83
-
84
- def __init__(self, discretization, strength: float = 1.0):
85
- self.discretization = discretization
86
- self.strength = strength
87
- assert 0.0 <= self.strength <= 1.0
88
-
89
- def __call__(self, *args, **kwargs):
90
- # sigmas start large first, and decrease then
91
- sigmas = self.discretization(*args, **kwargs)
92
- print(f"sigmas after discretization, before pruning img2img: ", sigmas)
93
- sigmas = torch.flip(sigmas, (0,))
94
- sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
95
- print("prune index:", max(int(self.strength * len(sigmas)), 1))
96
- sigmas = torch.flip(sigmas, (0,))
97
- print(f"sigmas after pruning: ", sigmas)
98
- return sigmas
99
-
100
-
101
- def do_sample(
102
- model,
103
- sampler,
104
- value_dict,
105
- num_samples,
106
- H,
107
- W,
108
- C,
109
- F,
110
- force_uc_zero_embeddings: Optional[List] = None,
111
- batch2model_input: Optional[List] = None,
112
- return_latents=False,
113
- filter=None,
114
- device="cuda",
115
- ):
116
- if force_uc_zero_embeddings is None:
117
- force_uc_zero_embeddings = []
118
- if batch2model_input is None:
119
- batch2model_input = []
120
-
121
- with torch.no_grad():
122
- with autocast(device) as precision_scope:
123
- with model.ema_scope():
124
- num_samples = [num_samples]
125
- batch, batch_uc = get_batch(
126
- get_unique_embedder_keys_from_conditioner(model.conditioner),
127
- value_dict,
128
- num_samples,
129
- )
130
- for key in batch:
131
- if isinstance(batch[key], torch.Tensor):
132
- print(key, batch[key].shape)
133
- elif isinstance(batch[key], list):
134
- print(key, [len(l) for l in batch[key]])
135
- else:
136
- print(key, batch[key])
137
- c, uc = model.conditioner.get_unconditional_conditioning(
138
- batch,
139
- batch_uc=batch_uc,
140
- force_uc_zero_embeddings=force_uc_zero_embeddings,
141
- )
142
-
143
- for k in c:
144
- if not k == "crossattn":
145
- c[k], uc[k] = map(
146
- lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
147
- )
148
-
149
- additional_model_inputs = {}
150
- for k in batch2model_input:
151
- additional_model_inputs[k] = batch[k]
152
-
153
- shape = (math.prod(num_samples), C, H // F, W // F)
154
- randn = torch.randn(shape).to(device)
155
-
156
- def denoiser(input, sigma, c):
157
- return model.denoiser(
158
- model.model, input, sigma, c, **additional_model_inputs
159
- )
160
-
161
- samples_z = sampler(denoiser, randn, cond=c, uc=uc)
162
- samples_x = model.decode_first_stage(samples_z)
163
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
164
-
165
- if filter is not None:
166
- samples = filter(samples)
167
-
168
- if return_latents:
169
- return samples, samples_z
170
- return samples
171
-
172
-
173
- def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
174
- # Hardcoded demo setups; might undergo some changes in the future
175
-
176
- batch = {}
177
- batch_uc = {}
178
-
179
- for key in keys:
180
- if key == "txt":
181
- batch["txt"] = (
182
- np.repeat([value_dict["prompt"]], repeats=math.prod(N))
183
- .reshape(N)
184
- .tolist()
185
- )
186
- batch_uc["txt"] = (
187
- np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
188
- .reshape(N)
189
- .tolist()
190
- )
191
- elif key == "original_size_as_tuple":
192
- batch["original_size_as_tuple"] = (
193
- torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
194
- .to(device)
195
- .repeat(*N, 1)
196
- )
197
- elif key == "crop_coords_top_left":
198
- batch["crop_coords_top_left"] = (
199
- torch.tensor(
200
- [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
201
- )
202
- .to(device)
203
- .repeat(*N, 1)
204
- )
205
- elif key == "aesthetic_score":
206
- batch["aesthetic_score"] = (
207
- torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
208
- )
209
- batch_uc["aesthetic_score"] = (
210
- torch.tensor([value_dict["negative_aesthetic_score"]])
211
- .to(device)
212
- .repeat(*N, 1)
213
- )
214
-
215
- elif key == "target_size_as_tuple":
216
- batch["target_size_as_tuple"] = (
217
- torch.tensor([value_dict["target_height"], value_dict["target_width"]])
218
- .to(device)
219
- .repeat(*N, 1)
220
- )
221
- else:
222
- batch[key] = value_dict[key]
223
-
224
- for key in batch.keys():
225
- if key not in batch_uc and isinstance(batch[key], torch.Tensor):
226
- batch_uc[key] = torch.clone(batch[key])
227
- return batch, batch_uc
228
-
229
-
230
- def get_input_image_tensor(image: Image.Image, device="cuda"):
231
- w, h = image.size
232
- print(f"loaded input image of size ({w}, {h})")
233
- width, height = map(
234
- lambda x: x - x % 64, (w, h)
235
- ) # resize to integer multiple of 64
236
- image = image.resize((width, height))
237
- image_array = np.array(image.convert("RGB"))
238
- image_array = image_array[None].transpose(0, 3, 1, 2)
239
- image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
240
- return image_tensor.to(device)
241
-
242
-
243
- def do_img2img(
244
- img,
245
- model,
246
- sampler,
247
- value_dict,
248
- num_samples,
249
- force_uc_zero_embeddings=[],
250
- additional_kwargs={},
251
- offset_noise_level: float = 0.0,
252
- return_latents=False,
253
- skip_encode=False,
254
- filter=None,
255
- device="cuda",
256
- ):
257
- with torch.no_grad():
258
- with autocast(device) as precision_scope:
259
- with model.ema_scope():
260
- batch, batch_uc = get_batch(
261
- get_unique_embedder_keys_from_conditioner(model.conditioner),
262
- value_dict,
263
- [num_samples],
264
- )
265
- c, uc = model.conditioner.get_unconditional_conditioning(
266
- batch,
267
- batch_uc=batch_uc,
268
- force_uc_zero_embeddings=force_uc_zero_embeddings,
269
- )
270
-
271
- for k in c:
272
- c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
273
-
274
- for k in additional_kwargs:
275
- c[k] = uc[k] = additional_kwargs[k]
276
- if skip_encode:
277
- z = img
278
- else:
279
- z = model.encode_first_stage(img)
280
- noise = torch.randn_like(z)
281
- sigmas = sampler.discretization(sampler.num_steps)
282
- sigma = sigmas[0].to(z.device)
283
-
284
- if offset_noise_level > 0.0:
285
- noise = noise + offset_noise_level * append_dims(
286
- torch.randn(z.shape[0], device=z.device), z.ndim
287
- )
288
- noised_z = z + noise * append_dims(sigma, z.ndim)
289
- noised_z = noised_z / torch.sqrt(
290
- 1.0 + sigmas[0] ** 2.0
291
- ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
292
-
293
- def denoiser(x, sigma, c):
294
- return model.denoiser(model.model, x, sigma, c)
295
-
296
- samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
297
- samples_x = model.decode_first_stage(samples_z)
298
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
299
-
300
- if filter is not None:
301
- samples = filter(samples)
302
-
303
- if return_latents:
304
- return samples, samples_z
305
- return samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/lr_scheduler.py DELETED
@@ -1,135 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- class LambdaWarmUpCosineScheduler:
5
- """
6
- note: use with a base_lr of 1.0
7
- """
8
-
9
- def __init__(
10
- self,
11
- warm_up_steps,
12
- lr_min,
13
- lr_max,
14
- lr_start,
15
- max_decay_steps,
16
- verbosity_interval=0,
17
- ):
18
- self.lr_warm_up_steps = warm_up_steps
19
- self.lr_start = lr_start
20
- self.lr_min = lr_min
21
- self.lr_max = lr_max
22
- self.lr_max_decay_steps = max_decay_steps
23
- self.last_lr = 0.0
24
- self.verbosity_interval = verbosity_interval
25
-
26
- def schedule(self, n, **kwargs):
27
- if self.verbosity_interval > 0:
28
- if n % self.verbosity_interval == 0:
29
- print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
- if n < self.lr_warm_up_steps:
31
- lr = (
32
- self.lr_max - self.lr_start
33
- ) / self.lr_warm_up_steps * n + self.lr_start
34
- self.last_lr = lr
35
- return lr
36
- else:
37
- t = (n - self.lr_warm_up_steps) / (
38
- self.lr_max_decay_steps - self.lr_warm_up_steps
39
- )
40
- t = min(t, 1.0)
41
- lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
- 1 + np.cos(t * np.pi)
43
- )
44
- self.last_lr = lr
45
- return lr
46
-
47
- def __call__(self, n, **kwargs):
48
- return self.schedule(n, **kwargs)
49
-
50
-
51
- class LambdaWarmUpCosineScheduler2:
52
- """
53
- supports repeated iterations, configurable via lists
54
- note: use with a base_lr of 1.0.
55
- """
56
-
57
- def __init__(
58
- self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
- ):
60
- assert (
61
- len(warm_up_steps)
62
- == len(f_min)
63
- == len(f_max)
64
- == len(f_start)
65
- == len(cycle_lengths)
66
- )
67
- self.lr_warm_up_steps = warm_up_steps
68
- self.f_start = f_start
69
- self.f_min = f_min
70
- self.f_max = f_max
71
- self.cycle_lengths = cycle_lengths
72
- self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
- self.last_f = 0.0
74
- self.verbosity_interval = verbosity_interval
75
-
76
- def find_in_interval(self, n):
77
- interval = 0
78
- for cl in self.cum_cycles[1:]:
79
- if n <= cl:
80
- return interval
81
- interval += 1
82
-
83
- def schedule(self, n, **kwargs):
84
- cycle = self.find_in_interval(n)
85
- n = n - self.cum_cycles[cycle]
86
- if self.verbosity_interval > 0:
87
- if n % self.verbosity_interval == 0:
88
- print(
89
- f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
- f"current cycle {cycle}"
91
- )
92
- if n < self.lr_warm_up_steps[cycle]:
93
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
- cycle
95
- ] * n + self.f_start[cycle]
96
- self.last_f = f
97
- return f
98
- else:
99
- t = (n - self.lr_warm_up_steps[cycle]) / (
100
- self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
- )
102
- t = min(t, 1.0)
103
- f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
- 1 + np.cos(t * np.pi)
105
- )
106
- self.last_f = f
107
- return f
108
-
109
- def __call__(self, n, **kwargs):
110
- return self.schedule(n, **kwargs)
111
-
112
-
113
- class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
- def schedule(self, n, **kwargs):
115
- cycle = self.find_in_interval(n)
116
- n = n - self.cum_cycles[cycle]
117
- if self.verbosity_interval > 0:
118
- if n % self.verbosity_interval == 0:
119
- print(
120
- f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
- f"current cycle {cycle}"
122
- )
123
-
124
- if n < self.lr_warm_up_steps[cycle]:
125
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
- cycle
127
- ] * n + self.f_start[cycle]
128
- self.last_f = f
129
- return f
130
- else:
131
- f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
- self.cycle_lengths[cycle] - n
133
- ) / (self.cycle_lengths[cycle])
134
- self.last_f = f
135
- return f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/models/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .autoencoder import AutoencodingEngine
2
- from .diffusion import DiffusionEngine
 
 
 
sgm/models/autoencoder.py DELETED
@@ -1,615 +0,0 @@
1
- import logging
2
- import math
3
- import re
4
- from abc import abstractmethod
5
- from contextlib import contextmanager
6
- from typing import Any, Dict, List, Optional, Tuple, Union
7
-
8
- import pytorch_lightning as pl
9
- import torch
10
- import torch.nn as nn
11
- from einops import rearrange
12
- from packaging import version
13
-
14
- from ..modules.autoencoding.regularizers import AbstractRegularizer
15
- from ..modules.ema import LitEma
16
- from ..util import (default, get_nested_attribute, get_obj_from_str,
17
- instantiate_from_config)
18
-
19
- logpy = logging.getLogger(__name__)
20
-
21
-
22
- class AbstractAutoencoder(pl.LightningModule):
23
- """
24
- This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
25
- unCLIP models, etc. Hence, it is fairly general, and specific features
26
- (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
27
- """
28
-
29
- def __init__(
30
- self,
31
- ema_decay: Union[None, float] = None,
32
- monitor: Union[None, str] = None,
33
- input_key: str = "jpg",
34
- ):
35
- super().__init__()
36
-
37
- self.input_key = input_key
38
- self.use_ema = ema_decay is not None
39
- if monitor is not None:
40
- self.monitor = monitor
41
-
42
- if self.use_ema:
43
- self.model_ema = LitEma(self, decay=ema_decay)
44
- logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
45
-
46
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
- self.automatic_optimization = False
48
-
49
- def apply_ckpt(self, ckpt: Union[None, str, dict]):
50
- if ckpt is None:
51
- return
52
- if isinstance(ckpt, str):
53
- ckpt = {
54
- "target": "sgm.modules.checkpoint.CheckpointEngine",
55
- "params": {"ckpt_path": ckpt},
56
- }
57
- engine = instantiate_from_config(ckpt)
58
- engine(self)
59
-
60
- @abstractmethod
61
- def get_input(self, batch) -> Any:
62
- raise NotImplementedError()
63
-
64
- def on_train_batch_end(self, *args, **kwargs):
65
- # for EMA computation
66
- if self.use_ema:
67
- self.model_ema(self)
68
-
69
- @contextmanager
70
- def ema_scope(self, context=None):
71
- if self.use_ema:
72
- self.model_ema.store(self.parameters())
73
- self.model_ema.copy_to(self)
74
- if context is not None:
75
- logpy.info(f"{context}: Switched to EMA weights")
76
- try:
77
- yield None
78
- finally:
79
- if self.use_ema:
80
- self.model_ema.restore(self.parameters())
81
- if context is not None:
82
- logpy.info(f"{context}: Restored training weights")
83
-
84
- @abstractmethod
85
- def encode(self, *args, **kwargs) -> torch.Tensor:
86
- raise NotImplementedError("encode()-method of abstract base class called")
87
-
88
- @abstractmethod
89
- def decode(self, *args, **kwargs) -> torch.Tensor:
90
- raise NotImplementedError("decode()-method of abstract base class called")
91
-
92
- def instantiate_optimizer_from_config(self, params, lr, cfg):
93
- logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
94
- return get_obj_from_str(cfg["target"])(
95
- params, lr=lr, **cfg.get("params", dict())
96
- )
97
-
98
- def configure_optimizers(self) -> Any:
99
- raise NotImplementedError()
100
-
101
-
102
- class AutoencodingEngine(AbstractAutoencoder):
103
- """
104
- Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
105
- (we also restore them explicitly as special cases for legacy reasons).
106
- Regularizations such as KL or VQ are moved to the regularizer class.
107
- """
108
-
109
- def __init__(
110
- self,
111
- *args,
112
- encoder_config: Dict,
113
- decoder_config: Dict,
114
- loss_config: Dict,
115
- regularizer_config: Dict,
116
- optimizer_config: Union[Dict, None] = None,
117
- lr_g_factor: float = 1.0,
118
- trainable_ae_params: Optional[List[List[str]]] = None,
119
- ae_optimizer_args: Optional[List[dict]] = None,
120
- trainable_disc_params: Optional[List[List[str]]] = None,
121
- disc_optimizer_args: Optional[List[dict]] = None,
122
- disc_start_iter: int = 0,
123
- diff_boost_factor: float = 3.0,
124
- ckpt_engine: Union[None, str, dict] = None,
125
- ckpt_path: Optional[str] = None,
126
- additional_decode_keys: Optional[List[str]] = None,
127
- **kwargs,
128
- ):
129
- super().__init__(*args, **kwargs)
130
- self.automatic_optimization = False # pytorch lightning
131
-
132
- self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
133
- self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
134
- self.loss: torch.nn.Module = instantiate_from_config(loss_config)
135
- self.regularization: AbstractRegularizer = instantiate_from_config(
136
- regularizer_config
137
- )
138
- self.optimizer_config = default(
139
- optimizer_config, {"target": "torch.optim.Adam"}
140
- )
141
- self.diff_boost_factor = diff_boost_factor
142
- self.disc_start_iter = disc_start_iter
143
- self.lr_g_factor = lr_g_factor
144
- self.trainable_ae_params = trainable_ae_params
145
- if self.trainable_ae_params is not None:
146
- self.ae_optimizer_args = default(
147
- ae_optimizer_args,
148
- [{} for _ in range(len(self.trainable_ae_params))],
149
- )
150
- assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
151
- else:
152
- self.ae_optimizer_args = [{}] # makes type consitent
153
-
154
- self.trainable_disc_params = trainable_disc_params
155
- if self.trainable_disc_params is not None:
156
- self.disc_optimizer_args = default(
157
- disc_optimizer_args,
158
- [{} for _ in range(len(self.trainable_disc_params))],
159
- )
160
- assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
161
- else:
162
- self.disc_optimizer_args = [{}] # makes type consitent
163
-
164
- if ckpt_path is not None:
165
- assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
166
- logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
167
- self.apply_ckpt(default(ckpt_path, ckpt_engine))
168
- self.additional_decode_keys = set(default(additional_decode_keys, []))
169
-
170
- def get_input(self, batch: Dict) -> torch.Tensor:
171
- # assuming unified data format, dataloader returns a dict.
172
- # image tensors should be scaled to -1 ... 1 and in channels-first
173
- # format (e.g., bchw instead if bhwc)
174
- return batch[self.input_key]
175
-
176
- def get_autoencoder_params(self) -> list:
177
- params = []
178
- if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
179
- params += list(self.loss.get_trainable_autoencoder_parameters())
180
- if hasattr(self.regularization, "get_trainable_parameters"):
181
- params += list(self.regularization.get_trainable_parameters())
182
- params = params + list(self.encoder.parameters())
183
- params = params + list(self.decoder.parameters())
184
- return params
185
-
186
- def get_discriminator_params(self) -> list:
187
- if hasattr(self.loss, "get_trainable_parameters"):
188
- params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
189
- else:
190
- params = []
191
- return params
192
-
193
- def get_last_layer(self):
194
- return self.decoder.get_last_layer()
195
-
196
- def encode(
197
- self,
198
- x: torch.Tensor,
199
- return_reg_log: bool = False,
200
- unregularized: bool = False,
201
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
202
- z = self.encoder(x)
203
- if unregularized:
204
- return z, dict()
205
- z, reg_log = self.regularization(z)
206
- if return_reg_log:
207
- return z, reg_log
208
- return z
209
-
210
- def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
211
- x = self.decoder(z, **kwargs)
212
- return x
213
-
214
- def forward(
215
- self, x: torch.Tensor, **additional_decode_kwargs
216
- ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
217
- z, reg_log = self.encode(x, return_reg_log=True)
218
- dec = self.decode(z, **additional_decode_kwargs)
219
- return z, dec, reg_log
220
-
221
- def inner_training_step(
222
- self, batch: dict, batch_idx: int, optimizer_idx: int = 0
223
- ) -> torch.Tensor:
224
- x = self.get_input(batch)
225
- additional_decode_kwargs = {
226
- key: batch[key] for key in self.additional_decode_keys.intersection(batch)
227
- }
228
- z, xrec, regularization_log = self(x, **additional_decode_kwargs)
229
- if hasattr(self.loss, "forward_keys"):
230
- extra_info = {
231
- "z": z,
232
- "optimizer_idx": optimizer_idx,
233
- "global_step": self.global_step,
234
- "last_layer": self.get_last_layer(),
235
- "split": "train",
236
- "regularization_log": regularization_log,
237
- "autoencoder": self,
238
- }
239
- extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
240
- else:
241
- extra_info = dict()
242
-
243
- if optimizer_idx == 0:
244
- # autoencode
245
- out_loss = self.loss(x, xrec, **extra_info)
246
- if isinstance(out_loss, tuple):
247
- aeloss, log_dict_ae = out_loss
248
- else:
249
- # simple loss function
250
- aeloss = out_loss
251
- log_dict_ae = {"train/loss/rec": aeloss.detach()}
252
-
253
- self.log_dict(
254
- log_dict_ae,
255
- prog_bar=False,
256
- logger=True,
257
- on_step=True,
258
- on_epoch=True,
259
- sync_dist=False,
260
- )
261
- self.log(
262
- "loss",
263
- aeloss.mean().detach(),
264
- prog_bar=True,
265
- logger=False,
266
- on_epoch=False,
267
- on_step=True,
268
- )
269
- return aeloss
270
- elif optimizer_idx == 1:
271
- # discriminator
272
- discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
273
- # -> discriminator always needs to return a tuple
274
- self.log_dict(
275
- log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
276
- )
277
- return discloss
278
- else:
279
- raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
280
-
281
- def training_step(self, batch: dict, batch_idx: int):
282
- opts = self.optimizers()
283
- if not isinstance(opts, list):
284
- # Non-adversarial case
285
- opts = [opts]
286
- optimizer_idx = batch_idx % len(opts)
287
- if self.global_step < self.disc_start_iter:
288
- optimizer_idx = 0
289
- opt = opts[optimizer_idx]
290
- opt.zero_grad()
291
- with opt.toggle_model():
292
- loss = self.inner_training_step(
293
- batch, batch_idx, optimizer_idx=optimizer_idx
294
- )
295
- self.manual_backward(loss)
296
- opt.step()
297
-
298
- def validation_step(self, batch: dict, batch_idx: int) -> Dict:
299
- log_dict = self._validation_step(batch, batch_idx)
300
- with self.ema_scope():
301
- log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
302
- log_dict.update(log_dict_ema)
303
- return log_dict
304
-
305
- def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
306
- x = self.get_input(batch)
307
-
308
- z, xrec, regularization_log = self(x)
309
- if hasattr(self.loss, "forward_keys"):
310
- extra_info = {
311
- "z": z,
312
- "optimizer_idx": 0,
313
- "global_step": self.global_step,
314
- "last_layer": self.get_last_layer(),
315
- "split": "val" + postfix,
316
- "regularization_log": regularization_log,
317
- "autoencoder": self,
318
- }
319
- extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
320
- else:
321
- extra_info = dict()
322
- out_loss = self.loss(x, xrec, **extra_info)
323
- if isinstance(out_loss, tuple):
324
- aeloss, log_dict_ae = out_loss
325
- else:
326
- # simple loss function
327
- aeloss = out_loss
328
- log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
329
- full_log_dict = log_dict_ae
330
-
331
- if "optimizer_idx" in extra_info:
332
- extra_info["optimizer_idx"] = 1
333
- discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
334
- full_log_dict.update(log_dict_disc)
335
- self.log(
336
- f"val{postfix}/loss/rec",
337
- log_dict_ae[f"val{postfix}/loss/rec"],
338
- sync_dist=True,
339
- )
340
- self.log_dict(full_log_dict, sync_dist=True)
341
- return full_log_dict
342
-
343
- def get_param_groups(
344
- self, parameter_names: List[List[str]], optimizer_args: List[dict]
345
- ) -> Tuple[List[Dict[str, Any]], int]:
346
- groups = []
347
- num_params = 0
348
- for names, args in zip(parameter_names, optimizer_args):
349
- params = []
350
- for pattern_ in names:
351
- pattern_params = []
352
- pattern = re.compile(pattern_)
353
- for p_name, param in self.named_parameters():
354
- if re.match(pattern, p_name):
355
- pattern_params.append(param)
356
- num_params += param.numel()
357
- if len(pattern_params) == 0:
358
- logpy.warn(f"Did not find parameters for pattern {pattern_}")
359
- params.extend(pattern_params)
360
- groups.append({"params": params, **args})
361
- return groups, num_params
362
-
363
- def configure_optimizers(self) -> List[torch.optim.Optimizer]:
364
- if self.trainable_ae_params is None:
365
- ae_params = self.get_autoencoder_params()
366
- else:
367
- ae_params, num_ae_params = self.get_param_groups(
368
- self.trainable_ae_params, self.ae_optimizer_args
369
- )
370
- logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
371
- if self.trainable_disc_params is None:
372
- disc_params = self.get_discriminator_params()
373
- else:
374
- disc_params, num_disc_params = self.get_param_groups(
375
- self.trainable_disc_params, self.disc_optimizer_args
376
- )
377
- logpy.info(
378
- f"Number of trainable discriminator parameters: {num_disc_params:,}"
379
- )
380
- opt_ae = self.instantiate_optimizer_from_config(
381
- ae_params,
382
- default(self.lr_g_factor, 1.0) * self.learning_rate,
383
- self.optimizer_config,
384
- )
385
- opts = [opt_ae]
386
- if len(disc_params) > 0:
387
- opt_disc = self.instantiate_optimizer_from_config(
388
- disc_params, self.learning_rate, self.optimizer_config
389
- )
390
- opts.append(opt_disc)
391
-
392
- return opts
393
-
394
- @torch.no_grad()
395
- def log_images(
396
- self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
397
- ) -> dict:
398
- log = dict()
399
- additional_decode_kwargs = {}
400
- x = self.get_input(batch)
401
- additional_decode_kwargs.update(
402
- {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
403
- )
404
-
405
- _, xrec, _ = self(x, **additional_decode_kwargs)
406
- log["inputs"] = x
407
- log["reconstructions"] = xrec
408
- diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
409
- diff.clamp_(0, 1.0)
410
- log["diff"] = 2.0 * diff - 1.0
411
- # diff_boost shows location of small errors, by boosting their
412
- # brightness.
413
- log["diff_boost"] = (
414
- 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
415
- )
416
- if hasattr(self.loss, "log_images"):
417
- log.update(self.loss.log_images(x, xrec))
418
- with self.ema_scope():
419
- _, xrec_ema, _ = self(x, **additional_decode_kwargs)
420
- log["reconstructions_ema"] = xrec_ema
421
- diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
422
- diff_ema.clamp_(0, 1.0)
423
- log["diff_ema"] = 2.0 * diff_ema - 1.0
424
- log["diff_boost_ema"] = (
425
- 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
426
- )
427
- if additional_log_kwargs:
428
- additional_decode_kwargs.update(additional_log_kwargs)
429
- _, xrec_add, _ = self(x, **additional_decode_kwargs)
430
- log_str = "reconstructions-" + "-".join(
431
- [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
432
- )
433
- log[log_str] = xrec_add
434
- return log
435
-
436
-
437
- class AutoencodingEngineLegacy(AutoencodingEngine):
438
- def __init__(self, embed_dim: int, **kwargs):
439
- self.max_batch_size = kwargs.pop("max_batch_size", None)
440
- ddconfig = kwargs.pop("ddconfig")
441
- ckpt_path = kwargs.pop("ckpt_path", None)
442
- ckpt_engine = kwargs.pop("ckpt_engine", None)
443
- super().__init__(
444
- encoder_config={
445
- "target": "sgm.modules.diffusionmodules.model.Encoder",
446
- "params": ddconfig,
447
- },
448
- decoder_config={
449
- "target": "sgm.modules.diffusionmodules.model.Decoder",
450
- "params": ddconfig,
451
- },
452
- **kwargs,
453
- )
454
- self.quant_conv = torch.nn.Conv2d(
455
- (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
456
- (1 + ddconfig["double_z"]) * embed_dim,
457
- 1,
458
- )
459
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
460
- self.embed_dim = embed_dim
461
-
462
- self.apply_ckpt(default(ckpt_path, ckpt_engine))
463
-
464
- def get_autoencoder_params(self) -> list:
465
- params = super().get_autoencoder_params()
466
- return params
467
-
468
- def encode(
469
- self, x: torch.Tensor, return_reg_log: bool = False
470
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
471
- if self.max_batch_size is None:
472
- z = self.encoder(x)
473
- z = self.quant_conv(z)
474
- else:
475
- N = x.shape[0]
476
- bs = self.max_batch_size
477
- n_batches = int(math.ceil(N / bs))
478
- z = list()
479
- for i_batch in range(n_batches):
480
- z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
481
- z_batch = self.quant_conv(z_batch)
482
- z.append(z_batch)
483
- z = torch.cat(z, 0)
484
-
485
- z, reg_log = self.regularization(z)
486
- if return_reg_log:
487
- return z, reg_log
488
- return z
489
-
490
- def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
491
- if self.max_batch_size is None:
492
- dec = self.post_quant_conv(z)
493
- dec = self.decoder(dec, **decoder_kwargs)
494
- else:
495
- N = z.shape[0]
496
- bs = self.max_batch_size
497
- n_batches = int(math.ceil(N / bs))
498
- dec = list()
499
- for i_batch in range(n_batches):
500
- dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
501
- dec_batch = self.decoder(dec_batch, **decoder_kwargs)
502
- dec.append(dec_batch)
503
- dec = torch.cat(dec, 0)
504
-
505
- return dec
506
-
507
-
508
- class AutoencoderKL(AutoencodingEngineLegacy):
509
- def __init__(self, **kwargs):
510
- if "lossconfig" in kwargs:
511
- kwargs["loss_config"] = kwargs.pop("lossconfig")
512
- super().__init__(
513
- regularizer_config={
514
- "target": (
515
- "sgm.modules.autoencoding.regularizers"
516
- ".DiagonalGaussianRegularizer"
517
- )
518
- },
519
- **kwargs,
520
- )
521
-
522
-
523
- class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
524
- def __init__(
525
- self,
526
- embed_dim: int,
527
- n_embed: int,
528
- sane_index_shape: bool = False,
529
- **kwargs,
530
- ):
531
- if "lossconfig" in kwargs:
532
- logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
533
- kwargs["loss_config"] = kwargs.pop("lossconfig")
534
- super().__init__(
535
- regularizer_config={
536
- "target": (
537
- "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
538
- ),
539
- "params": {
540
- "n_e": n_embed,
541
- "e_dim": embed_dim,
542
- "sane_index_shape": sane_index_shape,
543
- },
544
- },
545
- **kwargs,
546
- )
547
-
548
-
549
- class IdentityFirstStage(AbstractAutoencoder):
550
- def __init__(self, *args, **kwargs):
551
- super().__init__(*args, **kwargs)
552
-
553
- def get_input(self, x: Any) -> Any:
554
- return x
555
-
556
- def encode(self, x: Any, *args, **kwargs) -> Any:
557
- return x
558
-
559
- def decode(self, x: Any, *args, **kwargs) -> Any:
560
- return x
561
-
562
-
563
- class AEIntegerWrapper(nn.Module):
564
- def __init__(
565
- self,
566
- model: nn.Module,
567
- shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
568
- regularization_key: str = "regularization",
569
- encoder_kwargs: Optional[Dict[str, Any]] = None,
570
- ):
571
- super().__init__()
572
- self.model = model
573
- assert hasattr(model, "encode") and hasattr(
574
- model, "decode"
575
- ), "Need AE interface"
576
- self.regularization = get_nested_attribute(model, regularization_key)
577
- self.shape = shape
578
- self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
579
-
580
- def encode(self, x) -> torch.Tensor:
581
- assert (
582
- not self.training
583
- ), f"{self.__class__.__name__} only supports inference currently"
584
- _, log = self.model.encode(x, **self.encoder_kwargs)
585
- assert isinstance(log, dict)
586
- inds = log["min_encoding_indices"]
587
- return rearrange(inds, "b ... -> b (...)")
588
-
589
- def decode(
590
- self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
591
- ) -> torch.Tensor:
592
- # expect inds shape (b, s) with s = h*w
593
- shape = default(shape, self.shape) # Optional[(h, w)]
594
- if shape is not None:
595
- assert len(shape) == 2, f"Unhandeled shape {shape}"
596
- inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
597
- h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
598
- h = rearrange(h, "b h w c -> b c h w")
599
- return self.model.decode(h)
600
-
601
-
602
- class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
603
- def __init__(self, **kwargs):
604
- if "lossconfig" in kwargs:
605
- kwargs["loss_config"] = kwargs.pop("lossconfig")
606
- super().__init__(
607
- regularizer_config={
608
- "target": (
609
- "sgm.modules.autoencoding.regularizers"
610
- ".DiagonalGaussianRegularizer"
611
- ),
612
- "params": {"sample": False},
613
- },
614
- **kwargs,
615
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/models/diffusion.py DELETED
@@ -1,341 +0,0 @@
1
- import math
2
- from contextlib import contextmanager
3
- from typing import Any, Dict, List, Optional, Tuple, Union
4
-
5
- import pytorch_lightning as pl
6
- import torch
7
- from omegaconf import ListConfig, OmegaConf
8
- from safetensors.torch import load_file as load_safetensors
9
- from torch.optim.lr_scheduler import LambdaLR
10
-
11
- from ..modules import UNCONDITIONAL_CONFIG
12
- from ..modules.autoencoding.temporal_ae import VideoDecoder
13
- from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
14
- from ..modules.ema import LitEma
15
- from ..util import (default, disabled_train, get_obj_from_str,
16
- instantiate_from_config, log_txt_as_img)
17
-
18
-
19
- class DiffusionEngine(pl.LightningModule):
20
- def __init__(
21
- self,
22
- network_config,
23
- denoiser_config,
24
- first_stage_config,
25
- conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
26
- sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
27
- optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
28
- scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
29
- loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
30
- network_wrapper: Union[None, str] = None,
31
- ckpt_path: Union[None, str] = None,
32
- use_ema: bool = False,
33
- ema_decay_rate: float = 0.9999,
34
- scale_factor: float = 1.0,
35
- disable_first_stage_autocast=False,
36
- input_key: str = "jpg",
37
- log_keys: Union[List, None] = None,
38
- no_cond_log: bool = False,
39
- compile_model: bool = False,
40
- en_and_decode_n_samples_a_time: Optional[int] = None,
41
- ):
42
- super().__init__()
43
- self.log_keys = log_keys
44
- self.input_key = input_key
45
- self.optimizer_config = default(
46
- optimizer_config, {"target": "torch.optim.AdamW"}
47
- )
48
- model = instantiate_from_config(network_config)
49
- self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
50
- model, compile_model=compile_model
51
- )
52
-
53
- self.denoiser = instantiate_from_config(denoiser_config)
54
- self.sampler = (
55
- instantiate_from_config(sampler_config)
56
- if sampler_config is not None
57
- else None
58
- )
59
- self.conditioner = instantiate_from_config(
60
- default(conditioner_config, UNCONDITIONAL_CONFIG)
61
- )
62
- self.scheduler_config = scheduler_config
63
- self._init_first_stage(first_stage_config)
64
-
65
- self.loss_fn = (
66
- instantiate_from_config(loss_fn_config)
67
- if loss_fn_config is not None
68
- else None
69
- )
70
-
71
- self.use_ema = use_ema
72
- if self.use_ema:
73
- self.model_ema = LitEma(self.model, decay=ema_decay_rate)
74
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
75
-
76
- self.scale_factor = scale_factor
77
- self.disable_first_stage_autocast = disable_first_stage_autocast
78
- self.no_cond_log = no_cond_log
79
-
80
- if ckpt_path is not None:
81
- self.init_from_ckpt(ckpt_path)
82
-
83
- self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
84
-
85
- def init_from_ckpt(
86
- self,
87
- path: str,
88
- ) -> None:
89
- if path.endswith("ckpt"):
90
- sd = torch.load(path, map_location="cpu")["state_dict"]
91
- elif path.endswith("safetensors"):
92
- sd = load_safetensors(path)
93
- else:
94
- raise NotImplementedError
95
-
96
- missing, unexpected = self.load_state_dict(sd, strict=False)
97
- print(
98
- f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
99
- )
100
- if len(missing) > 0:
101
- print(f"Missing Keys: {missing}")
102
- if len(unexpected) > 0:
103
- print(f"Unexpected Keys: {unexpected}")
104
-
105
- def _init_first_stage(self, config):
106
- model = instantiate_from_config(config).eval()
107
- model.train = disabled_train
108
- for param in model.parameters():
109
- param.requires_grad = False
110
- self.first_stage_model = model
111
-
112
- def get_input(self, batch):
113
- # assuming unified data format, dataloader returns a dict.
114
- # image tensors should be scaled to -1 ... 1 and in bchw format
115
- return batch[self.input_key]
116
-
117
- @torch.no_grad()
118
- def decode_first_stage(self, z):
119
- z = 1.0 / self.scale_factor * z
120
- n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
121
-
122
- n_rounds = math.ceil(z.shape[0] / n_samples)
123
- all_out = []
124
- with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
125
- for n in range(n_rounds):
126
- if isinstance(self.first_stage_model.decoder, VideoDecoder):
127
- kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
128
- else:
129
- kwargs = {}
130
- out = self.first_stage_model.decode(
131
- z[n * n_samples : (n + 1) * n_samples], **kwargs
132
- )
133
- all_out.append(out)
134
- out = torch.cat(all_out, dim=0)
135
- return out
136
-
137
- @torch.no_grad()
138
- def encode_first_stage(self, x):
139
- n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
140
- n_rounds = math.ceil(x.shape[0] / n_samples)
141
- all_out = []
142
- with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
143
- for n in range(n_rounds):
144
- out = self.first_stage_model.encode(
145
- x[n * n_samples : (n + 1) * n_samples]
146
- )
147
- all_out.append(out)
148
- z = torch.cat(all_out, dim=0)
149
- z = self.scale_factor * z
150
- return z
151
-
152
- def forward(self, x, batch):
153
- loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
154
- loss_mean = loss.mean()
155
- loss_dict = {"loss": loss_mean}
156
- return loss_mean, loss_dict
157
-
158
- def shared_step(self, batch: Dict) -> Any:
159
- x = self.get_input(batch)
160
- x = self.encode_first_stage(x)
161
- batch["global_step"] = self.global_step
162
- loss, loss_dict = self(x, batch)
163
- return loss, loss_dict
164
-
165
- def training_step(self, batch, batch_idx):
166
- loss, loss_dict = self.shared_step(batch)
167
-
168
- self.log_dict(
169
- loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
170
- )
171
-
172
- self.log(
173
- "global_step",
174
- self.global_step,
175
- prog_bar=True,
176
- logger=True,
177
- on_step=True,
178
- on_epoch=False,
179
- )
180
-
181
- if self.scheduler_config is not None:
182
- lr = self.optimizers().param_groups[0]["lr"]
183
- self.log(
184
- "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
185
- )
186
-
187
- return loss
188
-
189
- def on_train_start(self, *args, **kwargs):
190
- if self.sampler is None or self.loss_fn is None:
191
- raise ValueError("Sampler and loss function need to be set for training.")
192
-
193
- def on_train_batch_end(self, *args, **kwargs):
194
- if self.use_ema:
195
- self.model_ema(self.model)
196
-
197
- @contextmanager
198
- def ema_scope(self, context=None):
199
- if self.use_ema:
200
- self.model_ema.store(self.model.parameters())
201
- self.model_ema.copy_to(self.model)
202
- if context is not None:
203
- print(f"{context}: Switched to EMA weights")
204
- try:
205
- yield None
206
- finally:
207
- if self.use_ema:
208
- self.model_ema.restore(self.model.parameters())
209
- if context is not None:
210
- print(f"{context}: Restored training weights")
211
-
212
- def instantiate_optimizer_from_config(self, params, lr, cfg):
213
- return get_obj_from_str(cfg["target"])(
214
- params, lr=lr, **cfg.get("params", dict())
215
- )
216
-
217
- def configure_optimizers(self):
218
- lr = self.learning_rate
219
- params = list(self.model.parameters())
220
- for embedder in self.conditioner.embedders:
221
- if embedder.is_trainable:
222
- params = params + list(embedder.parameters())
223
- opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
224
- if self.scheduler_config is not None:
225
- scheduler = instantiate_from_config(self.scheduler_config)
226
- print("Setting up LambdaLR scheduler...")
227
- scheduler = [
228
- {
229
- "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
230
- "interval": "step",
231
- "frequency": 1,
232
- }
233
- ]
234
- return [opt], scheduler
235
- return opt
236
-
237
- @torch.no_grad()
238
- def sample(
239
- self,
240
- cond: Dict,
241
- uc: Union[Dict, None] = None,
242
- batch_size: int = 16,
243
- shape: Union[None, Tuple, List] = None,
244
- **kwargs,
245
- ):
246
- randn = torch.randn(batch_size, *shape).to(self.device)
247
-
248
- denoiser = lambda input, sigma, c: self.denoiser(
249
- self.model, input, sigma, c, **kwargs
250
- )
251
- samples = self.sampler(denoiser, randn, cond, uc=uc)
252
- return samples
253
-
254
- @torch.no_grad()
255
- def log_conditionings(self, batch: Dict, n: int) -> Dict:
256
- """
257
- Defines heuristics to log different conditionings.
258
- These can be lists of strings (text-to-image), tensors, ints, ...
259
- """
260
- image_h, image_w = batch[self.input_key].shape[2:]
261
- log = dict()
262
-
263
- for embedder in self.conditioner.embedders:
264
- if (
265
- (self.log_keys is None) or (embedder.input_key in self.log_keys)
266
- ) and not self.no_cond_log:
267
- x = batch[embedder.input_key][:n]
268
- if isinstance(x, torch.Tensor):
269
- if x.dim() == 1:
270
- # class-conditional, convert integer to string
271
- x = [str(x[i].item()) for i in range(x.shape[0])]
272
- xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
273
- elif x.dim() == 2:
274
- # size and crop cond and the like
275
- x = [
276
- "x".join([str(xx) for xx in x[i].tolist()])
277
- for i in range(x.shape[0])
278
- ]
279
- xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
280
- else:
281
- raise NotImplementedError()
282
- elif isinstance(x, (List, ListConfig)):
283
- if isinstance(x[0], str):
284
- # strings
285
- xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
286
- else:
287
- raise NotImplementedError()
288
- else:
289
- raise NotImplementedError()
290
- log[embedder.input_key] = xc
291
- return log
292
-
293
- @torch.no_grad()
294
- def log_images(
295
- self,
296
- batch: Dict,
297
- N: int = 8,
298
- sample: bool = True,
299
- ucg_keys: List[str] = None,
300
- **kwargs,
301
- ) -> Dict:
302
- conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
303
- if ucg_keys:
304
- assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
305
- "Each defined ucg key for sampling must be in the provided conditioner input keys,"
306
- f"but we have {ucg_keys} vs. {conditioner_input_keys}"
307
- )
308
- else:
309
- ucg_keys = conditioner_input_keys
310
- log = dict()
311
-
312
- x = self.get_input(batch)
313
-
314
- c, uc = self.conditioner.get_unconditional_conditioning(
315
- batch,
316
- force_uc_zero_embeddings=ucg_keys
317
- if len(self.conditioner.embedders) > 0
318
- else [],
319
- )
320
-
321
- sampling_kwargs = {}
322
-
323
- N = min(x.shape[0], N)
324
- x = x.to(self.device)[:N]
325
- log["inputs"] = x
326
- z = self.encode_first_stage(x)
327
- log["reconstructions"] = self.decode_first_stage(z)
328
- log.update(self.log_conditionings(batch, N))
329
-
330
- for k in c:
331
- if isinstance(c[k], torch.Tensor):
332
- c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
333
-
334
- if sample:
335
- with self.ema_scope("Plotting"):
336
- samples = self.sample(
337
- c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
338
- )
339
- samples = self.decode_first_stage(samples)
340
- log["samples"] = samples
341
- return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .encoders.modules import GeneralConditioner
2
-
3
- UNCONDITIONAL_CONFIG = {
4
- "target": "sgm.modules.GeneralConditioner",
5
- "params": {"emb_models": []},
6
- }
 
 
 
 
 
 
 
sgm/modules/attention.py DELETED
@@ -1,759 +0,0 @@
1
- import logging
2
- import math
3
- from inspect import isfunction
4
- from typing import Any, Optional
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- from einops import rearrange, repeat
9
- from packaging import version
10
- from torch import nn
11
- from torch.utils.checkpoint import checkpoint
12
-
13
- logpy = logging.getLogger(__name__)
14
-
15
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
16
- SDP_IS_AVAILABLE = True
17
- from torch.backends.cuda import SDPBackend, sdp_kernel
18
-
19
- BACKEND_MAP = {
20
- SDPBackend.MATH: {
21
- "enable_math": True,
22
- "enable_flash": False,
23
- "enable_mem_efficient": False,
24
- },
25
- SDPBackend.FLASH_ATTENTION: {
26
- "enable_math": False,
27
- "enable_flash": True,
28
- "enable_mem_efficient": False,
29
- },
30
- SDPBackend.EFFICIENT_ATTENTION: {
31
- "enable_math": False,
32
- "enable_flash": False,
33
- "enable_mem_efficient": True,
34
- },
35
- None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
36
- }
37
- else:
38
- from contextlib import nullcontext
39
-
40
- SDP_IS_AVAILABLE = False
41
- sdp_kernel = nullcontext
42
- BACKEND_MAP = {}
43
- logpy.warn(
44
- f"No SDP backend available, likely because you are running in pytorch "
45
- f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
46
- f"You might want to consider upgrading."
47
- )
48
-
49
- try:
50
- import xformers
51
- import xformers.ops
52
-
53
- XFORMERS_IS_AVAILABLE = True
54
- except:
55
- XFORMERS_IS_AVAILABLE = False
56
- logpy.warn("no module 'xformers'. Processing without...")
57
-
58
- # from .diffusionmodules.util import mixed_checkpoint as checkpoint
59
-
60
-
61
- def exists(val):
62
- return val is not None
63
-
64
-
65
- def uniq(arr):
66
- return {el: True for el in arr}.keys()
67
-
68
-
69
- def default(val, d):
70
- if exists(val):
71
- return val
72
- return d() if isfunction(d) else d
73
-
74
-
75
- def max_neg_value(t):
76
- return -torch.finfo(t.dtype).max
77
-
78
-
79
- def init_(tensor):
80
- dim = tensor.shape[-1]
81
- std = 1 / math.sqrt(dim)
82
- tensor.uniform_(-std, std)
83
- return tensor
84
-
85
-
86
- # feedforward
87
- class GEGLU(nn.Module):
88
- def __init__(self, dim_in, dim_out):
89
- super().__init__()
90
- self.proj = nn.Linear(dim_in, dim_out * 2)
91
-
92
- def forward(self, x):
93
- x, gate = self.proj(x).chunk(2, dim=-1)
94
- return x * F.gelu(gate)
95
-
96
-
97
- class FeedForward(nn.Module):
98
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
99
- super().__init__()
100
- inner_dim = int(dim * mult)
101
- dim_out = default(dim_out, dim)
102
- project_in = (
103
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
104
- if not glu
105
- else GEGLU(dim, inner_dim)
106
- )
107
-
108
- self.net = nn.Sequential(
109
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
110
- )
111
-
112
- def forward(self, x):
113
- return self.net(x)
114
-
115
-
116
- def zero_module(module):
117
- """
118
- Zero out the parameters of a module and return it.
119
- """
120
- for p in module.parameters():
121
- p.detach().zero_()
122
- return module
123
-
124
-
125
- def Normalize(in_channels):
126
- return torch.nn.GroupNorm(
127
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
128
- )
129
-
130
-
131
- class LinearAttention(nn.Module):
132
- def __init__(self, dim, heads=4, dim_head=32):
133
- super().__init__()
134
- self.heads = heads
135
- hidden_dim = dim_head * heads
136
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
137
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
138
-
139
- def forward(self, x):
140
- b, c, h, w = x.shape
141
- qkv = self.to_qkv(x)
142
- q, k, v = rearrange(
143
- qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
144
- )
145
- k = k.softmax(dim=-1)
146
- context = torch.einsum("bhdn,bhen->bhde", k, v)
147
- out = torch.einsum("bhde,bhdn->bhen", context, q)
148
- out = rearrange(
149
- out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
150
- )
151
- return self.to_out(out)
152
-
153
-
154
- class SelfAttention(nn.Module):
155
- ATTENTION_MODES = ("xformers", "torch", "math")
156
-
157
- def __init__(
158
- self,
159
- dim: int,
160
- num_heads: int = 8,
161
- qkv_bias: bool = False,
162
- qk_scale: Optional[float] = None,
163
- attn_drop: float = 0.0,
164
- proj_drop: float = 0.0,
165
- attn_mode: str = "xformers",
166
- ):
167
- super().__init__()
168
- self.num_heads = num_heads
169
- head_dim = dim // num_heads
170
- self.scale = qk_scale or head_dim**-0.5
171
-
172
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
173
- self.attn_drop = nn.Dropout(attn_drop)
174
- self.proj = nn.Linear(dim, dim)
175
- self.proj_drop = nn.Dropout(proj_drop)
176
- assert attn_mode in self.ATTENTION_MODES
177
- self.attn_mode = attn_mode
178
-
179
- def forward(self, x: torch.Tensor) -> torch.Tensor:
180
- B, L, C = x.shape
181
-
182
- qkv = self.qkv(x)
183
- if self.attn_mode == "torch":
184
- qkv = rearrange(
185
- qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
186
- ).float()
187
- q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
188
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
189
- x = rearrange(x, "B H L D -> B L (H D)")
190
- elif self.attn_mode == "xformers":
191
- qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
192
- q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
193
- x = xformers.ops.memory_efficient_attention(q, k, v)
194
- x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
195
- elif self.attn_mode == "math":
196
- qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
197
- q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
198
- attn = (q @ k.transpose(-2, -1)) * self.scale
199
- attn = attn.softmax(dim=-1)
200
- attn = self.attn_drop(attn)
201
- x = (attn @ v).transpose(1, 2).reshape(B, L, C)
202
- else:
203
- raise NotImplemented
204
-
205
- x = self.proj(x)
206
- x = self.proj_drop(x)
207
- return x
208
-
209
-
210
- class SpatialSelfAttention(nn.Module):
211
- def __init__(self, in_channels):
212
- super().__init__()
213
- self.in_channels = in_channels
214
-
215
- self.norm = Normalize(in_channels)
216
- self.q = torch.nn.Conv2d(
217
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
218
- )
219
- self.k = torch.nn.Conv2d(
220
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
221
- )
222
- self.v = torch.nn.Conv2d(
223
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
224
- )
225
- self.proj_out = torch.nn.Conv2d(
226
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
227
- )
228
-
229
- def forward(self, x):
230
- h_ = x
231
- h_ = self.norm(h_)
232
- q = self.q(h_)
233
- k = self.k(h_)
234
- v = self.v(h_)
235
-
236
- # compute attention
237
- b, c, h, w = q.shape
238
- q = rearrange(q, "b c h w -> b (h w) c")
239
- k = rearrange(k, "b c h w -> b c (h w)")
240
- w_ = torch.einsum("bij,bjk->bik", q, k)
241
-
242
- w_ = w_ * (int(c) ** (-0.5))
243
- w_ = torch.nn.functional.softmax(w_, dim=2)
244
-
245
- # attend to values
246
- v = rearrange(v, "b c h w -> b c (h w)")
247
- w_ = rearrange(w_, "b i j -> b j i")
248
- h_ = torch.einsum("bij,bjk->bik", v, w_)
249
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
250
- h_ = self.proj_out(h_)
251
-
252
- return x + h_
253
-
254
-
255
- class CrossAttention(nn.Module):
256
- def __init__(
257
- self,
258
- query_dim,
259
- context_dim=None,
260
- heads=8,
261
- dim_head=64,
262
- dropout=0.0,
263
- backend=None,
264
- ):
265
- super().__init__()
266
- inner_dim = dim_head * heads
267
- context_dim = default(context_dim, query_dim)
268
-
269
- self.scale = dim_head**-0.5
270
- self.heads = heads
271
-
272
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
273
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
274
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
275
-
276
- self.to_out = nn.Sequential(
277
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
278
- )
279
- self.backend = backend
280
-
281
- def forward(
282
- self,
283
- x,
284
- context=None,
285
- mask=None,
286
- additional_tokens=None,
287
- n_times_crossframe_attn_in_self=0,
288
- ):
289
- h = self.heads
290
-
291
- if additional_tokens is not None:
292
- # get the number of masked tokens at the beginning of the output sequence
293
- n_tokens_to_mask = additional_tokens.shape[1]
294
- # add additional token
295
- x = torch.cat([additional_tokens, x], dim=1)
296
-
297
- q = self.to_q(x)
298
- context = default(context, x)
299
- k = self.to_k(context)
300
- v = self.to_v(context)
301
-
302
- if n_times_crossframe_attn_in_self:
303
- # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
304
- assert x.shape[0] % n_times_crossframe_attn_in_self == 0
305
- n_cp = x.shape[0] // n_times_crossframe_attn_in_self
306
- k = repeat(
307
- k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
308
- )
309
- v = repeat(
310
- v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
311
- )
312
-
313
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
314
-
315
- ## old
316
- """
317
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
318
- del q, k
319
-
320
- if exists(mask):
321
- mask = rearrange(mask, 'b ... -> b (...)')
322
- max_neg_value = -torch.finfo(sim.dtype).max
323
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
324
- sim.masked_fill_(~mask, max_neg_value)
325
-
326
- # attention, what we cannot get enough of
327
- sim = sim.softmax(dim=-1)
328
-
329
- out = einsum('b i j, b j d -> b i d', sim, v)
330
- """
331
- ## new
332
- with sdp_kernel(**BACKEND_MAP[self.backend]):
333
- # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
334
- out = F.scaled_dot_product_attention(
335
- q, k, v, attn_mask=mask
336
- ) # scale is dim_head ** -0.5 per default
337
-
338
- del q, k, v
339
- out = rearrange(out, "b h n d -> b n (h d)", h=h)
340
-
341
- if additional_tokens is not None:
342
- # remove additional token
343
- out = out[:, n_tokens_to_mask:]
344
- return self.to_out(out)
345
-
346
-
347
- class MemoryEfficientCrossAttention(nn.Module):
348
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
349
- def __init__(
350
- self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
351
- ):
352
- super().__init__()
353
- logpy.debug(
354
- f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
355
- f"context_dim is {context_dim} and using {heads} heads with a "
356
- f"dimension of {dim_head}."
357
- )
358
- inner_dim = dim_head * heads
359
- context_dim = default(context_dim, query_dim)
360
-
361
- self.heads = heads
362
- self.dim_head = dim_head
363
-
364
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
365
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
366
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
367
-
368
- self.to_out = nn.Sequential(
369
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
370
- )
371
- self.attention_op: Optional[Any] = None
372
-
373
- def forward(
374
- self,
375
- x,
376
- context=None,
377
- mask=None,
378
- additional_tokens=None,
379
- n_times_crossframe_attn_in_self=0,
380
- ):
381
- if additional_tokens is not None:
382
- # get the number of masked tokens at the beginning of the output sequence
383
- n_tokens_to_mask = additional_tokens.shape[1]
384
- # add additional token
385
- x = torch.cat([additional_tokens, x], dim=1)
386
- q = self.to_q(x)
387
- context = default(context, x)
388
- k = self.to_k(context)
389
- v = self.to_v(context)
390
-
391
- if n_times_crossframe_attn_in_self:
392
- # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
393
- assert x.shape[0] % n_times_crossframe_attn_in_self == 0
394
- # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
395
- k = repeat(
396
- k[::n_times_crossframe_attn_in_self],
397
- "b ... -> (b n) ...",
398
- n=n_times_crossframe_attn_in_self,
399
- )
400
- v = repeat(
401
- v[::n_times_crossframe_attn_in_self],
402
- "b ... -> (b n) ...",
403
- n=n_times_crossframe_attn_in_self,
404
- )
405
-
406
- b, _, _ = q.shape
407
- q, k, v = map(
408
- lambda t: t.unsqueeze(3)
409
- .reshape(b, t.shape[1], self.heads, self.dim_head)
410
- .permute(0, 2, 1, 3)
411
- .reshape(b * self.heads, t.shape[1], self.dim_head)
412
- .contiguous(),
413
- (q, k, v),
414
- )
415
-
416
- # actually compute the attention, what we cannot get enough of
417
- if version.parse(xformers.__version__) >= version.parse("0.0.21"):
418
- # NOTE: workaround for
419
- # https://github.com/facebookresearch/xformers/issues/845
420
- max_bs = 32768
421
- N = q.shape[0]
422
- n_batches = math.ceil(N / max_bs)
423
- out = list()
424
- for i_batch in range(n_batches):
425
- batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
426
- out.append(
427
- xformers.ops.memory_efficient_attention(
428
- q[batch],
429
- k[batch],
430
- v[batch],
431
- attn_bias=None,
432
- op=self.attention_op,
433
- )
434
- )
435
- out = torch.cat(out, 0)
436
- else:
437
- out = xformers.ops.memory_efficient_attention(
438
- q, k, v, attn_bias=None, op=self.attention_op
439
- )
440
-
441
- # TODO: Use this directly in the attention operation, as a bias
442
- if exists(mask):
443
- raise NotImplementedError
444
- out = (
445
- out.unsqueeze(0)
446
- .reshape(b, self.heads, out.shape[1], self.dim_head)
447
- .permute(0, 2, 1, 3)
448
- .reshape(b, out.shape[1], self.heads * self.dim_head)
449
- )
450
- if additional_tokens is not None:
451
- # remove additional token
452
- out = out[:, n_tokens_to_mask:]
453
- return self.to_out(out)
454
-
455
-
456
- class BasicTransformerBlock(nn.Module):
457
- ATTENTION_MODES = {
458
- "softmax": CrossAttention, # vanilla attention
459
- "softmax-xformers": MemoryEfficientCrossAttention, # ampere
460
- }
461
-
462
- def __init__(
463
- self,
464
- dim,
465
- n_heads,
466
- d_head,
467
- dropout=0.0,
468
- context_dim=None,
469
- gated_ff=True,
470
- checkpoint=True,
471
- disable_self_attn=False,
472
- attn_mode="softmax",
473
- sdp_backend=None,
474
- ):
475
- super().__init__()
476
- assert attn_mode in self.ATTENTION_MODES
477
- if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
478
- logpy.warn(
479
- f"Attention mode '{attn_mode}' is not available. Falling "
480
- f"back to native attention. This is not a problem in "
481
- f"Pytorch >= 2.0. FYI, you are running with PyTorch "
482
- f"version {torch.__version__}."
483
- )
484
- attn_mode = "softmax"
485
- elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
486
- logpy.warn(
487
- "We do not support vanilla attention anymore, as it is too "
488
- "expensive. Sorry."
489
- )
490
- if not XFORMERS_IS_AVAILABLE:
491
- assert (
492
- False
493
- ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
494
- else:
495
- logpy.info("Falling back to xformers efficient attention.")
496
- attn_mode = "softmax-xformers"
497
- attn_cls = self.ATTENTION_MODES[attn_mode]
498
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
499
- assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
500
- else:
501
- assert sdp_backend is None
502
- self.disable_self_attn = disable_self_attn
503
- self.attn1 = attn_cls(
504
- query_dim=dim,
505
- heads=n_heads,
506
- dim_head=d_head,
507
- dropout=dropout,
508
- context_dim=context_dim if self.disable_self_attn else None,
509
- backend=sdp_backend,
510
- ) # is a self-attention if not self.disable_self_attn
511
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
512
- self.attn2 = attn_cls(
513
- query_dim=dim,
514
- context_dim=context_dim,
515
- heads=n_heads,
516
- dim_head=d_head,
517
- dropout=dropout,
518
- backend=sdp_backend,
519
- ) # is self-attn if context is none
520
- self.norm1 = nn.LayerNorm(dim)
521
- self.norm2 = nn.LayerNorm(dim)
522
- self.norm3 = nn.LayerNorm(dim)
523
- self.checkpoint = checkpoint
524
- if self.checkpoint:
525
- logpy.debug(f"{self.__class__.__name__} is using checkpointing")
526
-
527
- def forward(
528
- self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
529
- ):
530
- kwargs = {"x": x}
531
-
532
- if context is not None:
533
- kwargs.update({"context": context})
534
-
535
- if additional_tokens is not None:
536
- kwargs.update({"additional_tokens": additional_tokens})
537
-
538
- if n_times_crossframe_attn_in_self:
539
- kwargs.update(
540
- {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
541
- )
542
-
543
- # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
544
- if self.checkpoint:
545
- # inputs = {"x": x, "context": context}
546
- return checkpoint(self._forward, x, context)
547
- # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
548
- else:
549
- return self._forward(**kwargs)
550
-
551
- def _forward(
552
- self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
553
- ):
554
- x = (
555
- self.attn1(
556
- self.norm1(x),
557
- context=context if self.disable_self_attn else None,
558
- additional_tokens=additional_tokens,
559
- n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
560
- if not self.disable_self_attn
561
- else 0,
562
- )
563
- + x
564
- )
565
- x = (
566
- self.attn2(
567
- self.norm2(x), context=context, additional_tokens=additional_tokens
568
- )
569
- + x
570
- )
571
- x = self.ff(self.norm3(x)) + x
572
- return x
573
-
574
-
575
- class BasicTransformerSingleLayerBlock(nn.Module):
576
- ATTENTION_MODES = {
577
- "softmax": CrossAttention, # vanilla attention
578
- "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
579
- # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
580
- }
581
-
582
- def __init__(
583
- self,
584
- dim,
585
- n_heads,
586
- d_head,
587
- dropout=0.0,
588
- context_dim=None,
589
- gated_ff=True,
590
- checkpoint=True,
591
- attn_mode="softmax",
592
- ):
593
- super().__init__()
594
- assert attn_mode in self.ATTENTION_MODES
595
- attn_cls = self.ATTENTION_MODES[attn_mode]
596
- self.attn1 = attn_cls(
597
- query_dim=dim,
598
- heads=n_heads,
599
- dim_head=d_head,
600
- dropout=dropout,
601
- context_dim=context_dim,
602
- )
603
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
604
- self.norm1 = nn.LayerNorm(dim)
605
- self.norm2 = nn.LayerNorm(dim)
606
- self.checkpoint = checkpoint
607
-
608
- def forward(self, x, context=None):
609
- # inputs = {"x": x, "context": context}
610
- # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
611
- return checkpoint(self._forward, x, context)
612
-
613
- def _forward(self, x, context=None):
614
- x = self.attn1(self.norm1(x), context=context) + x
615
- x = self.ff(self.norm2(x)) + x
616
- return x
617
-
618
-
619
- class SpatialTransformer(nn.Module):
620
- """
621
- Transformer block for image-like data.
622
- First, project the input (aka embedding)
623
- and reshape to b, t, d.
624
- Then apply standard transformer action.
625
- Finally, reshape to image
626
- NEW: use_linear for more efficiency instead of the 1x1 convs
627
- """
628
-
629
- def __init__(
630
- self,
631
- in_channels,
632
- n_heads,
633
- d_head,
634
- depth=1,
635
- dropout=0.0,
636
- context_dim=None,
637
- disable_self_attn=False,
638
- use_linear=False,
639
- attn_type="softmax",
640
- use_checkpoint=True,
641
- # sdp_backend=SDPBackend.FLASH_ATTENTION
642
- sdp_backend=None,
643
- ):
644
- super().__init__()
645
- logpy.debug(
646
- f"constructing {self.__class__.__name__} of depth {depth} w/ "
647
- f"{in_channels} channels and {n_heads} heads."
648
- )
649
-
650
- if exists(context_dim) and not isinstance(context_dim, list):
651
- context_dim = [context_dim]
652
- if exists(context_dim) and isinstance(context_dim, list):
653
- if depth != len(context_dim):
654
- logpy.warn(
655
- f"{self.__class__.__name__}: Found context dims "
656
- f"{context_dim} of depth {len(context_dim)}, which does not "
657
- f"match the specified 'depth' of {depth}. Setting context_dim "
658
- f"to {depth * [context_dim[0]]} now."
659
- )
660
- # depth does not match context dims.
661
- assert all(
662
- map(lambda x: x == context_dim[0], context_dim)
663
- ), "need homogenous context_dim to match depth automatically"
664
- context_dim = depth * [context_dim[0]]
665
- elif context_dim is None:
666
- context_dim = [None] * depth
667
- self.in_channels = in_channels
668
- inner_dim = n_heads * d_head
669
- self.norm = Normalize(in_channels)
670
- if not use_linear:
671
- self.proj_in = nn.Conv2d(
672
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
673
- )
674
- else:
675
- self.proj_in = nn.Linear(in_channels, inner_dim)
676
-
677
- self.transformer_blocks = nn.ModuleList(
678
- [
679
- BasicTransformerBlock(
680
- inner_dim,
681
- n_heads,
682
- d_head,
683
- dropout=dropout,
684
- context_dim=context_dim[d],
685
- disable_self_attn=disable_self_attn,
686
- attn_mode=attn_type,
687
- checkpoint=use_checkpoint,
688
- sdp_backend=sdp_backend,
689
- )
690
- for d in range(depth)
691
- ]
692
- )
693
- if not use_linear:
694
- self.proj_out = zero_module(
695
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
696
- )
697
- else:
698
- # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
699
- self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
700
- self.use_linear = use_linear
701
-
702
- def forward(self, x, context=None):
703
- # note: if no context is given, cross-attention defaults to self-attention
704
- if not isinstance(context, list):
705
- context = [context]
706
- b, c, h, w = x.shape
707
- x_in = x
708
- x = self.norm(x)
709
- if not self.use_linear:
710
- x = self.proj_in(x)
711
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
712
- if self.use_linear:
713
- x = self.proj_in(x)
714
- for i, block in enumerate(self.transformer_blocks):
715
- if i > 0 and len(context) == 1:
716
- i = 0 # use same context for each block
717
- x = block(x, context=context[i])
718
- if self.use_linear:
719
- x = self.proj_out(x)
720
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
721
- if not self.use_linear:
722
- x = self.proj_out(x)
723
- return x + x_in
724
-
725
-
726
- class SimpleTransformer(nn.Module):
727
- def __init__(
728
- self,
729
- dim: int,
730
- depth: int,
731
- heads: int,
732
- dim_head: int,
733
- context_dim: Optional[int] = None,
734
- dropout: float = 0.0,
735
- checkpoint: bool = True,
736
- ):
737
- super().__init__()
738
- self.layers = nn.ModuleList([])
739
- for _ in range(depth):
740
- self.layers.append(
741
- BasicTransformerBlock(
742
- dim,
743
- heads,
744
- dim_head,
745
- dropout=dropout,
746
- context_dim=context_dim,
747
- attn_mode="softmax-xformers",
748
- checkpoint=checkpoint,
749
- )
750
- )
751
-
752
- def forward(
753
- self,
754
- x: torch.Tensor,
755
- context: Optional[torch.Tensor] = None,
756
- ) -> torch.Tensor:
757
- for layer in self.layers:
758
- x = layer(x, context)
759
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/__init__.py DELETED
File without changes
sgm/modules/autoencoding/losses/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- __all__ = [
2
- "GeneralLPIPSWithDiscriminator",
3
- "LatentLPIPS",
4
- ]
5
-
6
- from .discriminator_loss import GeneralLPIPSWithDiscriminator
7
- from .lpips import LatentLPIPS
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/losses/discriminator_loss.py DELETED
@@ -1,306 +0,0 @@
1
- from typing import Dict, Iterator, List, Optional, Tuple, Union
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torchvision
7
- from einops import rearrange
8
- from matplotlib import colormaps
9
- from matplotlib import pyplot as plt
10
-
11
- from ....util import default, instantiate_from_config
12
- from ..lpips.loss.lpips import LPIPS
13
- from ..lpips.model.model import weights_init
14
- from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
15
-
16
-
17
- class GeneralLPIPSWithDiscriminator(nn.Module):
18
- def __init__(
19
- self,
20
- disc_start: int,
21
- logvar_init: float = 0.0,
22
- disc_num_layers: int = 3,
23
- disc_in_channels: int = 3,
24
- disc_factor: float = 1.0,
25
- disc_weight: float = 1.0,
26
- perceptual_weight: float = 1.0,
27
- disc_loss: str = "hinge",
28
- scale_input_to_tgt_size: bool = False,
29
- dims: int = 2,
30
- learn_logvar: bool = False,
31
- regularization_weights: Union[None, Dict[str, float]] = None,
32
- additional_log_keys: Optional[List[str]] = None,
33
- discriminator_config: Optional[Dict] = None,
34
- ):
35
- super().__init__()
36
- self.dims = dims
37
- if self.dims > 2:
38
- print(
39
- f"running with dims={dims}. This means that for perceptual loss "
40
- f"calculation, the LPIPS loss will be applied to each frame "
41
- f"independently."
42
- )
43
- self.scale_input_to_tgt_size = scale_input_to_tgt_size
44
- assert disc_loss in ["hinge", "vanilla"]
45
- self.perceptual_loss = LPIPS().eval()
46
- self.perceptual_weight = perceptual_weight
47
- # output log variance
48
- self.logvar = nn.Parameter(
49
- torch.full((), logvar_init), requires_grad=learn_logvar
50
- )
51
- self.learn_logvar = learn_logvar
52
-
53
- discriminator_config = default(
54
- discriminator_config,
55
- {
56
- "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
57
- "params": {
58
- "input_nc": disc_in_channels,
59
- "n_layers": disc_num_layers,
60
- "use_actnorm": False,
61
- },
62
- },
63
- )
64
-
65
- self.discriminator = instantiate_from_config(discriminator_config).apply(
66
- weights_init
67
- )
68
- self.discriminator_iter_start = disc_start
69
- self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
70
- self.disc_factor = disc_factor
71
- self.discriminator_weight = disc_weight
72
- self.regularization_weights = default(regularization_weights, {})
73
-
74
- self.forward_keys = [
75
- "optimizer_idx",
76
- "global_step",
77
- "last_layer",
78
- "split",
79
- "regularization_log",
80
- ]
81
-
82
- self.additional_log_keys = set(default(additional_log_keys, []))
83
- self.additional_log_keys.update(set(self.regularization_weights.keys()))
84
-
85
- def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
86
- return self.discriminator.parameters()
87
-
88
- def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
89
- if self.learn_logvar:
90
- yield self.logvar
91
- yield from ()
92
-
93
- @torch.no_grad()
94
- def log_images(
95
- self, inputs: torch.Tensor, reconstructions: torch.Tensor
96
- ) -> Dict[str, torch.Tensor]:
97
- # calc logits of real/fake
98
- logits_real = self.discriminator(inputs.contiguous().detach())
99
- if len(logits_real.shape) < 4:
100
- # Non patch-discriminator
101
- return dict()
102
- logits_fake = self.discriminator(reconstructions.contiguous().detach())
103
- # -> (b, 1, h, w)
104
-
105
- # parameters for colormapping
106
- high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
107
- cmap = colormaps["PiYG"] # diverging colormap
108
-
109
- def to_colormap(logits: torch.Tensor) -> torch.Tensor:
110
- """(b, 1, ...) -> (b, 3, ...)"""
111
- logits = (logits + high) / (2 * high)
112
- logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
113
- # -> (b, 1, ..., 3)
114
- logits = torch.from_numpy(logits_np).to(logits.device)
115
- return rearrange(logits, "b 1 ... c -> b c ...")
116
-
117
- logits_real = torch.nn.functional.interpolate(
118
- logits_real,
119
- size=inputs.shape[-2:],
120
- mode="nearest",
121
- antialias=False,
122
- )
123
- logits_fake = torch.nn.functional.interpolate(
124
- logits_fake,
125
- size=reconstructions.shape[-2:],
126
- mode="nearest",
127
- antialias=False,
128
- )
129
-
130
- # alpha value of logits for overlay
131
- alpha_real = torch.abs(logits_real) / high
132
- alpha_fake = torch.abs(logits_fake) / high
133
- # -> (b, 1, h, w) in range [0, 0.5]
134
- # alpha value of lines don't really matter, since the values are the same
135
- # for both images and logits anyway
136
- grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
137
- grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
138
- grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
139
- # -> (1, h, w)
140
- # blend logits and images together
141
-
142
- # prepare logits for plotting
143
- logits_real = to_colormap(logits_real)
144
- logits_fake = to_colormap(logits_fake)
145
- # resize logits
146
- # -> (b, 3, h, w)
147
-
148
- # make some grids
149
- # add all logits to one plot
150
- logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
151
- logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
152
- # I just love how torchvision calls the number of columns `nrow`
153
- grid_logits = torch.cat((logits_real, logits_fake), dim=1)
154
- # -> (3, h, w)
155
-
156
- grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
157
- grid_images_fake = torchvision.utils.make_grid(
158
- 0.5 * reconstructions + 0.5, nrow=4
159
- )
160
- grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
161
- # -> (3, h, w) in range [0, 1]
162
-
163
- grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
164
-
165
- # Create labeled colorbar
166
- dpi = 100
167
- height = 128 / dpi
168
- width = grid_logits.shape[2] / dpi
169
- fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
170
- img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
171
- plt.colorbar(
172
- img,
173
- cax=ax,
174
- orientation="horizontal",
175
- fraction=0.9,
176
- aspect=width / height,
177
- pad=0.0,
178
- )
179
- img.set_visible(False)
180
- fig.tight_layout()
181
- fig.canvas.draw()
182
- # manually convert figure to numpy
183
- cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
184
- cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
185
- cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
186
- cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
187
-
188
- # Add colorbar to plot
189
- annotated_grid = torch.cat((grid_logits, cbar), dim=1)
190
- blended_grid = torch.cat((grid_blend, cbar), dim=1)
191
- return {
192
- "vis_logits": 2 * annotated_grid[None, ...] - 1,
193
- "vis_logits_blended": 2 * blended_grid[None, ...] - 1,
194
- }
195
-
196
- def calculate_adaptive_weight(
197
- self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
198
- ) -> torch.Tensor:
199
- nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
200
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
201
-
202
- d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
203
- d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
204
- d_weight = d_weight * self.discriminator_weight
205
- return d_weight
206
-
207
- def forward(
208
- self,
209
- inputs: torch.Tensor,
210
- reconstructions: torch.Tensor,
211
- *, # added because I changed the order here
212
- regularization_log: Dict[str, torch.Tensor],
213
- optimizer_idx: int,
214
- global_step: int,
215
- last_layer: torch.Tensor,
216
- split: str = "train",
217
- weights: Union[None, float, torch.Tensor] = None,
218
- ) -> Tuple[torch.Tensor, dict]:
219
- if self.scale_input_to_tgt_size:
220
- inputs = torch.nn.functional.interpolate(
221
- inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
222
- )
223
-
224
- if self.dims > 2:
225
- inputs, reconstructions = map(
226
- lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
227
- (inputs, reconstructions),
228
- )
229
-
230
- rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
231
- if self.perceptual_weight > 0:
232
- p_loss = self.perceptual_loss(
233
- inputs.contiguous(), reconstructions.contiguous()
234
- )
235
- rec_loss = rec_loss + self.perceptual_weight * p_loss
236
-
237
- nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
238
-
239
- # now the GAN part
240
- if optimizer_idx == 0:
241
- # generator update
242
- if global_step >= self.discriminator_iter_start or not self.training:
243
- logits_fake = self.discriminator(reconstructions.contiguous())
244
- g_loss = -torch.mean(logits_fake)
245
- if self.training:
246
- d_weight = self.calculate_adaptive_weight(
247
- nll_loss, g_loss, last_layer=last_layer
248
- )
249
- else:
250
- d_weight = torch.tensor(1.0)
251
- else:
252
- d_weight = torch.tensor(0.0)
253
- g_loss = torch.tensor(0.0, requires_grad=True)
254
-
255
- loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
256
- log = dict()
257
- for k in regularization_log:
258
- if k in self.regularization_weights:
259
- loss = loss + self.regularization_weights[k] * regularization_log[k]
260
- if k in self.additional_log_keys:
261
- log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
262
-
263
- log.update(
264
- {
265
- f"{split}/loss/total": loss.clone().detach().mean(),
266
- f"{split}/loss/nll": nll_loss.detach().mean(),
267
- f"{split}/loss/rec": rec_loss.detach().mean(),
268
- f"{split}/loss/g": g_loss.detach().mean(),
269
- f"{split}/scalars/logvar": self.logvar.detach(),
270
- f"{split}/scalars/d_weight": d_weight.detach(),
271
- }
272
- )
273
-
274
- return loss, log
275
- elif optimizer_idx == 1:
276
- # second pass for discriminator update
277
- logits_real = self.discriminator(inputs.contiguous().detach())
278
- logits_fake = self.discriminator(reconstructions.contiguous().detach())
279
-
280
- if global_step >= self.discriminator_iter_start or not self.training:
281
- d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
282
- else:
283
- d_loss = torch.tensor(0.0, requires_grad=True)
284
-
285
- log = {
286
- f"{split}/loss/disc": d_loss.clone().detach().mean(),
287
- f"{split}/logits/real": logits_real.detach().mean(),
288
- f"{split}/logits/fake": logits_fake.detach().mean(),
289
- }
290
- return d_loss, log
291
- else:
292
- raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
293
-
294
- def get_nll_loss(
295
- self,
296
- rec_loss: torch.Tensor,
297
- weights: Optional[Union[float, torch.Tensor]] = None,
298
- ) -> Tuple[torch.Tensor, torch.Tensor]:
299
- nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
300
- weighted_nll_loss = nll_loss
301
- if weights is not None:
302
- weighted_nll_loss = weights * nll_loss
303
- weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
304
- nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
305
-
306
- return nll_loss, weighted_nll_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/losses/lpips.py DELETED
@@ -1,73 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from ....util import default, instantiate_from_config
5
- from ..lpips.loss.lpips import LPIPS
6
-
7
-
8
- class LatentLPIPS(nn.Module):
9
- def __init__(
10
- self,
11
- decoder_config,
12
- perceptual_weight=1.0,
13
- latent_weight=1.0,
14
- scale_input_to_tgt_size=False,
15
- scale_tgt_to_input_size=False,
16
- perceptual_weight_on_inputs=0.0,
17
- ):
18
- super().__init__()
19
- self.scale_input_to_tgt_size = scale_input_to_tgt_size
20
- self.scale_tgt_to_input_size = scale_tgt_to_input_size
21
- self.init_decoder(decoder_config)
22
- self.perceptual_loss = LPIPS().eval()
23
- self.perceptual_weight = perceptual_weight
24
- self.latent_weight = latent_weight
25
- self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
26
-
27
- def init_decoder(self, config):
28
- self.decoder = instantiate_from_config(config)
29
- if hasattr(self.decoder, "encoder"):
30
- del self.decoder.encoder
31
-
32
- def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
33
- log = dict()
34
- loss = (latent_inputs - latent_predictions) ** 2
35
- log[f"{split}/latent_l2_loss"] = loss.mean().detach()
36
- image_reconstructions = None
37
- if self.perceptual_weight > 0.0:
38
- image_reconstructions = self.decoder.decode(latent_predictions)
39
- image_targets = self.decoder.decode(latent_inputs)
40
- perceptual_loss = self.perceptual_loss(
41
- image_targets.contiguous(), image_reconstructions.contiguous()
42
- )
43
- loss = (
44
- self.latent_weight * loss.mean()
45
- + self.perceptual_weight * perceptual_loss.mean()
46
- )
47
- log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
48
-
49
- if self.perceptual_weight_on_inputs > 0.0:
50
- image_reconstructions = default(
51
- image_reconstructions, self.decoder.decode(latent_predictions)
52
- )
53
- if self.scale_input_to_tgt_size:
54
- image_inputs = torch.nn.functional.interpolate(
55
- image_inputs,
56
- image_reconstructions.shape[2:],
57
- mode="bicubic",
58
- antialias=True,
59
- )
60
- elif self.scale_tgt_to_input_size:
61
- image_reconstructions = torch.nn.functional.interpolate(
62
- image_reconstructions,
63
- image_inputs.shape[2:],
64
- mode="bicubic",
65
- antialias=True,
66
- )
67
-
68
- perceptual_loss2 = self.perceptual_loss(
69
- image_inputs.contiguous(), image_reconstructions.contiguous()
70
- )
71
- loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
72
- log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
73
- return loss, log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/__init__.py DELETED
File without changes
sgm/modules/autoencoding/lpips/loss/.gitignore DELETED
@@ -1 +0,0 @@
1
- vgg.pth
 
 
sgm/modules/autoencoding/lpips/loss/LICENSE DELETED
@@ -1,23 +0,0 @@
1
- Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
- All rights reserved.
3
-
4
- Redistribution and use in source and binary forms, with or without
5
- modification, are permitted provided that the following conditions are met:
6
-
7
- * Redistributions of source code must retain the above copyright notice, this
8
- list of conditions and the following disclaimer.
9
-
10
- * Redistributions in binary form must reproduce the above copyright notice,
11
- this list of conditions and the following disclaimer in the documentation
12
- and/or other materials provided with the distribution.
13
-
14
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/loss/__init__.py DELETED
File without changes
sgm/modules/autoencoding/lpips/loss/lpips.py DELETED
@@ -1,147 +0,0 @@
1
- """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
-
3
- from collections import namedtuple
4
-
5
- import torch
6
- import torch.nn as nn
7
- from torchvision import models
8
-
9
- from ..util import get_ckpt_path
10
-
11
-
12
- class LPIPS(nn.Module):
13
- # Learned perceptual metric
14
- def __init__(self, use_dropout=True):
15
- super().__init__()
16
- self.scaling_layer = ScalingLayer()
17
- self.chns = [64, 128, 256, 512, 512] # vg16 features
18
- self.net = vgg16(pretrained=True, requires_grad=False)
19
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24
- self.load_from_pretrained()
25
- for param in self.parameters():
26
- param.requires_grad = False
27
-
28
- def load_from_pretrained(self, name="vgg_lpips"):
29
- ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
30
- self.load_state_dict(
31
- torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32
- )
33
- print("loaded pretrained LPIPS loss from {}".format(ckpt))
34
-
35
- @classmethod
36
- def from_pretrained(cls, name="vgg_lpips"):
37
- if name != "vgg_lpips":
38
- raise NotImplementedError
39
- model = cls()
40
- ckpt = get_ckpt_path(name)
41
- model.load_state_dict(
42
- torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43
- )
44
- return model
45
-
46
- def forward(self, input, target):
47
- in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48
- outs0, outs1 = self.net(in0_input), self.net(in1_input)
49
- feats0, feats1, diffs = {}, {}, {}
50
- lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51
- for kk in range(len(self.chns)):
52
- feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53
- outs1[kk]
54
- )
55
- diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56
-
57
- res = [
58
- spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59
- for kk in range(len(self.chns))
60
- ]
61
- val = res[0]
62
- for l in range(1, len(self.chns)):
63
- val += res[l]
64
- return val
65
-
66
-
67
- class ScalingLayer(nn.Module):
68
- def __init__(self):
69
- super(ScalingLayer, self).__init__()
70
- self.register_buffer(
71
- "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72
- )
73
- self.register_buffer(
74
- "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75
- )
76
-
77
- def forward(self, inp):
78
- return (inp - self.shift) / self.scale
79
-
80
-
81
- class NetLinLayer(nn.Module):
82
- """A single linear layer which does a 1x1 conv"""
83
-
84
- def __init__(self, chn_in, chn_out=1, use_dropout=False):
85
- super(NetLinLayer, self).__init__()
86
- layers = (
87
- [
88
- nn.Dropout(),
89
- ]
90
- if (use_dropout)
91
- else []
92
- )
93
- layers += [
94
- nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95
- ]
96
- self.model = nn.Sequential(*layers)
97
-
98
-
99
- class vgg16(torch.nn.Module):
100
- def __init__(self, requires_grad=False, pretrained=True):
101
- super(vgg16, self).__init__()
102
- vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103
- self.slice1 = torch.nn.Sequential()
104
- self.slice2 = torch.nn.Sequential()
105
- self.slice3 = torch.nn.Sequential()
106
- self.slice4 = torch.nn.Sequential()
107
- self.slice5 = torch.nn.Sequential()
108
- self.N_slices = 5
109
- for x in range(4):
110
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
111
- for x in range(4, 9):
112
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
113
- for x in range(9, 16):
114
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
115
- for x in range(16, 23):
116
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
117
- for x in range(23, 30):
118
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
119
- if not requires_grad:
120
- for param in self.parameters():
121
- param.requires_grad = False
122
-
123
- def forward(self, X):
124
- h = self.slice1(X)
125
- h_relu1_2 = h
126
- h = self.slice2(h)
127
- h_relu2_2 = h
128
- h = self.slice3(h)
129
- h_relu3_3 = h
130
- h = self.slice4(h)
131
- h_relu4_3 = h
132
- h = self.slice5(h)
133
- h_relu5_3 = h
134
- vgg_outputs = namedtuple(
135
- "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136
- )
137
- out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138
- return out
139
-
140
-
141
- def normalize_tensor(x, eps=1e-10):
142
- norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143
- return x / (norm_factor + eps)
144
-
145
-
146
- def spatial_average(x, keepdim=True):
147
- return x.mean([2, 3], keepdim=keepdim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/model/LICENSE DELETED
@@ -1,58 +0,0 @@
1
- Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2
- All rights reserved.
3
-
4
- Redistribution and use in source and binary forms, with or without
5
- modification, are permitted provided that the following conditions are met:
6
-
7
- * Redistributions of source code must retain the above copyright notice, this
8
- list of conditions and the following disclaimer.
9
-
10
- * Redistributions in binary form must reproduce the above copyright notice,
11
- this list of conditions and the following disclaimer in the documentation
12
- and/or other materials provided with the distribution.
13
-
14
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
-
25
-
26
- --------------------------- LICENSE FOR pix2pix --------------------------------
27
- BSD License
28
-
29
- For pix2pix software
30
- Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31
- All rights reserved.
32
-
33
- Redistribution and use in source and binary forms, with or without
34
- modification, are permitted provided that the following conditions are met:
35
-
36
- * Redistributions of source code must retain the above copyright notice, this
37
- list of conditions and the following disclaimer.
38
-
39
- * Redistributions in binary form must reproduce the above copyright notice,
40
- this list of conditions and the following disclaimer in the documentation
41
- and/or other materials provided with the distribution.
42
-
43
- ----------------------------- LICENSE FOR DCGAN --------------------------------
44
- BSD License
45
-
46
- For dcgan.torch software
47
-
48
- Copyright (c) 2015, Facebook, Inc. All rights reserved.
49
-
50
- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51
-
52
- Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53
-
54
- Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55
-
56
- Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57
-
58
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/model/__init__.py DELETED
File without changes
sgm/modules/autoencoding/lpips/model/model.py DELETED
@@ -1,88 +0,0 @@
1
- import functools
2
-
3
- import torch.nn as nn
4
-
5
- from ..util import ActNorm
6
-
7
-
8
- def weights_init(m):
9
- classname = m.__class__.__name__
10
- if classname.find("Conv") != -1:
11
- nn.init.normal_(m.weight.data, 0.0, 0.02)
12
- elif classname.find("BatchNorm") != -1:
13
- nn.init.normal_(m.weight.data, 1.0, 0.02)
14
- nn.init.constant_(m.bias.data, 0)
15
-
16
-
17
- class NLayerDiscriminator(nn.Module):
18
- """Defines a PatchGAN discriminator as in Pix2Pix
19
- --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
- """
21
-
22
- def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23
- """Construct a PatchGAN discriminator
24
- Parameters:
25
- input_nc (int) -- the number of channels in input images
26
- ndf (int) -- the number of filters in the last conv layer
27
- n_layers (int) -- the number of conv layers in the discriminator
28
- norm_layer -- normalization layer
29
- """
30
- super(NLayerDiscriminator, self).__init__()
31
- if not use_actnorm:
32
- norm_layer = nn.BatchNorm2d
33
- else:
34
- norm_layer = ActNorm
35
- if (
36
- type(norm_layer) == functools.partial
37
- ): # no need to use bias as BatchNorm2d has affine parameters
38
- use_bias = norm_layer.func != nn.BatchNorm2d
39
- else:
40
- use_bias = norm_layer != nn.BatchNorm2d
41
-
42
- kw = 4
43
- padw = 1
44
- sequence = [
45
- nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46
- nn.LeakyReLU(0.2, True),
47
- ]
48
- nf_mult = 1
49
- nf_mult_prev = 1
50
- for n in range(1, n_layers): # gradually increase the number of filters
51
- nf_mult_prev = nf_mult
52
- nf_mult = min(2**n, 8)
53
- sequence += [
54
- nn.Conv2d(
55
- ndf * nf_mult_prev,
56
- ndf * nf_mult,
57
- kernel_size=kw,
58
- stride=2,
59
- padding=padw,
60
- bias=use_bias,
61
- ),
62
- norm_layer(ndf * nf_mult),
63
- nn.LeakyReLU(0.2, True),
64
- ]
65
-
66
- nf_mult_prev = nf_mult
67
- nf_mult = min(2**n_layers, 8)
68
- sequence += [
69
- nn.Conv2d(
70
- ndf * nf_mult_prev,
71
- ndf * nf_mult,
72
- kernel_size=kw,
73
- stride=1,
74
- padding=padw,
75
- bias=use_bias,
76
- ),
77
- norm_layer(ndf * nf_mult),
78
- nn.LeakyReLU(0.2, True),
79
- ]
80
-
81
- sequence += [
82
- nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83
- ] # output 1 channel prediction map
84
- self.main = nn.Sequential(*sequence)
85
-
86
- def forward(self, input):
87
- """Standard forward."""
88
- return self.main(input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/util.py DELETED
@@ -1,128 +0,0 @@
1
- import hashlib
2
- import os
3
-
4
- import requests
5
- import torch
6
- import torch.nn as nn
7
- from tqdm import tqdm
8
-
9
- URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10
-
11
- CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12
-
13
- MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14
-
15
-
16
- def download(url, local_path, chunk_size=1024):
17
- os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18
- with requests.get(url, stream=True) as r:
19
- total_size = int(r.headers.get("content-length", 0))
20
- with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21
- with open(local_path, "wb") as f:
22
- for data in r.iter_content(chunk_size=chunk_size):
23
- if data:
24
- f.write(data)
25
- pbar.update(chunk_size)
26
-
27
-
28
- def md5_hash(path):
29
- with open(path, "rb") as f:
30
- content = f.read()
31
- return hashlib.md5(content).hexdigest()
32
-
33
-
34
- def get_ckpt_path(name, root, check=False):
35
- assert name in URL_MAP
36
- path = os.path.join(root, CKPT_MAP[name])
37
- if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38
- print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39
- download(URL_MAP[name], path)
40
- md5 = md5_hash(path)
41
- assert md5 == MD5_MAP[name], md5
42
- return path
43
-
44
-
45
- class ActNorm(nn.Module):
46
- def __init__(
47
- self, num_features, logdet=False, affine=True, allow_reverse_init=False
48
- ):
49
- assert affine
50
- super().__init__()
51
- self.logdet = logdet
52
- self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53
- self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54
- self.allow_reverse_init = allow_reverse_init
55
-
56
- self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57
-
58
- def initialize(self, input):
59
- with torch.no_grad():
60
- flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61
- mean = (
62
- flatten.mean(1)
63
- .unsqueeze(1)
64
- .unsqueeze(2)
65
- .unsqueeze(3)
66
- .permute(1, 0, 2, 3)
67
- )
68
- std = (
69
- flatten.std(1)
70
- .unsqueeze(1)
71
- .unsqueeze(2)
72
- .unsqueeze(3)
73
- .permute(1, 0, 2, 3)
74
- )
75
-
76
- self.loc.data.copy_(-mean)
77
- self.scale.data.copy_(1 / (std + 1e-6))
78
-
79
- def forward(self, input, reverse=False):
80
- if reverse:
81
- return self.reverse(input)
82
- if len(input.shape) == 2:
83
- input = input[:, :, None, None]
84
- squeeze = True
85
- else:
86
- squeeze = False
87
-
88
- _, _, height, width = input.shape
89
-
90
- if self.training and self.initialized.item() == 0:
91
- self.initialize(input)
92
- self.initialized.fill_(1)
93
-
94
- h = self.scale * (input + self.loc)
95
-
96
- if squeeze:
97
- h = h.squeeze(-1).squeeze(-1)
98
-
99
- if self.logdet:
100
- log_abs = torch.log(torch.abs(self.scale))
101
- logdet = height * width * torch.sum(log_abs)
102
- logdet = logdet * torch.ones(input.shape[0]).to(input)
103
- return h, logdet
104
-
105
- return h
106
-
107
- def reverse(self, output):
108
- if self.training and self.initialized.item() == 0:
109
- if not self.allow_reverse_init:
110
- raise RuntimeError(
111
- "Initializing ActNorm in reverse direction is "
112
- "disabled by default. Use allow_reverse_init=True to enable."
113
- )
114
- else:
115
- self.initialize(output)
116
- self.initialized.fill_(1)
117
-
118
- if len(output.shape) == 2:
119
- output = output[:, :, None, None]
120
- squeeze = True
121
- else:
122
- squeeze = False
123
-
124
- h = output / self.scale - self.loc
125
-
126
- if squeeze:
127
- h = h.squeeze(-1).squeeze(-1)
128
- return h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/lpips/vqperceptual.py DELETED
@@ -1,17 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
-
4
-
5
- def hinge_d_loss(logits_real, logits_fake):
6
- loss_real = torch.mean(F.relu(1.0 - logits_real))
7
- loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8
- d_loss = 0.5 * (loss_real + loss_fake)
9
- return d_loss
10
-
11
-
12
- def vanilla_d_loss(logits_real, logits_fake):
13
- d_loss = 0.5 * (
14
- torch.mean(torch.nn.functional.softplus(-logits_real))
15
- + torch.mean(torch.nn.functional.softplus(logits_fake))
16
- )
17
- return d_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/regularizers/__init__.py DELETED
@@ -1,31 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Tuple
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from ....modules.distributions.distributions import \
9
- DiagonalGaussianDistribution
10
- from .base import AbstractRegularizer
11
-
12
-
13
- class DiagonalGaussianRegularizer(AbstractRegularizer):
14
- def __init__(self, sample: bool = True):
15
- super().__init__()
16
- self.sample = sample
17
-
18
- def get_trainable_parameters(self) -> Any:
19
- yield from ()
20
-
21
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
22
- log = dict()
23
- posterior = DiagonalGaussianDistribution(z)
24
- if self.sample:
25
- z = posterior.sample()
26
- else:
27
- z = posterior.mode()
28
- kl_loss = posterior.kl()
29
- kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
30
- log["kl_loss"] = kl_loss
31
- return z, log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/regularizers/base.py DELETED
@@ -1,40 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Tuple
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import nn
7
-
8
-
9
- class AbstractRegularizer(nn.Module):
10
- def __init__(self):
11
- super().__init__()
12
-
13
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
14
- raise NotImplementedError()
15
-
16
- @abstractmethod
17
- def get_trainable_parameters(self) -> Any:
18
- raise NotImplementedError()
19
-
20
-
21
- class IdentityRegularizer(AbstractRegularizer):
22
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
23
- return z, dict()
24
-
25
- def get_trainable_parameters(self) -> Any:
26
- yield from ()
27
-
28
-
29
- def measure_perplexity(
30
- predicted_indices: torch.Tensor, num_centroids: int
31
- ) -> Tuple[torch.Tensor, torch.Tensor]:
32
- # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
33
- # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
34
- encodings = (
35
- F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
36
- )
37
- avg_probs = encodings.mean(0)
38
- perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
39
- cluster_use = torch.sum(avg_probs > 0)
40
- return perplexity, cluster_use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/regularizers/quantize.py DELETED
@@ -1,487 +0,0 @@
1
- import logging
2
- from abc import abstractmethod
3
- from typing import Dict, Iterator, Literal, Optional, Tuple, Union
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from einops import rearrange
10
- from torch import einsum
11
-
12
- from .base import AbstractRegularizer, measure_perplexity
13
-
14
- logpy = logging.getLogger(__name__)
15
-
16
-
17
- class AbstractQuantizer(AbstractRegularizer):
18
- def __init__(self):
19
- super().__init__()
20
- # Define these in your init
21
- # shape (N,)
22
- self.used: Optional[torch.Tensor]
23
- self.re_embed: int
24
- self.unknown_index: Union[Literal["random"], int]
25
-
26
- def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
27
- assert self.used is not None, "You need to define used indices for remap"
28
- ishape = inds.shape
29
- assert len(ishape) > 1
30
- inds = inds.reshape(ishape[0], -1)
31
- used = self.used.to(inds)
32
- match = (inds[:, :, None] == used[None, None, ...]).long()
33
- new = match.argmax(-1)
34
- unknown = match.sum(2) < 1
35
- if self.unknown_index == "random":
36
- new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
37
- device=new.device
38
- )
39
- else:
40
- new[unknown] = self.unknown_index
41
- return new.reshape(ishape)
42
-
43
- def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
44
- assert self.used is not None, "You need to define used indices for remap"
45
- ishape = inds.shape
46
- assert len(ishape) > 1
47
- inds = inds.reshape(ishape[0], -1)
48
- used = self.used.to(inds)
49
- if self.re_embed > self.used.shape[0]: # extra token
50
- inds[inds >= self.used.shape[0]] = 0 # simply set to zero
51
- back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
52
- return back.reshape(ishape)
53
-
54
- @abstractmethod
55
- def get_codebook_entry(
56
- self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
57
- ) -> torch.Tensor:
58
- raise NotImplementedError()
59
-
60
- def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
61
- yield from self.parameters()
62
-
63
-
64
- class GumbelQuantizer(AbstractQuantizer):
65
- """
66
- credit to @karpathy:
67
- https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
68
- Gumbel Softmax trick quantizer
69
- Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
70
- https://arxiv.org/abs/1611.01144
71
- """
72
-
73
- def __init__(
74
- self,
75
- num_hiddens: int,
76
- embedding_dim: int,
77
- n_embed: int,
78
- straight_through: bool = True,
79
- kl_weight: float = 5e-4,
80
- temp_init: float = 1.0,
81
- remap: Optional[str] = None,
82
- unknown_index: str = "random",
83
- loss_key: str = "loss/vq",
84
- ) -> None:
85
- super().__init__()
86
-
87
- self.loss_key = loss_key
88
- self.embedding_dim = embedding_dim
89
- self.n_embed = n_embed
90
-
91
- self.straight_through = straight_through
92
- self.temperature = temp_init
93
- self.kl_weight = kl_weight
94
-
95
- self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
96
- self.embed = nn.Embedding(n_embed, embedding_dim)
97
-
98
- self.remap = remap
99
- if self.remap is not None:
100
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
101
- self.re_embed = self.used.shape[0]
102
- else:
103
- self.used = None
104
- self.re_embed = n_embed
105
- if unknown_index == "extra":
106
- self.unknown_index = self.re_embed
107
- self.re_embed = self.re_embed + 1
108
- else:
109
- assert unknown_index == "random" or isinstance(
110
- unknown_index, int
111
- ), "unknown index needs to be 'random', 'extra' or any integer"
112
- self.unknown_index = unknown_index # "random" or "extra" or integer
113
- if self.remap is not None:
114
- logpy.info(
115
- f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
116
- f"Using {self.unknown_index} for unknown indices."
117
- )
118
-
119
- def forward(
120
- self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
121
- ) -> Tuple[torch.Tensor, Dict]:
122
- # force hard = True when we are in eval mode, as we must quantize.
123
- # actually, always true seems to work
124
- hard = self.straight_through if self.training else True
125
- temp = self.temperature if temp is None else temp
126
- out_dict = {}
127
- logits = self.proj(z)
128
- if self.remap is not None:
129
- # continue only with used logits
130
- full_zeros = torch.zeros_like(logits)
131
- logits = logits[:, self.used, ...]
132
-
133
- soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
134
- if self.remap is not None:
135
- # go back to all entries but unused set to zero
136
- full_zeros[:, self.used, ...] = soft_one_hot
137
- soft_one_hot = full_zeros
138
- z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
139
-
140
- # + kl divergence to the prior loss
141
- qy = F.softmax(logits, dim=1)
142
- diff = (
143
- self.kl_weight
144
- * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
145
- )
146
- out_dict[self.loss_key] = diff
147
-
148
- ind = soft_one_hot.argmax(dim=1)
149
- out_dict["indices"] = ind
150
- if self.remap is not None:
151
- ind = self.remap_to_used(ind)
152
-
153
- if return_logits:
154
- out_dict["logits"] = logits
155
-
156
- return z_q, out_dict
157
-
158
- def get_codebook_entry(self, indices, shape):
159
- # TODO: shape not yet optional
160
- b, h, w, c = shape
161
- assert b * h * w == indices.shape[0]
162
- indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
163
- if self.remap is not None:
164
- indices = self.unmap_to_all(indices)
165
- one_hot = (
166
- F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
167
- )
168
- z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
169
- return z_q
170
-
171
-
172
- class VectorQuantizer(AbstractQuantizer):
173
- """
174
- ____________________________________________
175
- Discretization bottleneck part of the VQ-VAE.
176
- Inputs:
177
- - n_e : number of embeddings
178
- - e_dim : dimension of embedding
179
- - beta : commitment cost used in loss term,
180
- beta * ||z_e(x)-sg[e]||^2
181
- _____________________________________________
182
- """
183
-
184
- def __init__(
185
- self,
186
- n_e: int,
187
- e_dim: int,
188
- beta: float = 0.25,
189
- remap: Optional[str] = None,
190
- unknown_index: str = "random",
191
- sane_index_shape: bool = False,
192
- log_perplexity: bool = False,
193
- embedding_weight_norm: bool = False,
194
- loss_key: str = "loss/vq",
195
- ):
196
- super().__init__()
197
- self.n_e = n_e
198
- self.e_dim = e_dim
199
- self.beta = beta
200
- self.loss_key = loss_key
201
-
202
- if not embedding_weight_norm:
203
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
204
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
205
- else:
206
- self.embedding = torch.nn.utils.weight_norm(
207
- nn.Embedding(self.n_e, self.e_dim), dim=1
208
- )
209
-
210
- self.remap = remap
211
- if self.remap is not None:
212
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
213
- self.re_embed = self.used.shape[0]
214
- else:
215
- self.used = None
216
- self.re_embed = n_e
217
- if unknown_index == "extra":
218
- self.unknown_index = self.re_embed
219
- self.re_embed = self.re_embed + 1
220
- else:
221
- assert unknown_index == "random" or isinstance(
222
- unknown_index, int
223
- ), "unknown index needs to be 'random', 'extra' or any integer"
224
- self.unknown_index = unknown_index # "random" or "extra" or integer
225
- if self.remap is not None:
226
- logpy.info(
227
- f"Remapping {self.n_e} indices to {self.re_embed} indices. "
228
- f"Using {self.unknown_index} for unknown indices."
229
- )
230
-
231
- self.sane_index_shape = sane_index_shape
232
- self.log_perplexity = log_perplexity
233
-
234
- def forward(
235
- self,
236
- z: torch.Tensor,
237
- ) -> Tuple[torch.Tensor, Dict]:
238
- do_reshape = z.ndim == 4
239
- if do_reshape:
240
- # # reshape z -> (batch, height, width, channel) and flatten
241
- z = rearrange(z, "b c h w -> b h w c").contiguous()
242
-
243
- else:
244
- assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
245
- z = z.contiguous()
246
-
247
- z_flattened = z.view(-1, self.e_dim)
248
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
249
-
250
- d = (
251
- torch.sum(z_flattened**2, dim=1, keepdim=True)
252
- + torch.sum(self.embedding.weight**2, dim=1)
253
- - 2
254
- * torch.einsum(
255
- "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
256
- )
257
- )
258
-
259
- min_encoding_indices = torch.argmin(d, dim=1)
260
- z_q = self.embedding(min_encoding_indices).view(z.shape)
261
- loss_dict = {}
262
- if self.log_perplexity:
263
- perplexity, cluster_usage = measure_perplexity(
264
- min_encoding_indices.detach(), self.n_e
265
- )
266
- loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
267
-
268
- # compute loss for embedding
269
- loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
270
- (z_q - z.detach()) ** 2
271
- )
272
- loss_dict[self.loss_key] = loss
273
-
274
- # preserve gradients
275
- z_q = z + (z_q - z).detach()
276
-
277
- # reshape back to match original input shape
278
- if do_reshape:
279
- z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
280
-
281
- if self.remap is not None:
282
- min_encoding_indices = min_encoding_indices.reshape(
283
- z.shape[0], -1
284
- ) # add batch axis
285
- min_encoding_indices = self.remap_to_used(min_encoding_indices)
286
- min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
287
-
288
- if self.sane_index_shape:
289
- if do_reshape:
290
- min_encoding_indices = min_encoding_indices.reshape(
291
- z_q.shape[0], z_q.shape[2], z_q.shape[3]
292
- )
293
- else:
294
- min_encoding_indices = rearrange(
295
- min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
296
- )
297
-
298
- loss_dict["min_encoding_indices"] = min_encoding_indices
299
-
300
- return z_q, loss_dict
301
-
302
- def get_codebook_entry(
303
- self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
304
- ) -> torch.Tensor:
305
- # shape specifying (batch, height, width, channel)
306
- if self.remap is not None:
307
- assert shape is not None, "Need to give shape for remap"
308
- indices = indices.reshape(shape[0], -1) # add batch axis
309
- indices = self.unmap_to_all(indices)
310
- indices = indices.reshape(-1) # flatten again
311
-
312
- # get quantized latent vectors
313
- z_q = self.embedding(indices)
314
-
315
- if shape is not None:
316
- z_q = z_q.view(shape)
317
- # reshape back to match original input shape
318
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
319
-
320
- return z_q
321
-
322
-
323
- class EmbeddingEMA(nn.Module):
324
- def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
325
- super().__init__()
326
- self.decay = decay
327
- self.eps = eps
328
- weight = torch.randn(num_tokens, codebook_dim)
329
- self.weight = nn.Parameter(weight, requires_grad=False)
330
- self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
331
- self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
332
- self.update = True
333
-
334
- def forward(self, embed_id):
335
- return F.embedding(embed_id, self.weight)
336
-
337
- def cluster_size_ema_update(self, new_cluster_size):
338
- self.cluster_size.data.mul_(self.decay).add_(
339
- new_cluster_size, alpha=1 - self.decay
340
- )
341
-
342
- def embed_avg_ema_update(self, new_embed_avg):
343
- self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
344
-
345
- def weight_update(self, num_tokens):
346
- n = self.cluster_size.sum()
347
- smoothed_cluster_size = (
348
- (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
349
- )
350
- # normalize embedding average with smoothed cluster size
351
- embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
352
- self.weight.data.copy_(embed_normalized)
353
-
354
-
355
- class EMAVectorQuantizer(AbstractQuantizer):
356
- def __init__(
357
- self,
358
- n_embed: int,
359
- embedding_dim: int,
360
- beta: float,
361
- decay: float = 0.99,
362
- eps: float = 1e-5,
363
- remap: Optional[str] = None,
364
- unknown_index: str = "random",
365
- loss_key: str = "loss/vq",
366
- ):
367
- super().__init__()
368
- self.codebook_dim = embedding_dim
369
- self.num_tokens = n_embed
370
- self.beta = beta
371
- self.loss_key = loss_key
372
-
373
- self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
374
-
375
- self.remap = remap
376
- if self.remap is not None:
377
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
378
- self.re_embed = self.used.shape[0]
379
- else:
380
- self.used = None
381
- self.re_embed = n_embed
382
- if unknown_index == "extra":
383
- self.unknown_index = self.re_embed
384
- self.re_embed = self.re_embed + 1
385
- else:
386
- assert unknown_index == "random" or isinstance(
387
- unknown_index, int
388
- ), "unknown index needs to be 'random', 'extra' or any integer"
389
- self.unknown_index = unknown_index # "random" or "extra" or integer
390
- if self.remap is not None:
391
- logpy.info(
392
- f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
393
- f"Using {self.unknown_index} for unknown indices."
394
- )
395
-
396
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
397
- # reshape z -> (batch, height, width, channel) and flatten
398
- # z, 'b c h w -> b h w c'
399
- z = rearrange(z, "b c h w -> b h w c")
400
- z_flattened = z.reshape(-1, self.codebook_dim)
401
-
402
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
403
- d = (
404
- z_flattened.pow(2).sum(dim=1, keepdim=True)
405
- + self.embedding.weight.pow(2).sum(dim=1)
406
- - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
407
- ) # 'n d -> d n'
408
-
409
- encoding_indices = torch.argmin(d, dim=1)
410
-
411
- z_q = self.embedding(encoding_indices).view(z.shape)
412
- encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
413
- avg_probs = torch.mean(encodings, dim=0)
414
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
415
-
416
- if self.training and self.embedding.update:
417
- # EMA cluster size
418
- encodings_sum = encodings.sum(0)
419
- self.embedding.cluster_size_ema_update(encodings_sum)
420
- # EMA embedding average
421
- embed_sum = encodings.transpose(0, 1) @ z_flattened
422
- self.embedding.embed_avg_ema_update(embed_sum)
423
- # normalize embed_avg and update weight
424
- self.embedding.weight_update(self.num_tokens)
425
-
426
- # compute loss for embedding
427
- loss = self.beta * F.mse_loss(z_q.detach(), z)
428
-
429
- # preserve gradients
430
- z_q = z + (z_q - z).detach()
431
-
432
- # reshape back to match original input shape
433
- # z_q, 'b h w c -> b c h w'
434
- z_q = rearrange(z_q, "b h w c -> b c h w")
435
-
436
- out_dict = {
437
- self.loss_key: loss,
438
- "encodings": encodings,
439
- "encoding_indices": encoding_indices,
440
- "perplexity": perplexity,
441
- }
442
-
443
- return z_q, out_dict
444
-
445
-
446
- class VectorQuantizerWithInputProjection(VectorQuantizer):
447
- def __init__(
448
- self,
449
- input_dim: int,
450
- n_codes: int,
451
- codebook_dim: int,
452
- beta: float = 1.0,
453
- output_dim: Optional[int] = None,
454
- **kwargs,
455
- ):
456
- super().__init__(n_codes, codebook_dim, beta, **kwargs)
457
- self.proj_in = nn.Linear(input_dim, codebook_dim)
458
- self.output_dim = output_dim
459
- if output_dim is not None:
460
- self.proj_out = nn.Linear(codebook_dim, output_dim)
461
- else:
462
- self.proj_out = nn.Identity()
463
-
464
- def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
465
- rearr = False
466
- in_shape = z.shape
467
-
468
- if z.ndim > 3:
469
- rearr = self.output_dim is not None
470
- z = rearrange(z, "b c ... -> b (...) c")
471
- z = self.proj_in(z)
472
- z_q, loss_dict = super().forward(z)
473
-
474
- z_q = self.proj_out(z_q)
475
- if rearr:
476
- if len(in_shape) == 4:
477
- z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
478
- elif len(in_shape) == 5:
479
- z_q = rearrange(
480
- z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
481
- )
482
- else:
483
- raise NotImplementedError(
484
- f"rearranging not available for {len(in_shape)}-dimensional input."
485
- )
486
-
487
- return z_q, loss_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/autoencoding/temporal_ae.py DELETED
@@ -1,349 +0,0 @@
1
- from typing import Callable, Iterable, Union
2
-
3
- import torch
4
- from einops import rearrange, repeat
5
-
6
- from sgm.modules.diffusionmodules.model import (
7
- XFORMERS_IS_AVAILABLE,
8
- AttnBlock,
9
- Decoder,
10
- MemoryEfficientAttnBlock,
11
- ResnetBlock,
12
- )
13
- from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
14
- from sgm.modules.video_attention import VideoTransformerBlock
15
- from sgm.util import partialclass
16
-
17
-
18
- class VideoResBlock(ResnetBlock):
19
- def __init__(
20
- self,
21
- out_channels,
22
- *args,
23
- dropout=0.0,
24
- video_kernel_size=3,
25
- alpha=0.0,
26
- merge_strategy="learned",
27
- **kwargs,
28
- ):
29
- super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
30
- if video_kernel_size is None:
31
- video_kernel_size = [3, 1, 1]
32
- self.time_stack = ResBlock(
33
- channels=out_channels,
34
- emb_channels=0,
35
- dropout=dropout,
36
- dims=3,
37
- use_scale_shift_norm=False,
38
- use_conv=False,
39
- up=False,
40
- down=False,
41
- kernel_size=video_kernel_size,
42
- use_checkpoint=False,
43
- skip_t_emb=True,
44
- )
45
-
46
- self.merge_strategy = merge_strategy
47
- if self.merge_strategy == "fixed":
48
- self.register_buffer("mix_factor", torch.Tensor([alpha]))
49
- elif self.merge_strategy == "learned":
50
- self.register_parameter(
51
- "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
52
- )
53
- else:
54
- raise ValueError(f"unknown merge strategy {self.merge_strategy}")
55
-
56
- def get_alpha(self, bs):
57
- if self.merge_strategy == "fixed":
58
- return self.mix_factor
59
- elif self.merge_strategy == "learned":
60
- return torch.sigmoid(self.mix_factor)
61
- else:
62
- raise NotImplementedError()
63
-
64
- def forward(self, x, temb, skip_video=False, timesteps=None):
65
- if timesteps is None:
66
- timesteps = self.timesteps
67
-
68
- b, c, h, w = x.shape
69
-
70
- x = super().forward(x, temb)
71
-
72
- if not skip_video:
73
- x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
74
-
75
- x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
76
-
77
- x = self.time_stack(x, temb)
78
-
79
- alpha = self.get_alpha(bs=b // timesteps)
80
- x = alpha * x + (1.0 - alpha) * x_mix
81
-
82
- x = rearrange(x, "b c t h w -> (b t) c h w")
83
- return x
84
-
85
-
86
- class AE3DConv(torch.nn.Conv2d):
87
- def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
88
- super().__init__(in_channels, out_channels, *args, **kwargs)
89
- if isinstance(video_kernel_size, Iterable):
90
- padding = [int(k // 2) for k in video_kernel_size]
91
- else:
92
- padding = int(video_kernel_size // 2)
93
-
94
- self.time_mix_conv = torch.nn.Conv3d(
95
- in_channels=out_channels,
96
- out_channels=out_channels,
97
- kernel_size=video_kernel_size,
98
- padding=padding,
99
- )
100
-
101
- def forward(self, input, timesteps, skip_video=False):
102
- x = super().forward(input)
103
- if skip_video:
104
- return x
105
- x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
106
- x = self.time_mix_conv(x)
107
- return rearrange(x, "b c t h w -> (b t) c h w")
108
-
109
-
110
- class VideoBlock(AttnBlock):
111
- def __init__(
112
- self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
113
- ):
114
- super().__init__(in_channels)
115
- # no context, single headed, as in base class
116
- self.time_mix_block = VideoTransformerBlock(
117
- dim=in_channels,
118
- n_heads=1,
119
- d_head=in_channels,
120
- checkpoint=False,
121
- ff_in=True,
122
- attn_mode="softmax",
123
- )
124
-
125
- time_embed_dim = self.in_channels * 4
126
- self.video_time_embed = torch.nn.Sequential(
127
- torch.nn.Linear(self.in_channels, time_embed_dim),
128
- torch.nn.SiLU(),
129
- torch.nn.Linear(time_embed_dim, self.in_channels),
130
- )
131
-
132
- self.merge_strategy = merge_strategy
133
- if self.merge_strategy == "fixed":
134
- self.register_buffer("mix_factor", torch.Tensor([alpha]))
135
- elif self.merge_strategy == "learned":
136
- self.register_parameter(
137
- "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
138
- )
139
- else:
140
- raise ValueError(f"unknown merge strategy {self.merge_strategy}")
141
-
142
- def forward(self, x, timesteps, skip_video=False):
143
- if skip_video:
144
- return super().forward(x)
145
-
146
- x_in = x
147
- x = self.attention(x)
148
- h, w = x.shape[2:]
149
- x = rearrange(x, "b c h w -> b (h w) c")
150
-
151
- x_mix = x
152
- num_frames = torch.arange(timesteps, device=x.device)
153
- num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
154
- num_frames = rearrange(num_frames, "b t -> (b t)")
155
- t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
156
- emb = self.video_time_embed(t_emb) # b, n_channels
157
- emb = emb[:, None, :]
158
- x_mix = x_mix + emb
159
-
160
- alpha = self.get_alpha()
161
- x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
162
- x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
163
-
164
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
165
- x = self.proj_out(x)
166
-
167
- return x_in + x
168
-
169
- def get_alpha(
170
- self,
171
- ):
172
- if self.merge_strategy == "fixed":
173
- return self.mix_factor
174
- elif self.merge_strategy == "learned":
175
- return torch.sigmoid(self.mix_factor)
176
- else:
177
- raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
178
-
179
-
180
- class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
181
- def __init__(
182
- self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
183
- ):
184
- super().__init__(in_channels)
185
- # no context, single headed, as in base class
186
- self.time_mix_block = VideoTransformerBlock(
187
- dim=in_channels,
188
- n_heads=1,
189
- d_head=in_channels,
190
- checkpoint=False,
191
- ff_in=True,
192
- attn_mode="softmax-xformers",
193
- )
194
-
195
- time_embed_dim = self.in_channels * 4
196
- self.video_time_embed = torch.nn.Sequential(
197
- torch.nn.Linear(self.in_channels, time_embed_dim),
198
- torch.nn.SiLU(),
199
- torch.nn.Linear(time_embed_dim, self.in_channels),
200
- )
201
-
202
- self.merge_strategy = merge_strategy
203
- if self.merge_strategy == "fixed":
204
- self.register_buffer("mix_factor", torch.Tensor([alpha]))
205
- elif self.merge_strategy == "learned":
206
- self.register_parameter(
207
- "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
208
- )
209
- else:
210
- raise ValueError(f"unknown merge strategy {self.merge_strategy}")
211
-
212
- def forward(self, x, timesteps, skip_time_block=False):
213
- if skip_time_block:
214
- return super().forward(x)
215
-
216
- x_in = x
217
- x = self.attention(x)
218
- h, w = x.shape[2:]
219
- x = rearrange(x, "b c h w -> b (h w) c")
220
-
221
- x_mix = x
222
- num_frames = torch.arange(timesteps, device=x.device)
223
- num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
224
- num_frames = rearrange(num_frames, "b t -> (b t)")
225
- t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
226
- emb = self.video_time_embed(t_emb) # b, n_channels
227
- emb = emb[:, None, :]
228
- x_mix = x_mix + emb
229
-
230
- alpha = self.get_alpha()
231
- x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
232
- x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
233
-
234
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
235
- x = self.proj_out(x)
236
-
237
- return x_in + x
238
-
239
- def get_alpha(
240
- self,
241
- ):
242
- if self.merge_strategy == "fixed":
243
- return self.mix_factor
244
- elif self.merge_strategy == "learned":
245
- return torch.sigmoid(self.mix_factor)
246
- else:
247
- raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
248
-
249
-
250
- def make_time_attn(
251
- in_channels,
252
- attn_type="vanilla",
253
- attn_kwargs=None,
254
- alpha: float = 0,
255
- merge_strategy: str = "learned",
256
- ):
257
- assert attn_type in [
258
- "vanilla",
259
- "vanilla-xformers",
260
- ], f"attn_type {attn_type} not supported for spatio-temporal attention"
261
- print(
262
- f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
263
- )
264
- if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
265
- print(
266
- f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
267
- f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
268
- )
269
- attn_type = "vanilla"
270
-
271
- if attn_type == "vanilla":
272
- assert attn_kwargs is None
273
- return partialclass(
274
- VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
275
- )
276
- elif attn_type == "vanilla-xformers":
277
- print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
278
- return partialclass(
279
- MemoryEfficientVideoBlock,
280
- in_channels,
281
- alpha=alpha,
282
- merge_strategy=merge_strategy,
283
- )
284
- else:
285
- return NotImplementedError()
286
-
287
-
288
- class Conv2DWrapper(torch.nn.Conv2d):
289
- def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
290
- return super().forward(input)
291
-
292
-
293
- class VideoDecoder(Decoder):
294
- available_time_modes = ["all", "conv-only", "attn-only"]
295
-
296
- def __init__(
297
- self,
298
- *args,
299
- video_kernel_size: Union[int, list] = 3,
300
- alpha: float = 0.0,
301
- merge_strategy: str = "learned",
302
- time_mode: str = "conv-only",
303
- **kwargs,
304
- ):
305
- self.video_kernel_size = video_kernel_size
306
- self.alpha = alpha
307
- self.merge_strategy = merge_strategy
308
- self.time_mode = time_mode
309
- assert (
310
- self.time_mode in self.available_time_modes
311
- ), f"time_mode parameter has to be in {self.available_time_modes}"
312
- super().__init__(*args, **kwargs)
313
-
314
- def get_last_layer(self, skip_time_mix=False, **kwargs):
315
- if self.time_mode == "attn-only":
316
- raise NotImplementedError("TODO")
317
- else:
318
- return (
319
- self.conv_out.time_mix_conv.weight
320
- if not skip_time_mix
321
- else self.conv_out.weight
322
- )
323
-
324
- def _make_attn(self) -> Callable:
325
- if self.time_mode not in ["conv-only", "only-last-conv"]:
326
- return partialclass(
327
- make_time_attn,
328
- alpha=self.alpha,
329
- merge_strategy=self.merge_strategy,
330
- )
331
- else:
332
- return super()._make_attn()
333
-
334
- def _make_conv(self) -> Callable:
335
- if self.time_mode != "attn-only":
336
- return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
337
- else:
338
- return Conv2DWrapper
339
-
340
- def _make_resblock(self) -> Callable:
341
- if self.time_mode not in ["attn-only", "only-last-conv"]:
342
- return partialclass(
343
- VideoResBlock,
344
- video_kernel_size=self.video_kernel_size,
345
- alpha=self.alpha,
346
- merge_strategy=self.merge_strategy,
347
- )
348
- else:
349
- return super()._make_resblock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/__init__.py DELETED
File without changes
sgm/modules/diffusionmodules/denoiser.py DELETED
@@ -1,75 +0,0 @@
1
- from typing import Dict, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
- from ...util import append_dims, instantiate_from_config
7
- from .denoiser_scaling import DenoiserScaling
8
- from .discretizer import Discretization
9
-
10
-
11
- class Denoiser(nn.Module):
12
- def __init__(self, scaling_config: Dict):
13
- super().__init__()
14
-
15
- self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
16
-
17
- def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
18
- return sigma
19
-
20
- def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
21
- return c_noise
22
-
23
- def forward(
24
- self,
25
- network: nn.Module,
26
- input: torch.Tensor,
27
- sigma: torch.Tensor,
28
- cond: Dict,
29
- **additional_model_inputs,
30
- ) -> torch.Tensor:
31
- sigma = self.possibly_quantize_sigma(sigma)
32
- sigma_shape = sigma.shape
33
- sigma = append_dims(sigma, input.ndim)
34
- c_skip, c_out, c_in, c_noise = self.scaling(sigma)
35
- c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
36
- return (
37
- network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
38
- + input * c_skip
39
- )
40
-
41
-
42
- class DiscreteDenoiser(Denoiser):
43
- def __init__(
44
- self,
45
- scaling_config: Dict,
46
- num_idx: int,
47
- discretization_config: Dict,
48
- do_append_zero: bool = False,
49
- quantize_c_noise: bool = True,
50
- flip: bool = True,
51
- ):
52
- super().__init__(scaling_config)
53
- self.discretization: Discretization = instantiate_from_config(
54
- discretization_config
55
- )
56
- sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
57
- self.register_buffer("sigmas", sigmas)
58
- self.quantize_c_noise = quantize_c_noise
59
- self.num_idx = num_idx
60
-
61
- def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
62
- dists = sigma - self.sigmas[:, None]
63
- return dists.abs().argmin(dim=0).view(sigma.shape)
64
-
65
- def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
66
- return self.sigmas[idx]
67
-
68
- def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
69
- return self.idx_to_sigma(self.sigma_to_idx(sigma))
70
-
71
- def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
72
- if self.quantize_c_noise:
73
- return self.sigma_to_idx(c_noise)
74
- else:
75
- return c_noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/denoiser_scaling.py DELETED
@@ -1,59 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Tuple
3
-
4
- import torch
5
-
6
-
7
- class DenoiserScaling(ABC):
8
- @abstractmethod
9
- def __call__(
10
- self, sigma: torch.Tensor
11
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
12
- pass
13
-
14
-
15
- class EDMScaling:
16
- def __init__(self, sigma_data: float = 0.5):
17
- self.sigma_data = sigma_data
18
-
19
- def __call__(
20
- self, sigma: torch.Tensor
21
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
22
- c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
23
- c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
24
- c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
25
- c_noise = 0.25 * sigma.log()
26
- return c_skip, c_out, c_in, c_noise
27
-
28
-
29
- class EpsScaling:
30
- def __call__(
31
- self, sigma: torch.Tensor
32
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
33
- c_skip = torch.ones_like(sigma, device=sigma.device)
34
- c_out = -sigma
35
- c_in = 1 / (sigma**2 + 1.0) ** 0.5
36
- c_noise = sigma.clone()
37
- return c_skip, c_out, c_in, c_noise
38
-
39
-
40
- class VScaling:
41
- def __call__(
42
- self, sigma: torch.Tensor
43
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
44
- c_skip = 1.0 / (sigma**2 + 1.0)
45
- c_out = -sigma / (sigma**2 + 1.0) ** 0.5
46
- c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
47
- c_noise = sigma.clone()
48
- return c_skip, c_out, c_in, c_noise
49
-
50
-
51
- class VScalingWithEDMcNoise(DenoiserScaling):
52
- def __call__(
53
- self, sigma: torch.Tensor
54
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
55
- c_skip = 1.0 / (sigma**2 + 1.0)
56
- c_out = -sigma / (sigma**2 + 1.0) ** 0.5
57
- c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
58
- c_noise = 0.25 * sigma.log()
59
- return c_skip, c_out, c_in, c_noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/denoiser_weighting.py DELETED
@@ -1,24 +0,0 @@
1
- import torch
2
-
3
-
4
- class UnitWeighting:
5
- def __call__(self, sigma):
6
- return torch.ones_like(sigma, device=sigma.device)
7
-
8
-
9
- class EDMWeighting:
10
- def __init__(self, sigma_data=0.5):
11
- self.sigma_data = sigma_data
12
-
13
- def __call__(self, sigma):
14
- return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
15
-
16
-
17
- class VWeighting(EDMWeighting):
18
- def __init__(self):
19
- super().__init__(sigma_data=1.0)
20
-
21
-
22
- class EpsWeighting:
23
- def __call__(self, sigma):
24
- return sigma**-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/discretizer.py DELETED
@@ -1,69 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
-
4
- import numpy as np
5
- import torch
6
-
7
- from ...modules.diffusionmodules.util import make_beta_schedule
8
- from ...util import append_zero
9
-
10
-
11
- def generate_roughly_equally_spaced_steps(
12
- num_substeps: int, max_step: int
13
- ) -> np.ndarray:
14
- return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
15
-
16
-
17
- class Discretization:
18
- def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
19
- sigmas = self.get_sigmas(n, device=device)
20
- sigmas = append_zero(sigmas) if do_append_zero else sigmas
21
- return sigmas if not flip else torch.flip(sigmas, (0,))
22
-
23
- @abstractmethod
24
- def get_sigmas(self, n, device):
25
- pass
26
-
27
-
28
- class EDMDiscretization(Discretization):
29
- def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
30
- self.sigma_min = sigma_min
31
- self.sigma_max = sigma_max
32
- self.rho = rho
33
-
34
- def get_sigmas(self, n, device="cpu"):
35
- ramp = torch.linspace(0, 1, n, device=device)
36
- min_inv_rho = self.sigma_min ** (1 / self.rho)
37
- max_inv_rho = self.sigma_max ** (1 / self.rho)
38
- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
39
- return sigmas
40
-
41
-
42
- class LegacyDDPMDiscretization(Discretization):
43
- def __init__(
44
- self,
45
- linear_start=0.00085,
46
- linear_end=0.0120,
47
- num_timesteps=1000,
48
- ):
49
- super().__init__()
50
- self.num_timesteps = num_timesteps
51
- betas = make_beta_schedule(
52
- "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
53
- )
54
- alphas = 1.0 - betas
55
- self.alphas_cumprod = np.cumprod(alphas, axis=0)
56
- self.to_torch = partial(torch.tensor, dtype=torch.float32)
57
-
58
- def get_sigmas(self, n, device="cpu"):
59
- if n < self.num_timesteps:
60
- timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
61
- alphas_cumprod = self.alphas_cumprod[timesteps]
62
- elif n == self.num_timesteps:
63
- alphas_cumprod = self.alphas_cumprod
64
- else:
65
- raise ValueError
66
-
67
- to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
68
- sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
69
- return torch.flip(sigmas, (0,))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/guiders.py DELETED
@@ -1,99 +0,0 @@
1
- import logging
2
- from abc import ABC, abstractmethod
3
- from typing import Dict, List, Optional, Tuple, Union
4
-
5
- import torch
6
- from einops import rearrange, repeat
7
-
8
- from ...util import append_dims, default
9
-
10
- logpy = logging.getLogger(__name__)
11
-
12
-
13
- class Guider(ABC):
14
- @abstractmethod
15
- def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
16
- pass
17
-
18
- def prepare_inputs(
19
- self, x: torch.Tensor, s: float, c: Dict, uc: Dict
20
- ) -> Tuple[torch.Tensor, float, Dict]:
21
- pass
22
-
23
-
24
- class VanillaCFG(Guider):
25
- def __init__(self, scale: float):
26
- self.scale = scale
27
-
28
- def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
29
- x_u, x_c = x.chunk(2)
30
- x_pred = x_u + self.scale * (x_c - x_u)
31
- return x_pred
32
-
33
- def prepare_inputs(self, x, s, c, uc):
34
- c_out = dict()
35
-
36
- for k in c:
37
- if k in ["vector", "crossattn", "concat"]:
38
- c_out[k] = torch.cat((uc[k], c[k]), 0)
39
- else:
40
- assert c[k] == uc[k]
41
- c_out[k] = c[k]
42
- return torch.cat([x] * 2), torch.cat([s] * 2), c_out
43
-
44
-
45
- class IdentityGuider(Guider):
46
- def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
47
- return x
48
-
49
- def prepare_inputs(
50
- self, x: torch.Tensor, s: float, c: Dict, uc: Dict
51
- ) -> Tuple[torch.Tensor, float, Dict]:
52
- c_out = dict()
53
-
54
- for k in c:
55
- c_out[k] = c[k]
56
-
57
- return x, s, c_out
58
-
59
-
60
- class LinearPredictionGuider(Guider):
61
- def __init__(
62
- self,
63
- max_scale: float,
64
- num_frames: int,
65
- min_scale: float = 1.0,
66
- additional_cond_keys: Optional[Union[List[str], str]] = None,
67
- ):
68
- self.min_scale = min_scale
69
- self.max_scale = max_scale
70
- self.num_frames = num_frames
71
- self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
72
-
73
- additional_cond_keys = default(additional_cond_keys, [])
74
- if isinstance(additional_cond_keys, str):
75
- additional_cond_keys = [additional_cond_keys]
76
- self.additional_cond_keys = additional_cond_keys
77
-
78
- def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
79
- x_u, x_c = x.chunk(2)
80
-
81
- x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
82
- x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
83
- scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
84
- scale = append_dims(scale, x_u.ndim).to(x_u.device)
85
-
86
- return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
87
-
88
- def prepare_inputs(
89
- self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
90
- ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
91
- c_out = dict()
92
-
93
- for k in c:
94
- if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
95
- c_out[k] = torch.cat((uc[k], c[k]), 0)
96
- else:
97
- assert c[k] == uc[k]
98
- c_out[k] = c[k]
99
- return torch.cat([x] * 2), torch.cat([s] * 2), c_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/loss.py DELETED
@@ -1,105 +0,0 @@
1
- from typing import Dict, List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
- from ...modules.autoencoding.lpips.loss.lpips import LPIPS
7
- from ...modules.encoders.modules import GeneralConditioner
8
- from ...util import append_dims, instantiate_from_config
9
- from .denoiser import Denoiser
10
-
11
-
12
- class StandardDiffusionLoss(nn.Module):
13
- def __init__(
14
- self,
15
- sigma_sampler_config: dict,
16
- loss_weighting_config: dict,
17
- loss_type: str = "l2",
18
- offset_noise_level: float = 0.0,
19
- batch2model_keys: Optional[Union[str, List[str]]] = None,
20
- ):
21
- super().__init__()
22
-
23
- assert loss_type in ["l2", "l1", "lpips"]
24
-
25
- self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
26
- self.loss_weighting = instantiate_from_config(loss_weighting_config)
27
-
28
- self.loss_type = loss_type
29
- self.offset_noise_level = offset_noise_level
30
-
31
- if loss_type == "lpips":
32
- self.lpips = LPIPS().eval()
33
-
34
- if not batch2model_keys:
35
- batch2model_keys = []
36
-
37
- if isinstance(batch2model_keys, str):
38
- batch2model_keys = [batch2model_keys]
39
-
40
- self.batch2model_keys = set(batch2model_keys)
41
-
42
- def get_noised_input(
43
- self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
44
- ) -> torch.Tensor:
45
- noised_input = input + noise * sigmas_bc
46
- return noised_input
47
-
48
- def forward(
49
- self,
50
- network: nn.Module,
51
- denoiser: Denoiser,
52
- conditioner: GeneralConditioner,
53
- input: torch.Tensor,
54
- batch: Dict,
55
- ) -> torch.Tensor:
56
- cond = conditioner(batch)
57
- return self._forward(network, denoiser, cond, input, batch)
58
-
59
- def _forward(
60
- self,
61
- network: nn.Module,
62
- denoiser: Denoiser,
63
- cond: Dict,
64
- input: torch.Tensor,
65
- batch: Dict,
66
- ) -> Tuple[torch.Tensor, Dict]:
67
- additional_model_inputs = {
68
- key: batch[key] for key in self.batch2model_keys.intersection(batch)
69
- }
70
- sigmas = self.sigma_sampler(input.shape[0]).to(input)
71
-
72
- noise = torch.randn_like(input)
73
- if self.offset_noise_level > 0.0:
74
- offset_shape = (
75
- (input.shape[0], 1, input.shape[2])
76
- if self.n_frames is not None
77
- else (input.shape[0], input.shape[1])
78
- )
79
- noise = noise + self.offset_noise_level * append_dims(
80
- torch.randn(offset_shape, device=input.device),
81
- input.ndim,
82
- )
83
- sigmas_bc = append_dims(sigmas, input.ndim)
84
- noised_input = self.get_noised_input(sigmas_bc, noise, input)
85
-
86
- model_output = denoiser(
87
- network, noised_input, sigmas, cond, **additional_model_inputs
88
- )
89
- w = append_dims(self.loss_weighting(sigmas), input.ndim)
90
- return self.get_loss(model_output, input, w)
91
-
92
- def get_loss(self, model_output, target, w):
93
- if self.loss_type == "l2":
94
- return torch.mean(
95
- (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
96
- )
97
- elif self.loss_type == "l1":
98
- return torch.mean(
99
- (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
100
- )
101
- elif self.loss_type == "lpips":
102
- loss = self.lpips(model_output, target).reshape(-1)
103
- return loss
104
- else:
105
- raise NotImplementedError(f"Unknown loss type {self.loss_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/loss_weighting.py DELETED
@@ -1,32 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- import torch
4
-
5
-
6
- class DiffusionLossWeighting(ABC):
7
- @abstractmethod
8
- def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
9
- pass
10
-
11
-
12
- class UnitWeighting(DiffusionLossWeighting):
13
- def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
14
- return torch.ones_like(sigma, device=sigma.device)
15
-
16
-
17
- class EDMWeighting(DiffusionLossWeighting):
18
- def __init__(self, sigma_data: float = 0.5):
19
- self.sigma_data = sigma_data
20
-
21
- def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
22
- return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
23
-
24
-
25
- class VWeighting(EDMWeighting):
26
- def __init__(self):
27
- super().__init__(sigma_data=1.0)
28
-
29
-
30
- class EpsWeighting(DiffusionLossWeighting):
31
- def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
32
- return sigma**-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/model.py DELETED
@@ -1,748 +0,0 @@
1
- # pytorch_diffusion + derived encoder decoder
2
- import logging
3
- import math
4
- from typing import Any, Callable, Optional
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- from einops import rearrange
10
- from packaging import version
11
-
12
- logpy = logging.getLogger(__name__)
13
-
14
- try:
15
- import xformers
16
- import xformers.ops
17
-
18
- XFORMERS_IS_AVAILABLE = True
19
- except:
20
- XFORMERS_IS_AVAILABLE = False
21
- logpy.warning("no module 'xformers'. Processing without...")
22
-
23
- from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
24
-
25
-
26
- def get_timestep_embedding(timesteps, embedding_dim):
27
- """
28
- This matches the implementation in Denoising Diffusion Probabilistic Models:
29
- From Fairseq.
30
- Build sinusoidal embeddings.
31
- This matches the implementation in tensor2tensor, but differs slightly
32
- from the description in Section 3.5 of "Attention Is All You Need".
33
- """
34
- assert len(timesteps.shape) == 1
35
-
36
- half_dim = embedding_dim // 2
37
- emb = math.log(10000) / (half_dim - 1)
38
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
39
- emb = emb.to(device=timesteps.device)
40
- emb = timesteps.float()[:, None] * emb[None, :]
41
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
42
- if embedding_dim % 2 == 1: # zero pad
43
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
44
- return emb
45
-
46
-
47
- def nonlinearity(x):
48
- # swish
49
- return x * torch.sigmoid(x)
50
-
51
-
52
- def Normalize(in_channels, num_groups=32):
53
- return torch.nn.GroupNorm(
54
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
55
- )
56
-
57
-
58
- class Upsample(nn.Module):
59
- def __init__(self, in_channels, with_conv):
60
- super().__init__()
61
- self.with_conv = with_conv
62
- if self.with_conv:
63
- self.conv = torch.nn.Conv2d(
64
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
65
- )
66
-
67
- def forward(self, x):
68
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
69
- if self.with_conv:
70
- x = self.conv(x)
71
- return x
72
-
73
-
74
- class Downsample(nn.Module):
75
- def __init__(self, in_channels, with_conv):
76
- super().__init__()
77
- self.with_conv = with_conv
78
- if self.with_conv:
79
- # no asymmetric padding in torch conv, must do it ourselves
80
- self.conv = torch.nn.Conv2d(
81
- in_channels, in_channels, kernel_size=3, stride=2, padding=0
82
- )
83
-
84
- def forward(self, x):
85
- if self.with_conv:
86
- pad = (0, 1, 0, 1)
87
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
88
- x = self.conv(x)
89
- else:
90
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
91
- return x
92
-
93
-
94
- class ResnetBlock(nn.Module):
95
- def __init__(
96
- self,
97
- *,
98
- in_channels,
99
- out_channels=None,
100
- conv_shortcut=False,
101
- dropout,
102
- temb_channels=512,
103
- ):
104
- super().__init__()
105
- self.in_channels = in_channels
106
- out_channels = in_channels if out_channels is None else out_channels
107
- self.out_channels = out_channels
108
- self.use_conv_shortcut = conv_shortcut
109
-
110
- self.norm1 = Normalize(in_channels)
111
- self.conv1 = torch.nn.Conv2d(
112
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
113
- )
114
- if temb_channels > 0:
115
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
116
- self.norm2 = Normalize(out_channels)
117
- self.dropout = torch.nn.Dropout(dropout)
118
- self.conv2 = torch.nn.Conv2d(
119
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
120
- )
121
- if self.in_channels != self.out_channels:
122
- if self.use_conv_shortcut:
123
- self.conv_shortcut = torch.nn.Conv2d(
124
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
125
- )
126
- else:
127
- self.nin_shortcut = torch.nn.Conv2d(
128
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
129
- )
130
-
131
- def forward(self, x, temb):
132
- h = x
133
- h = self.norm1(h)
134
- h = nonlinearity(h)
135
- h = self.conv1(h)
136
-
137
- if temb is not None:
138
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
139
-
140
- h = self.norm2(h)
141
- h = nonlinearity(h)
142
- h = self.dropout(h)
143
- h = self.conv2(h)
144
-
145
- if self.in_channels != self.out_channels:
146
- if self.use_conv_shortcut:
147
- x = self.conv_shortcut(x)
148
- else:
149
- x = self.nin_shortcut(x)
150
-
151
- return x + h
152
-
153
-
154
- class LinAttnBlock(LinearAttention):
155
- """to match AttnBlock usage"""
156
-
157
- def __init__(self, in_channels):
158
- super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
159
-
160
-
161
- class AttnBlock(nn.Module):
162
- def __init__(self, in_channels):
163
- super().__init__()
164
- self.in_channels = in_channels
165
-
166
- self.norm = Normalize(in_channels)
167
- self.q = torch.nn.Conv2d(
168
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
169
- )
170
- self.k = torch.nn.Conv2d(
171
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
172
- )
173
- self.v = torch.nn.Conv2d(
174
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
175
- )
176
- self.proj_out = torch.nn.Conv2d(
177
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
178
- )
179
-
180
- def attention(self, h_: torch.Tensor) -> torch.Tensor:
181
- h_ = self.norm(h_)
182
- q = self.q(h_)
183
- k = self.k(h_)
184
- v = self.v(h_)
185
-
186
- b, c, h, w = q.shape
187
- q, k, v = map(
188
- lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
189
- )
190
- h_ = torch.nn.functional.scaled_dot_product_attention(
191
- q, k, v
192
- ) # scale is dim ** -0.5 per default
193
- # compute attention
194
-
195
- return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
196
-
197
- def forward(self, x, **kwargs):
198
- h_ = x
199
- h_ = self.attention(h_)
200
- h_ = self.proj_out(h_)
201
- return x + h_
202
-
203
-
204
- class MemoryEfficientAttnBlock(nn.Module):
205
- """
206
- Uses xformers efficient implementation,
207
- see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
208
- Note: this is a single-head self-attention operation
209
- """
210
-
211
- #
212
- def __init__(self, in_channels):
213
- super().__init__()
214
- self.in_channels = in_channels
215
-
216
- self.norm = Normalize(in_channels)
217
- self.q = torch.nn.Conv2d(
218
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
219
- )
220
- self.k = torch.nn.Conv2d(
221
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
222
- )
223
- self.v = torch.nn.Conv2d(
224
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
225
- )
226
- self.proj_out = torch.nn.Conv2d(
227
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
228
- )
229
- self.attention_op: Optional[Any] = None
230
-
231
- def attention(self, h_: torch.Tensor) -> torch.Tensor:
232
- h_ = self.norm(h_)
233
- q = self.q(h_)
234
- k = self.k(h_)
235
- v = self.v(h_)
236
-
237
- # compute attention
238
- B, C, H, W = q.shape
239
- q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
240
-
241
- q, k, v = map(
242
- lambda t: t.unsqueeze(3)
243
- .reshape(B, t.shape[1], 1, C)
244
- .permute(0, 2, 1, 3)
245
- .reshape(B * 1, t.shape[1], C)
246
- .contiguous(),
247
- (q, k, v),
248
- )
249
- out = xformers.ops.memory_efficient_attention(
250
- q, k, v, attn_bias=None, op=self.attention_op
251
- )
252
-
253
- out = (
254
- out.unsqueeze(0)
255
- .reshape(B, 1, out.shape[1], C)
256
- .permute(0, 2, 1, 3)
257
- .reshape(B, out.shape[1], C)
258
- )
259
- return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
260
-
261
- def forward(self, x, **kwargs):
262
- h_ = x
263
- h_ = self.attention(h_)
264
- h_ = self.proj_out(h_)
265
- return x + h_
266
-
267
-
268
- class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
269
- def forward(self, x, context=None, mask=None, **unused_kwargs):
270
- b, c, h, w = x.shape
271
- x = rearrange(x, "b c h w -> b (h w) c")
272
- out = super().forward(x, context=context, mask=mask)
273
- out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
274
- return x + out
275
-
276
-
277
- def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
278
- assert attn_type in [
279
- "vanilla",
280
- "vanilla-xformers",
281
- "memory-efficient-cross-attn",
282
- "linear",
283
- "none",
284
- ], f"attn_type {attn_type} unknown"
285
- if (
286
- version.parse(torch.__version__) < version.parse("2.0.0")
287
- and attn_type != "none"
288
- ):
289
- assert XFORMERS_IS_AVAILABLE, (
290
- f"We do not support vanilla attention in {torch.__version__} anymore, "
291
- f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
292
- )
293
- attn_type = "vanilla-xformers"
294
- logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
295
- if attn_type == "vanilla":
296
- assert attn_kwargs is None
297
- return AttnBlock(in_channels)
298
- elif attn_type == "vanilla-xformers":
299
- logpy.info(
300
- f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
301
- )
302
- return MemoryEfficientAttnBlock(in_channels)
303
- elif type == "memory-efficient-cross-attn":
304
- attn_kwargs["query_dim"] = in_channels
305
- return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
306
- elif attn_type == "none":
307
- return nn.Identity(in_channels)
308
- else:
309
- return LinAttnBlock(in_channels)
310
-
311
-
312
- class Model(nn.Module):
313
- def __init__(
314
- self,
315
- *,
316
- ch,
317
- out_ch,
318
- ch_mult=(1, 2, 4, 8),
319
- num_res_blocks,
320
- attn_resolutions,
321
- dropout=0.0,
322
- resamp_with_conv=True,
323
- in_channels,
324
- resolution,
325
- use_timestep=True,
326
- use_linear_attn=False,
327
- attn_type="vanilla",
328
- ):
329
- super().__init__()
330
- if use_linear_attn:
331
- attn_type = "linear"
332
- self.ch = ch
333
- self.temb_ch = self.ch * 4
334
- self.num_resolutions = len(ch_mult)
335
- self.num_res_blocks = num_res_blocks
336
- self.resolution = resolution
337
- self.in_channels = in_channels
338
-
339
- self.use_timestep = use_timestep
340
- if self.use_timestep:
341
- # timestep embedding
342
- self.temb = nn.Module()
343
- self.temb.dense = nn.ModuleList(
344
- [
345
- torch.nn.Linear(self.ch, self.temb_ch),
346
- torch.nn.Linear(self.temb_ch, self.temb_ch),
347
- ]
348
- )
349
-
350
- # downsampling
351
- self.conv_in = torch.nn.Conv2d(
352
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
353
- )
354
-
355
- curr_res = resolution
356
- in_ch_mult = (1,) + tuple(ch_mult)
357
- self.down = nn.ModuleList()
358
- for i_level in range(self.num_resolutions):
359
- block = nn.ModuleList()
360
- attn = nn.ModuleList()
361
- block_in = ch * in_ch_mult[i_level]
362
- block_out = ch * ch_mult[i_level]
363
- for i_block in range(self.num_res_blocks):
364
- block.append(
365
- ResnetBlock(
366
- in_channels=block_in,
367
- out_channels=block_out,
368
- temb_channels=self.temb_ch,
369
- dropout=dropout,
370
- )
371
- )
372
- block_in = block_out
373
- if curr_res in attn_resolutions:
374
- attn.append(make_attn(block_in, attn_type=attn_type))
375
- down = nn.Module()
376
- down.block = block
377
- down.attn = attn
378
- if i_level != self.num_resolutions - 1:
379
- down.downsample = Downsample(block_in, resamp_with_conv)
380
- curr_res = curr_res // 2
381
- self.down.append(down)
382
-
383
- # middle
384
- self.mid = nn.Module()
385
- self.mid.block_1 = ResnetBlock(
386
- in_channels=block_in,
387
- out_channels=block_in,
388
- temb_channels=self.temb_ch,
389
- dropout=dropout,
390
- )
391
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
392
- self.mid.block_2 = ResnetBlock(
393
- in_channels=block_in,
394
- out_channels=block_in,
395
- temb_channels=self.temb_ch,
396
- dropout=dropout,
397
- )
398
-
399
- # upsampling
400
- self.up = nn.ModuleList()
401
- for i_level in reversed(range(self.num_resolutions)):
402
- block = nn.ModuleList()
403
- attn = nn.ModuleList()
404
- block_out = ch * ch_mult[i_level]
405
- skip_in = ch * ch_mult[i_level]
406
- for i_block in range(self.num_res_blocks + 1):
407
- if i_block == self.num_res_blocks:
408
- skip_in = ch * in_ch_mult[i_level]
409
- block.append(
410
- ResnetBlock(
411
- in_channels=block_in + skip_in,
412
- out_channels=block_out,
413
- temb_channels=self.temb_ch,
414
- dropout=dropout,
415
- )
416
- )
417
- block_in = block_out
418
- if curr_res in attn_resolutions:
419
- attn.append(make_attn(block_in, attn_type=attn_type))
420
- up = nn.Module()
421
- up.block = block
422
- up.attn = attn
423
- if i_level != 0:
424
- up.upsample = Upsample(block_in, resamp_with_conv)
425
- curr_res = curr_res * 2
426
- self.up.insert(0, up) # prepend to get consistent order
427
-
428
- # end
429
- self.norm_out = Normalize(block_in)
430
- self.conv_out = torch.nn.Conv2d(
431
- block_in, out_ch, kernel_size=3, stride=1, padding=1
432
- )
433
-
434
- def forward(self, x, t=None, context=None):
435
- # assert x.shape[2] == x.shape[3] == self.resolution
436
- if context is not None:
437
- # assume aligned context, cat along channel axis
438
- x = torch.cat((x, context), dim=1)
439
- if self.use_timestep:
440
- # timestep embedding
441
- assert t is not None
442
- temb = get_timestep_embedding(t, self.ch)
443
- temb = self.temb.dense[0](temb)
444
- temb = nonlinearity(temb)
445
- temb = self.temb.dense[1](temb)
446
- else:
447
- temb = None
448
-
449
- # downsampling
450
- hs = [self.conv_in(x)]
451
- for i_level in range(self.num_resolutions):
452
- for i_block in range(self.num_res_blocks):
453
- h = self.down[i_level].block[i_block](hs[-1], temb)
454
- if len(self.down[i_level].attn) > 0:
455
- h = self.down[i_level].attn[i_block](h)
456
- hs.append(h)
457
- if i_level != self.num_resolutions - 1:
458
- hs.append(self.down[i_level].downsample(hs[-1]))
459
-
460
- # middle
461
- h = hs[-1]
462
- h = self.mid.block_1(h, temb)
463
- h = self.mid.attn_1(h)
464
- h = self.mid.block_2(h, temb)
465
-
466
- # upsampling
467
- for i_level in reversed(range(self.num_resolutions)):
468
- for i_block in range(self.num_res_blocks + 1):
469
- h = self.up[i_level].block[i_block](
470
- torch.cat([h, hs.pop()], dim=1), temb
471
- )
472
- if len(self.up[i_level].attn) > 0:
473
- h = self.up[i_level].attn[i_block](h)
474
- if i_level != 0:
475
- h = self.up[i_level].upsample(h)
476
-
477
- # end
478
- h = self.norm_out(h)
479
- h = nonlinearity(h)
480
- h = self.conv_out(h)
481
- return h
482
-
483
- def get_last_layer(self):
484
- return self.conv_out.weight
485
-
486
-
487
- class Encoder(nn.Module):
488
- def __init__(
489
- self,
490
- *,
491
- ch,
492
- out_ch,
493
- ch_mult=(1, 2, 4, 8),
494
- num_res_blocks,
495
- attn_resolutions,
496
- dropout=0.0,
497
- resamp_with_conv=True,
498
- in_channels,
499
- resolution,
500
- z_channels,
501
- double_z=True,
502
- use_linear_attn=False,
503
- attn_type="vanilla",
504
- **ignore_kwargs,
505
- ):
506
- super().__init__()
507
- if use_linear_attn:
508
- attn_type = "linear"
509
- self.ch = ch
510
- self.temb_ch = 0
511
- self.num_resolutions = len(ch_mult)
512
- self.num_res_blocks = num_res_blocks
513
- self.resolution = resolution
514
- self.in_channels = in_channels
515
-
516
- # downsampling
517
- self.conv_in = torch.nn.Conv2d(
518
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
519
- )
520
-
521
- curr_res = resolution
522
- in_ch_mult = (1,) + tuple(ch_mult)
523
- self.in_ch_mult = in_ch_mult
524
- self.down = nn.ModuleList()
525
- for i_level in range(self.num_resolutions):
526
- block = nn.ModuleList()
527
- attn = nn.ModuleList()
528
- block_in = ch * in_ch_mult[i_level]
529
- block_out = ch * ch_mult[i_level]
530
- for i_block in range(self.num_res_blocks):
531
- block.append(
532
- ResnetBlock(
533
- in_channels=block_in,
534
- out_channels=block_out,
535
- temb_channels=self.temb_ch,
536
- dropout=dropout,
537
- )
538
- )
539
- block_in = block_out
540
- if curr_res in attn_resolutions:
541
- attn.append(make_attn(block_in, attn_type=attn_type))
542
- down = nn.Module()
543
- down.block = block
544
- down.attn = attn
545
- if i_level != self.num_resolutions - 1:
546
- down.downsample = Downsample(block_in, resamp_with_conv)
547
- curr_res = curr_res // 2
548
- self.down.append(down)
549
-
550
- # middle
551
- self.mid = nn.Module()
552
- self.mid.block_1 = ResnetBlock(
553
- in_channels=block_in,
554
- out_channels=block_in,
555
- temb_channels=self.temb_ch,
556
- dropout=dropout,
557
- )
558
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
559
- self.mid.block_2 = ResnetBlock(
560
- in_channels=block_in,
561
- out_channels=block_in,
562
- temb_channels=self.temb_ch,
563
- dropout=dropout,
564
- )
565
-
566
- # end
567
- self.norm_out = Normalize(block_in)
568
- self.conv_out = torch.nn.Conv2d(
569
- block_in,
570
- 2 * z_channels if double_z else z_channels,
571
- kernel_size=3,
572
- stride=1,
573
- padding=1,
574
- )
575
-
576
- def forward(self, x):
577
- # timestep embedding
578
- temb = None
579
-
580
- # downsampling
581
- hs = [self.conv_in(x)]
582
- for i_level in range(self.num_resolutions):
583
- for i_block in range(self.num_res_blocks):
584
- h = self.down[i_level].block[i_block](hs[-1], temb)
585
- if len(self.down[i_level].attn) > 0:
586
- h = self.down[i_level].attn[i_block](h)
587
- hs.append(h)
588
- if i_level != self.num_resolutions - 1:
589
- hs.append(self.down[i_level].downsample(hs[-1]))
590
-
591
- # middle
592
- h = hs[-1]
593
- h = self.mid.block_1(h, temb)
594
- h = self.mid.attn_1(h)
595
- h = self.mid.block_2(h, temb)
596
-
597
- # end
598
- h = self.norm_out(h)
599
- h = nonlinearity(h)
600
- h = self.conv_out(h)
601
- return h
602
-
603
-
604
- class Decoder(nn.Module):
605
- def __init__(
606
- self,
607
- *,
608
- ch,
609
- out_ch,
610
- ch_mult=(1, 2, 4, 8),
611
- num_res_blocks,
612
- attn_resolutions,
613
- dropout=0.0,
614
- resamp_with_conv=True,
615
- in_channels,
616
- resolution,
617
- z_channels,
618
- give_pre_end=False,
619
- tanh_out=False,
620
- use_linear_attn=False,
621
- attn_type="vanilla",
622
- **ignorekwargs,
623
- ):
624
- super().__init__()
625
- if use_linear_attn:
626
- attn_type = "linear"
627
- self.ch = ch
628
- self.temb_ch = 0
629
- self.num_resolutions = len(ch_mult)
630
- self.num_res_blocks = num_res_blocks
631
- self.resolution = resolution
632
- self.in_channels = in_channels
633
- self.give_pre_end = give_pre_end
634
- self.tanh_out = tanh_out
635
-
636
- # compute in_ch_mult, block_in and curr_res at lowest res
637
- in_ch_mult = (1,) + tuple(ch_mult)
638
- block_in = ch * ch_mult[self.num_resolutions - 1]
639
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
640
- self.z_shape = (1, z_channels, curr_res, curr_res)
641
- logpy.info(
642
- "Working with z of shape {} = {} dimensions.".format(
643
- self.z_shape, np.prod(self.z_shape)
644
- )
645
- )
646
-
647
- make_attn_cls = self._make_attn()
648
- make_resblock_cls = self._make_resblock()
649
- make_conv_cls = self._make_conv()
650
- # z to block_in
651
- self.conv_in = torch.nn.Conv2d(
652
- z_channels, block_in, kernel_size=3, stride=1, padding=1
653
- )
654
-
655
- # middle
656
- self.mid = nn.Module()
657
- self.mid.block_1 = make_resblock_cls(
658
- in_channels=block_in,
659
- out_channels=block_in,
660
- temb_channels=self.temb_ch,
661
- dropout=dropout,
662
- )
663
- self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
664
- self.mid.block_2 = make_resblock_cls(
665
- in_channels=block_in,
666
- out_channels=block_in,
667
- temb_channels=self.temb_ch,
668
- dropout=dropout,
669
- )
670
-
671
- # upsampling
672
- self.up = nn.ModuleList()
673
- for i_level in reversed(range(self.num_resolutions)):
674
- block = nn.ModuleList()
675
- attn = nn.ModuleList()
676
- block_out = ch * ch_mult[i_level]
677
- for i_block in range(self.num_res_blocks + 1):
678
- block.append(
679
- make_resblock_cls(
680
- in_channels=block_in,
681
- out_channels=block_out,
682
- temb_channels=self.temb_ch,
683
- dropout=dropout,
684
- )
685
- )
686
- block_in = block_out
687
- if curr_res in attn_resolutions:
688
- attn.append(make_attn_cls(block_in, attn_type=attn_type))
689
- up = nn.Module()
690
- up.block = block
691
- up.attn = attn
692
- if i_level != 0:
693
- up.upsample = Upsample(block_in, resamp_with_conv)
694
- curr_res = curr_res * 2
695
- self.up.insert(0, up) # prepend to get consistent order
696
-
697
- # end
698
- self.norm_out = Normalize(block_in)
699
- self.conv_out = make_conv_cls(
700
- block_in, out_ch, kernel_size=3, stride=1, padding=1
701
- )
702
-
703
- def _make_attn(self) -> Callable:
704
- return make_attn
705
-
706
- def _make_resblock(self) -> Callable:
707
- return ResnetBlock
708
-
709
- def _make_conv(self) -> Callable:
710
- return torch.nn.Conv2d
711
-
712
- def get_last_layer(self, **kwargs):
713
- return self.conv_out.weight
714
-
715
- def forward(self, z, **kwargs):
716
- # assert z.shape[1:] == self.z_shape[1:]
717
- self.last_z_shape = z.shape
718
-
719
- # timestep embedding
720
- temb = None
721
-
722
- # z to block_in
723
- h = self.conv_in(z)
724
-
725
- # middle
726
- h = self.mid.block_1(h, temb, **kwargs)
727
- h = self.mid.attn_1(h, **kwargs)
728
- h = self.mid.block_2(h, temb, **kwargs)
729
-
730
- # upsampling
731
- for i_level in reversed(range(self.num_resolutions)):
732
- for i_block in range(self.num_res_blocks + 1):
733
- h = self.up[i_level].block[i_block](h, temb, **kwargs)
734
- if len(self.up[i_level].attn) > 0:
735
- h = self.up[i_level].attn[i_block](h, **kwargs)
736
- if i_level != 0:
737
- h = self.up[i_level].upsample(h)
738
-
739
- # end
740
- if self.give_pre_end:
741
- return h
742
-
743
- h = self.norm_out(h)
744
- h = nonlinearity(h)
745
- h = self.conv_out(h, **kwargs)
746
- if self.tanh_out:
747
- h = torch.tanh(h)
748
- return h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/openaimodel.py DELETED
@@ -1,853 +0,0 @@
1
- import logging
2
- import math
3
- from abc import abstractmethod
4
- from typing import Iterable, List, Optional, Tuple, Union
5
-
6
- import torch as th
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from einops import rearrange
10
- from torch.utils.checkpoint import checkpoint
11
-
12
- from ...modules.attention import SpatialTransformer
13
- from ...modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear,
14
- normalization,
15
- timestep_embedding, zero_module)
16
- from ...modules.video_attention import SpatialVideoTransformer
17
- from ...util import exists
18
-
19
- logpy = logging.getLogger(__name__)
20
-
21
-
22
- class AttentionPool2d(nn.Module):
23
- """
24
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
25
- """
26
-
27
- def __init__(
28
- self,
29
- spacial_dim: int,
30
- embed_dim: int,
31
- num_heads_channels: int,
32
- output_dim: Optional[int] = None,
33
- ):
34
- super().__init__()
35
- self.positional_embedding = nn.Parameter(
36
- th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
37
- )
38
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
39
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
40
- self.num_heads = embed_dim // num_heads_channels
41
- self.attention = QKVAttention(self.num_heads)
42
-
43
- def forward(self, x: th.Tensor) -> th.Tensor:
44
- b, c, _ = x.shape
45
- x = x.reshape(b, c, -1)
46
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
47
- x = x + self.positional_embedding[None, :, :].to(x.dtype)
48
- x = self.qkv_proj(x)
49
- x = self.attention(x)
50
- x = self.c_proj(x)
51
- return x[:, :, 0]
52
-
53
-
54
- class TimestepBlock(nn.Module):
55
- """
56
- Any module where forward() takes timestep embeddings as a second argument.
57
- """
58
-
59
- @abstractmethod
60
- def forward(self, x: th.Tensor, emb: th.Tensor):
61
- """
62
- Apply the module to `x` given `emb` timestep embeddings.
63
- """
64
-
65
-
66
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
67
- """
68
- A sequential module that passes timestep embeddings to the children that
69
- support it as an extra input.
70
- """
71
-
72
- def forward(
73
- self,
74
- x: th.Tensor,
75
- emb: th.Tensor,
76
- context: Optional[th.Tensor] = None,
77
- image_only_indicator: Optional[th.Tensor] = None,
78
- time_context: Optional[int] = None,
79
- num_video_frames: Optional[int] = None,
80
- ):
81
- from ...modules.diffusionmodules.video_model import VideoResBlock
82
-
83
- for layer in self:
84
- module = layer
85
-
86
- if isinstance(module, TimestepBlock) and not isinstance(
87
- module, VideoResBlock
88
- ):
89
- x = layer(x, emb)
90
- elif isinstance(module, VideoResBlock):
91
- x = layer(x, emb, num_video_frames, image_only_indicator)
92
- elif isinstance(module, SpatialVideoTransformer):
93
- x = layer(
94
- x,
95
- context,
96
- time_context,
97
- num_video_frames,
98
- image_only_indicator,
99
- )
100
- elif isinstance(module, SpatialTransformer):
101
- x = layer(x, context)
102
- else:
103
- x = layer(x)
104
- return x
105
-
106
-
107
- class Upsample(nn.Module):
108
- """
109
- An upsampling layer with an optional convolution.
110
- :param channels: channels in the inputs and outputs.
111
- :param use_conv: a bool determining if a convolution is applied.
112
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
113
- upsampling occurs in the inner-two dimensions.
114
- """
115
-
116
- def __init__(
117
- self,
118
- channels: int,
119
- use_conv: bool,
120
- dims: int = 2,
121
- out_channels: Optional[int] = None,
122
- padding: int = 1,
123
- third_up: bool = False,
124
- kernel_size: int = 3,
125
- scale_factor: int = 2,
126
- ):
127
- super().__init__()
128
- self.channels = channels
129
- self.out_channels = out_channels or channels
130
- self.use_conv = use_conv
131
- self.dims = dims
132
- self.third_up = third_up
133
- self.scale_factor = scale_factor
134
- if use_conv:
135
- self.conv = conv_nd(
136
- dims, self.channels, self.out_channels, kernel_size, padding=padding
137
- )
138
-
139
- def forward(self, x: th.Tensor) -> th.Tensor:
140
- assert x.shape[1] == self.channels
141
-
142
- if self.dims == 3:
143
- t_factor = 1 if not self.third_up else self.scale_factor
144
- x = F.interpolate(
145
- x,
146
- (
147
- t_factor * x.shape[2],
148
- x.shape[3] * self.scale_factor,
149
- x.shape[4] * self.scale_factor,
150
- ),
151
- mode="nearest",
152
- )
153
- else:
154
- x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
155
- if self.use_conv:
156
- x = self.conv(x)
157
- return x
158
-
159
-
160
- class Downsample(nn.Module):
161
- """
162
- A downsampling layer with an optional convolution.
163
- :param channels: channels in the inputs and outputs.
164
- :param use_conv: a bool determining if a convolution is applied.
165
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
166
- downsampling occurs in the inner-two dimensions.
167
- """
168
-
169
- def __init__(
170
- self,
171
- channels: int,
172
- use_conv: bool,
173
- dims: int = 2,
174
- out_channels: Optional[int] = None,
175
- padding: int = 1,
176
- third_down: bool = False,
177
- ):
178
- super().__init__()
179
- self.channels = channels
180
- self.out_channels = out_channels or channels
181
- self.use_conv = use_conv
182
- self.dims = dims
183
- stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
184
- if use_conv:
185
- logpy.info(f"Building a Downsample layer with {dims} dims.")
186
- logpy.info(
187
- f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
188
- f"kernel-size: 3, stride: {stride}, padding: {padding}"
189
- )
190
- if dims == 3:
191
- logpy.info(f" --> Downsampling third axis (time): {third_down}")
192
- self.op = conv_nd(
193
- dims,
194
- self.channels,
195
- self.out_channels,
196
- 3,
197
- stride=stride,
198
- padding=padding,
199
- )
200
- else:
201
- assert self.channels == self.out_channels
202
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
203
-
204
- def forward(self, x: th.Tensor) -> th.Tensor:
205
- assert x.shape[1] == self.channels
206
-
207
- return self.op(x)
208
-
209
-
210
- class ResBlock(TimestepBlock):
211
- """
212
- A residual block that can optionally change the number of channels.
213
- :param channels: the number of input channels.
214
- :param emb_channels: the number of timestep embedding channels.
215
- :param dropout: the rate of dropout.
216
- :param out_channels: if specified, the number of out channels.
217
- :param use_conv: if True and out_channels is specified, use a spatial
218
- convolution instead of a smaller 1x1 convolution to change the
219
- channels in the skip connection.
220
- :param dims: determines if the signal is 1D, 2D, or 3D.
221
- :param use_checkpoint: if True, use gradient checkpointing on this module.
222
- :param up: if True, use this block for upsampling.
223
- :param down: if True, use this block for downsampling.
224
- """
225
-
226
- def __init__(
227
- self,
228
- channels: int,
229
- emb_channels: int,
230
- dropout: float,
231
- out_channels: Optional[int] = None,
232
- use_conv: bool = False,
233
- use_scale_shift_norm: bool = False,
234
- dims: int = 2,
235
- use_checkpoint: bool = False,
236
- up: bool = False,
237
- down: bool = False,
238
- kernel_size: int = 3,
239
- exchange_temb_dims: bool = False,
240
- skip_t_emb: bool = False,
241
- ):
242
- super().__init__()
243
- self.channels = channels
244
- self.emb_channels = emb_channels
245
- self.dropout = dropout
246
- self.out_channels = out_channels or channels
247
- self.use_conv = use_conv
248
- self.use_checkpoint = use_checkpoint
249
- self.use_scale_shift_norm = use_scale_shift_norm
250
- self.exchange_temb_dims = exchange_temb_dims
251
-
252
- if isinstance(kernel_size, Iterable):
253
- padding = [k // 2 for k in kernel_size]
254
- else:
255
- padding = kernel_size // 2
256
-
257
- self.in_layers = nn.Sequential(
258
- normalization(channels),
259
- nn.SiLU(),
260
- conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
261
- )
262
-
263
- self.updown = up or down
264
-
265
- if up:
266
- self.h_upd = Upsample(channels, False, dims)
267
- self.x_upd = Upsample(channels, False, dims)
268
- elif down:
269
- self.h_upd = Downsample(channels, False, dims)
270
- self.x_upd = Downsample(channels, False, dims)
271
- else:
272
- self.h_upd = self.x_upd = nn.Identity()
273
-
274
- self.skip_t_emb = skip_t_emb
275
- self.emb_out_channels = (
276
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels
277
- )
278
- if self.skip_t_emb:
279
- logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}")
280
- assert not self.use_scale_shift_norm
281
- self.emb_layers = None
282
- self.exchange_temb_dims = False
283
- else:
284
- self.emb_layers = nn.Sequential(
285
- nn.SiLU(),
286
- linear(
287
- emb_channels,
288
- self.emb_out_channels,
289
- ),
290
- )
291
-
292
- self.out_layers = nn.Sequential(
293
- normalization(self.out_channels),
294
- nn.SiLU(),
295
- nn.Dropout(p=dropout),
296
- zero_module(
297
- conv_nd(
298
- dims,
299
- self.out_channels,
300
- self.out_channels,
301
- kernel_size,
302
- padding=padding,
303
- )
304
- ),
305
- )
306
-
307
- if self.out_channels == channels:
308
- self.skip_connection = nn.Identity()
309
- elif use_conv:
310
- self.skip_connection = conv_nd(
311
- dims, channels, self.out_channels, kernel_size, padding=padding
312
- )
313
- else:
314
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
315
-
316
- def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
317
- """
318
- Apply the block to a Tensor, conditioned on a timestep embedding.
319
- :param x: an [N x C x ...] Tensor of features.
320
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
321
- :return: an [N x C x ...] Tensor of outputs.
322
- """
323
- if self.use_checkpoint:
324
- return checkpoint(self._forward, x, emb)
325
- else:
326
- return self._forward(x, emb)
327
-
328
- def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
329
- if self.updown:
330
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
331
- h = in_rest(x)
332
- h = self.h_upd(h)
333
- x = self.x_upd(x)
334
- h = in_conv(h)
335
- else:
336
- h = self.in_layers(x)
337
-
338
- if self.skip_t_emb:
339
- emb_out = th.zeros_like(h)
340
- else:
341
- emb_out = self.emb_layers(emb).type(h.dtype)
342
- while len(emb_out.shape) < len(h.shape):
343
- emb_out = emb_out[..., None]
344
- if self.use_scale_shift_norm:
345
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
346
- scale, shift = th.chunk(emb_out, 2, dim=1)
347
- h = out_norm(h) * (1 + scale) + shift
348
- h = out_rest(h)
349
- else:
350
- if self.exchange_temb_dims:
351
- emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
352
- h = h + emb_out
353
- h = self.out_layers(h)
354
- return self.skip_connection(x) + h
355
-
356
-
357
- class AttentionBlock(nn.Module):
358
- """
359
- An attention block that allows spatial positions to attend to each other.
360
- Originally ported from here, but adapted to the N-d case.
361
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
362
- """
363
-
364
- def __init__(
365
- self,
366
- channels: int,
367
- num_heads: int = 1,
368
- num_head_channels: int = -1,
369
- use_checkpoint: bool = False,
370
- use_new_attention_order: bool = False,
371
- ):
372
- super().__init__()
373
- self.channels = channels
374
- if num_head_channels == -1:
375
- self.num_heads = num_heads
376
- else:
377
- assert (
378
- channels % num_head_channels == 0
379
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
380
- self.num_heads = channels // num_head_channels
381
- self.use_checkpoint = use_checkpoint
382
- self.norm = normalization(channels)
383
- self.qkv = conv_nd(1, channels, channels * 3, 1)
384
- if use_new_attention_order:
385
- # split qkv before split heads
386
- self.attention = QKVAttention(self.num_heads)
387
- else:
388
- # split heads before split qkv
389
- self.attention = QKVAttentionLegacy(self.num_heads)
390
-
391
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
392
-
393
- def forward(self, x: th.Tensor, **kwargs) -> th.Tensor:
394
- return checkpoint(self._forward, x)
395
-
396
- def _forward(self, x: th.Tensor) -> th.Tensor:
397
- b, c, *spatial = x.shape
398
- x = x.reshape(b, c, -1)
399
- qkv = self.qkv(self.norm(x))
400
- h = self.attention(qkv)
401
- h = self.proj_out(h)
402
- return (x + h).reshape(b, c, *spatial)
403
-
404
-
405
- class QKVAttentionLegacy(nn.Module):
406
- """
407
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
408
- """
409
-
410
- def __init__(self, n_heads: int):
411
- super().__init__()
412
- self.n_heads = n_heads
413
-
414
- def forward(self, qkv: th.Tensor) -> th.Tensor:
415
- """
416
- Apply QKV attention.
417
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
418
- :return: an [N x (H * C) x T] tensor after attention.
419
- """
420
- bs, width, length = qkv.shape
421
- assert width % (3 * self.n_heads) == 0
422
- ch = width // (3 * self.n_heads)
423
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
424
- scale = 1 / math.sqrt(math.sqrt(ch))
425
- weight = th.einsum(
426
- "bct,bcs->bts", q * scale, k * scale
427
- ) # More stable with f16 than dividing afterwards
428
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
429
- a = th.einsum("bts,bcs->bct", weight, v)
430
- return a.reshape(bs, -1, length)
431
-
432
-
433
- class QKVAttention(nn.Module):
434
- """
435
- A module which performs QKV attention and splits in a different order.
436
- """
437
-
438
- def __init__(self, n_heads: int):
439
- super().__init__()
440
- self.n_heads = n_heads
441
-
442
- def forward(self, qkv: th.Tensor) -> th.Tensor:
443
- """
444
- Apply QKV attention.
445
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
446
- :return: an [N x (H * C) x T] tensor after attention.
447
- """
448
- bs, width, length = qkv.shape
449
- assert width % (3 * self.n_heads) == 0
450
- ch = width // (3 * self.n_heads)
451
- q, k, v = qkv.chunk(3, dim=1)
452
- scale = 1 / math.sqrt(math.sqrt(ch))
453
- weight = th.einsum(
454
- "bct,bcs->bts",
455
- (q * scale).view(bs * self.n_heads, ch, length),
456
- (k * scale).view(bs * self.n_heads, ch, length),
457
- ) # More stable with f16 than dividing afterwards
458
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
459
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
460
- return a.reshape(bs, -1, length)
461
-
462
-
463
- class Timestep(nn.Module):
464
- def __init__(self, dim: int):
465
- super().__init__()
466
- self.dim = dim
467
-
468
- def forward(self, t: th.Tensor) -> th.Tensor:
469
- return timestep_embedding(t, self.dim)
470
-
471
-
472
- class UNetModel(nn.Module):
473
- """
474
- The full UNet model with attention and timestep embedding.
475
- :param in_channels: channels in the input Tensor.
476
- :param model_channels: base channel count for the model.
477
- :param out_channels: channels in the output Tensor.
478
- :param num_res_blocks: number of residual blocks per downsample.
479
- :param attention_resolutions: a collection of downsample rates at which
480
- attention will take place. May be a set, list, or tuple.
481
- For example, if this contains 4, then at 4x downsampling, attention
482
- will be used.
483
- :param dropout: the dropout probability.
484
- :param channel_mult: channel multiplier for each level of the UNet.
485
- :param conv_resample: if True, use learned convolutions for upsampling and
486
- downsampling.
487
- :param dims: determines if the signal is 1D, 2D, or 3D.
488
- :param num_classes: if specified (as an int), then this model will be
489
- class-conditional with `num_classes` classes.
490
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
491
- :param num_heads: the number of attention heads in each attention layer.
492
- :param num_heads_channels: if specified, ignore num_heads and instead use
493
- a fixed channel width per attention head.
494
- :param num_heads_upsample: works with num_heads to set a different number
495
- of heads for upsampling. Deprecated.
496
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
497
- :param resblock_updown: use residual blocks for up/downsampling.
498
- :param use_new_attention_order: use a different attention pattern for potentially
499
- increased efficiency.
500
- """
501
-
502
- def __init__(
503
- self,
504
- in_channels: int,
505
- model_channels: int,
506
- out_channels: int,
507
- num_res_blocks: int,
508
- attention_resolutions: int,
509
- dropout: float = 0.0,
510
- channel_mult: Union[List, Tuple] = (1, 2, 4, 8),
511
- conv_resample: bool = True,
512
- dims: int = 2,
513
- num_classes: Optional[Union[int, str]] = None,
514
- use_checkpoint: bool = False,
515
- num_heads: int = -1,
516
- num_head_channels: int = -1,
517
- num_heads_upsample: int = -1,
518
- use_scale_shift_norm: bool = False,
519
- resblock_updown: bool = False,
520
- transformer_depth: int = 1,
521
- context_dim: Optional[int] = None,
522
- disable_self_attentions: Optional[List[bool]] = None,
523
- num_attention_blocks: Optional[List[int]] = None,
524
- disable_middle_self_attn: bool = False,
525
- disable_middle_transformer: bool = False,
526
- use_linear_in_transformer: bool = False,
527
- spatial_transformer_attn_type: str = "softmax",
528
- adm_in_channels: Optional[int] = None,
529
- ):
530
- super().__init__()
531
-
532
- if num_heads_upsample == -1:
533
- num_heads_upsample = num_heads
534
-
535
- if num_heads == -1:
536
- assert (
537
- num_head_channels != -1
538
- ), "Either num_heads or num_head_channels has to be set"
539
-
540
- if num_head_channels == -1:
541
- assert (
542
- num_heads != -1
543
- ), "Either num_heads or num_head_channels has to be set"
544
-
545
- self.in_channels = in_channels
546
- self.model_channels = model_channels
547
- self.out_channels = out_channels
548
- if isinstance(transformer_depth, int):
549
- transformer_depth = len(channel_mult) * [transformer_depth]
550
- transformer_depth_middle = transformer_depth[-1]
551
-
552
- if isinstance(num_res_blocks, int):
553
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
554
- else:
555
- if len(num_res_blocks) != len(channel_mult):
556
- raise ValueError(
557
- "provide num_res_blocks either as an int (globally constant) or "
558
- "as a list/tuple (per-level) with the same length as channel_mult"
559
- )
560
- self.num_res_blocks = num_res_blocks
561
-
562
- if disable_self_attentions is not None:
563
- assert len(disable_self_attentions) == len(channel_mult)
564
- if num_attention_blocks is not None:
565
- assert len(num_attention_blocks) == len(self.num_res_blocks)
566
- assert all(
567
- map(
568
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
569
- range(len(num_attention_blocks)),
570
- )
571
- )
572
- logpy.info(
573
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
574
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
575
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
576
- f"attention will still not be set."
577
- )
578
-
579
- self.attention_resolutions = attention_resolutions
580
- self.dropout = dropout
581
- self.channel_mult = channel_mult
582
- self.conv_resample = conv_resample
583
- self.num_classes = num_classes
584
- self.use_checkpoint = use_checkpoint
585
- self.num_heads = num_heads
586
- self.num_head_channels = num_head_channels
587
- self.num_heads_upsample = num_heads_upsample
588
-
589
- time_embed_dim = model_channels * 4
590
- self.time_embed = nn.Sequential(
591
- linear(model_channels, time_embed_dim),
592
- nn.SiLU(),
593
- linear(time_embed_dim, time_embed_dim),
594
- )
595
-
596
- if self.num_classes is not None:
597
- if isinstance(self.num_classes, int):
598
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
599
- elif self.num_classes == "continuous":
600
- logpy.info("setting up linear c_adm embedding layer")
601
- self.label_emb = nn.Linear(1, time_embed_dim)
602
- elif self.num_classes == "timestep":
603
- self.label_emb = nn.Sequential(
604
- Timestep(model_channels),
605
- nn.Sequential(
606
- linear(model_channels, time_embed_dim),
607
- nn.SiLU(),
608
- linear(time_embed_dim, time_embed_dim),
609
- ),
610
- )
611
- elif self.num_classes == "sequential":
612
- assert adm_in_channels is not None
613
- self.label_emb = nn.Sequential(
614
- nn.Sequential(
615
- linear(adm_in_channels, time_embed_dim),
616
- nn.SiLU(),
617
- linear(time_embed_dim, time_embed_dim),
618
- )
619
- )
620
- else:
621
- raise ValueError
622
-
623
- self.input_blocks = nn.ModuleList(
624
- [
625
- TimestepEmbedSequential(
626
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
627
- )
628
- ]
629
- )
630
- self._feature_size = model_channels
631
- input_block_chans = [model_channels]
632
- ch = model_channels
633
- ds = 1
634
- for level, mult in enumerate(channel_mult):
635
- for nr in range(self.num_res_blocks[level]):
636
- layers = [
637
- ResBlock(
638
- ch,
639
- time_embed_dim,
640
- dropout,
641
- out_channels=mult * model_channels,
642
- dims=dims,
643
- use_checkpoint=use_checkpoint,
644
- use_scale_shift_norm=use_scale_shift_norm,
645
- )
646
- ]
647
- ch = mult * model_channels
648
- if ds in attention_resolutions:
649
- if num_head_channels == -1:
650
- dim_head = ch // num_heads
651
- else:
652
- num_heads = ch // num_head_channels
653
- dim_head = num_head_channels
654
-
655
- if context_dim is not None and exists(disable_self_attentions):
656
- disabled_sa = disable_self_attentions[level]
657
- else:
658
- disabled_sa = False
659
-
660
- if (
661
- not exists(num_attention_blocks)
662
- or nr < num_attention_blocks[level]
663
- ):
664
- layers.append(
665
- SpatialTransformer(
666
- ch,
667
- num_heads,
668
- dim_head,
669
- depth=transformer_depth[level],
670
- context_dim=context_dim,
671
- disable_self_attn=disabled_sa,
672
- use_linear=use_linear_in_transformer,
673
- attn_type=spatial_transformer_attn_type,
674
- use_checkpoint=use_checkpoint,
675
- )
676
- )
677
- self.input_blocks.append(TimestepEmbedSequential(*layers))
678
- self._feature_size += ch
679
- input_block_chans.append(ch)
680
- if level != len(channel_mult) - 1:
681
- out_ch = ch
682
- self.input_blocks.append(
683
- TimestepEmbedSequential(
684
- ResBlock(
685
- ch,
686
- time_embed_dim,
687
- dropout,
688
- out_channels=out_ch,
689
- dims=dims,
690
- use_checkpoint=use_checkpoint,
691
- use_scale_shift_norm=use_scale_shift_norm,
692
- down=True,
693
- )
694
- if resblock_updown
695
- else Downsample(
696
- ch, conv_resample, dims=dims, out_channels=out_ch
697
- )
698
- )
699
- )
700
- ch = out_ch
701
- input_block_chans.append(ch)
702
- ds *= 2
703
- self._feature_size += ch
704
-
705
- if num_head_channels == -1:
706
- dim_head = ch // num_heads
707
- else:
708
- num_heads = ch // num_head_channels
709
- dim_head = num_head_channels
710
-
711
- self.middle_block = TimestepEmbedSequential(
712
- ResBlock(
713
- ch,
714
- time_embed_dim,
715
- dropout,
716
- out_channels=ch,
717
- dims=dims,
718
- use_checkpoint=use_checkpoint,
719
- use_scale_shift_norm=use_scale_shift_norm,
720
- ),
721
- SpatialTransformer(
722
- ch,
723
- num_heads,
724
- dim_head,
725
- depth=transformer_depth_middle,
726
- context_dim=context_dim,
727
- disable_self_attn=disable_middle_self_attn,
728
- use_linear=use_linear_in_transformer,
729
- attn_type=spatial_transformer_attn_type,
730
- use_checkpoint=use_checkpoint,
731
- )
732
- if not disable_middle_transformer
733
- else th.nn.Identity(),
734
- ResBlock(
735
- ch,
736
- time_embed_dim,
737
- dropout,
738
- dims=dims,
739
- use_checkpoint=use_checkpoint,
740
- use_scale_shift_norm=use_scale_shift_norm,
741
- ),
742
- )
743
- self._feature_size += ch
744
-
745
- self.output_blocks = nn.ModuleList([])
746
- for level, mult in list(enumerate(channel_mult))[::-1]:
747
- for i in range(self.num_res_blocks[level] + 1):
748
- ich = input_block_chans.pop()
749
- layers = [
750
- ResBlock(
751
- ch + ich,
752
- time_embed_dim,
753
- dropout,
754
- out_channels=model_channels * mult,
755
- dims=dims,
756
- use_checkpoint=use_checkpoint,
757
- use_scale_shift_norm=use_scale_shift_norm,
758
- )
759
- ]
760
- ch = model_channels * mult
761
- if ds in attention_resolutions:
762
- if num_head_channels == -1:
763
- dim_head = ch // num_heads
764
- else:
765
- num_heads = ch // num_head_channels
766
- dim_head = num_head_channels
767
-
768
- if exists(disable_self_attentions):
769
- disabled_sa = disable_self_attentions[level]
770
- else:
771
- disabled_sa = False
772
-
773
- if (
774
- not exists(num_attention_blocks)
775
- or i < num_attention_blocks[level]
776
- ):
777
- layers.append(
778
- SpatialTransformer(
779
- ch,
780
- num_heads,
781
- dim_head,
782
- depth=transformer_depth[level],
783
- context_dim=context_dim,
784
- disable_self_attn=disabled_sa,
785
- use_linear=use_linear_in_transformer,
786
- attn_type=spatial_transformer_attn_type,
787
- use_checkpoint=use_checkpoint,
788
- )
789
- )
790
- if level and i == self.num_res_blocks[level]:
791
- out_ch = ch
792
- layers.append(
793
- ResBlock(
794
- ch,
795
- time_embed_dim,
796
- dropout,
797
- out_channels=out_ch,
798
- dims=dims,
799
- use_checkpoint=use_checkpoint,
800
- use_scale_shift_norm=use_scale_shift_norm,
801
- up=True,
802
- )
803
- if resblock_updown
804
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
805
- )
806
- ds //= 2
807
- self.output_blocks.append(TimestepEmbedSequential(*layers))
808
- self._feature_size += ch
809
-
810
- self.out = nn.Sequential(
811
- normalization(ch),
812
- nn.SiLU(),
813
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
814
- )
815
-
816
- def forward(
817
- self,
818
- x: th.Tensor,
819
- timesteps: Optional[th.Tensor] = None,
820
- context: Optional[th.Tensor] = None,
821
- y: Optional[th.Tensor] = None,
822
- **kwargs,
823
- ) -> th.Tensor:
824
- """
825
- Apply the model to an input batch.
826
- :param x: an [N x C x ...] Tensor of inputs.
827
- :param timesteps: a 1-D batch of timesteps.
828
- :param context: conditioning plugged in via crossattn
829
- :param y: an [N] Tensor of labels, if class-conditional.
830
- :return: an [N x C x ...] Tensor of outputs.
831
- """
832
- assert (y is not None) == (
833
- self.num_classes is not None
834
- ), "must specify y if and only if the model is class-conditional"
835
- hs = []
836
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
837
- emb = self.time_embed(t_emb)
838
-
839
- if self.num_classes is not None:
840
- assert y.shape[0] == x.shape[0]
841
- emb = emb + self.label_emb(y)
842
-
843
- h = x
844
- for module in self.input_blocks:
845
- h = module(h, emb, context)
846
- hs.append(h)
847
- h = self.middle_block(h, emb, context)
848
- for module in self.output_blocks:
849
- h = th.cat([h, hs.pop()], dim=1)
850
- h = module(h, emb, context)
851
- h = h.type(x.dtype)
852
-
853
- return self.out(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/sampling.py DELETED
@@ -1,362 +0,0 @@
1
- """
2
- Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
3
- """
4
-
5
-
6
- from typing import Dict, Union
7
-
8
- import torch
9
- from omegaconf import ListConfig, OmegaConf
10
- from tqdm import tqdm
11
-
12
- from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,
13
- linear_multistep_coeff,
14
- to_d, to_neg_log_sigma,
15
- to_sigma)
16
- from ...util import append_dims, default, instantiate_from_config
17
-
18
- DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
19
-
20
-
21
- class BaseDiffusionSampler:
22
- def __init__(
23
- self,
24
- discretization_config: Union[Dict, ListConfig, OmegaConf],
25
- num_steps: Union[int, None] = None,
26
- guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
27
- verbose: bool = False,
28
- device: str = "cuda",
29
- ):
30
- self.num_steps = num_steps
31
- self.discretization = instantiate_from_config(discretization_config)
32
- self.guider = instantiate_from_config(
33
- default(
34
- guider_config,
35
- DEFAULT_GUIDER,
36
- )
37
- )
38
- self.verbose = verbose
39
- self.device = device
40
-
41
- def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
42
- sigmas = self.discretization(
43
- self.num_steps if num_steps is None else num_steps, device=self.device
44
- )
45
- uc = default(uc, cond)
46
-
47
- x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
48
- num_sigmas = len(sigmas)
49
-
50
- s_in = x.new_ones([x.shape[0]])
51
-
52
- return x, s_in, sigmas, num_sigmas, cond, uc
53
-
54
- def denoise(self, x, denoiser, sigma, cond, uc):
55
- denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
56
- denoised = self.guider(denoised, sigma)
57
- return denoised
58
-
59
- def get_sigma_gen(self, num_sigmas):
60
- sigma_generator = range(num_sigmas - 1)
61
- if self.verbose:
62
- print("#" * 30, " Sampling setting ", "#" * 30)
63
- print(f"Sampler: {self.__class__.__name__}")
64
- print(f"Discretization: {self.discretization.__class__.__name__}")
65
- print(f"Guider: {self.guider.__class__.__name__}")
66
- sigma_generator = tqdm(
67
- sigma_generator,
68
- total=num_sigmas,
69
- desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
70
- )
71
- return sigma_generator
72
-
73
-
74
- class SingleStepDiffusionSampler(BaseDiffusionSampler):
75
- def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
76
- raise NotImplementedError
77
-
78
- def euler_step(self, x, d, dt):
79
- return x + dt * d
80
-
81
-
82
- class EDMSampler(SingleStepDiffusionSampler):
83
- def __init__(
84
- self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
85
- ):
86
- super().__init__(*args, **kwargs)
87
-
88
- self.s_churn = s_churn
89
- self.s_tmin = s_tmin
90
- self.s_tmax = s_tmax
91
- self.s_noise = s_noise
92
-
93
- def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
94
- sigma_hat = sigma * (gamma + 1.0)
95
- if gamma > 0:
96
- eps = torch.randn_like(x) * self.s_noise
97
- x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
98
-
99
- denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
100
- d = to_d(x, sigma_hat, denoised)
101
- dt = append_dims(next_sigma - sigma_hat, x.ndim)
102
-
103
- euler_step = self.euler_step(x, d, dt)
104
- x = self.possible_correction_step(
105
- euler_step, x, d, dt, next_sigma, denoiser, cond, uc
106
- )
107
- return x
108
-
109
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
110
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
111
- x, cond, uc, num_steps
112
- )
113
-
114
- for i in self.get_sigma_gen(num_sigmas):
115
- gamma = (
116
- min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
117
- if self.s_tmin <= sigmas[i] <= self.s_tmax
118
- else 0.0
119
- )
120
- x = self.sampler_step(
121
- s_in * sigmas[i],
122
- s_in * sigmas[i + 1],
123
- denoiser,
124
- x,
125
- cond,
126
- uc,
127
- gamma,
128
- )
129
-
130
- return x
131
-
132
-
133
- class AncestralSampler(SingleStepDiffusionSampler):
134
- def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
135
- super().__init__(*args, **kwargs)
136
-
137
- self.eta = eta
138
- self.s_noise = s_noise
139
- self.noise_sampler = lambda x: torch.randn_like(x)
140
-
141
- def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
142
- d = to_d(x, sigma, denoised)
143
- dt = append_dims(sigma_down - sigma, x.ndim)
144
-
145
- return self.euler_step(x, d, dt)
146
-
147
- def ancestral_step(self, x, sigma, next_sigma, sigma_up):
148
- x = torch.where(
149
- append_dims(next_sigma, x.ndim) > 0.0,
150
- x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
151
- x,
152
- )
153
- return x
154
-
155
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
156
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
157
- x, cond, uc, num_steps
158
- )
159
-
160
- for i in self.get_sigma_gen(num_sigmas):
161
- x = self.sampler_step(
162
- s_in * sigmas[i],
163
- s_in * sigmas[i + 1],
164
- denoiser,
165
- x,
166
- cond,
167
- uc,
168
- )
169
-
170
- return x
171
-
172
-
173
- class LinearMultistepSampler(BaseDiffusionSampler):
174
- def __init__(
175
- self,
176
- order=4,
177
- *args,
178
- **kwargs,
179
- ):
180
- super().__init__(*args, **kwargs)
181
-
182
- self.order = order
183
-
184
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
185
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
186
- x, cond, uc, num_steps
187
- )
188
-
189
- ds = []
190
- sigmas_cpu = sigmas.detach().cpu().numpy()
191
- for i in self.get_sigma_gen(num_sigmas):
192
- sigma = s_in * sigmas[i]
193
- denoised = denoiser(
194
- *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
195
- )
196
- denoised = self.guider(denoised, sigma)
197
- d = to_d(x, sigma, denoised)
198
- ds.append(d)
199
- if len(ds) > self.order:
200
- ds.pop(0)
201
- cur_order = min(i + 1, self.order)
202
- coeffs = [
203
- linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
204
- for j in range(cur_order)
205
- ]
206
- x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
207
-
208
- return x
209
-
210
-
211
- class EulerEDMSampler(EDMSampler):
212
- def possible_correction_step(
213
- self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
214
- ):
215
- return euler_step
216
-
217
-
218
- class HeunEDMSampler(EDMSampler):
219
- def possible_correction_step(
220
- self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
221
- ):
222
- if torch.sum(next_sigma) < 1e-14:
223
- # Save a network evaluation if all noise levels are 0
224
- return euler_step
225
- else:
226
- denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
227
- d_new = to_d(euler_step, next_sigma, denoised)
228
- d_prime = (d + d_new) / 2.0
229
-
230
- # apply correction if noise level is not 0
231
- x = torch.where(
232
- append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
233
- )
234
- return x
235
-
236
-
237
- class EulerAncestralSampler(AncestralSampler):
238
- def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
239
- sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
240
- denoised = self.denoise(x, denoiser, sigma, cond, uc)
241
- x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
242
- x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
243
-
244
- return x
245
-
246
-
247
- class DPMPP2SAncestralSampler(AncestralSampler):
248
- def get_variables(self, sigma, sigma_down):
249
- t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
250
- h = t_next - t
251
- s = t + 0.5 * h
252
- return h, s, t, t_next
253
-
254
- def get_mult(self, h, s, t, t_next):
255
- mult1 = to_sigma(s) / to_sigma(t)
256
- mult2 = (-0.5 * h).expm1()
257
- mult3 = to_sigma(t_next) / to_sigma(t)
258
- mult4 = (-h).expm1()
259
-
260
- return mult1, mult2, mult3, mult4
261
-
262
- def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
263
- sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
264
- denoised = self.denoise(x, denoiser, sigma, cond, uc)
265
- x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
266
-
267
- if torch.sum(sigma_down) < 1e-14:
268
- # Save a network evaluation if all noise levels are 0
269
- x = x_euler
270
- else:
271
- h, s, t, t_next = self.get_variables(sigma, sigma_down)
272
- mult = [
273
- append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
274
- ]
275
-
276
- x2 = mult[0] * x - mult[1] * denoised
277
- denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
278
- x_dpmpp2s = mult[2] * x - mult[3] * denoised2
279
-
280
- # apply correction if noise level is not 0
281
- x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
282
-
283
- x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
284
- return x
285
-
286
-
287
- class DPMPP2MSampler(BaseDiffusionSampler):
288
- def get_variables(self, sigma, next_sigma, previous_sigma=None):
289
- t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
290
- h = t_next - t
291
-
292
- if previous_sigma is not None:
293
- h_last = t - to_neg_log_sigma(previous_sigma)
294
- r = h_last / h
295
- return h, r, t, t_next
296
- else:
297
- return h, None, t, t_next
298
-
299
- def get_mult(self, h, r, t, t_next, previous_sigma):
300
- mult1 = to_sigma(t_next) / to_sigma(t)
301
- mult2 = (-h).expm1()
302
-
303
- if previous_sigma is not None:
304
- mult3 = 1 + 1 / (2 * r)
305
- mult4 = 1 / (2 * r)
306
- return mult1, mult2, mult3, mult4
307
- else:
308
- return mult1, mult2
309
-
310
- def sampler_step(
311
- self,
312
- old_denoised,
313
- previous_sigma,
314
- sigma,
315
- next_sigma,
316
- denoiser,
317
- x,
318
- cond,
319
- uc=None,
320
- ):
321
- denoised = self.denoise(x, denoiser, sigma, cond, uc)
322
-
323
- h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
324
- mult = [
325
- append_dims(mult, x.ndim)
326
- for mult in self.get_mult(h, r, t, t_next, previous_sigma)
327
- ]
328
-
329
- x_standard = mult[0] * x - mult[1] * denoised
330
- if old_denoised is None or torch.sum(next_sigma) < 1e-14:
331
- # Save a network evaluation if all noise levels are 0 or on the first step
332
- return x_standard, denoised
333
- else:
334
- denoised_d = mult[2] * denoised - mult[3] * old_denoised
335
- x_advanced = mult[0] * x - mult[1] * denoised_d
336
-
337
- # apply correction if noise level is not 0 and not first step
338
- x = torch.where(
339
- append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
340
- )
341
-
342
- return x, denoised
343
-
344
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
345
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
346
- x, cond, uc, num_steps
347
- )
348
-
349
- old_denoised = None
350
- for i in self.get_sigma_gen(num_sigmas):
351
- x, old_denoised = self.sampler_step(
352
- old_denoised,
353
- None if i == 0 else s_in * sigmas[i - 1],
354
- s_in * sigmas[i],
355
- s_in * sigmas[i + 1],
356
- denoiser,
357
- x,
358
- cond,
359
- uc=uc,
360
- )
361
-
362
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/sampling_utils.py DELETED
@@ -1,43 +0,0 @@
1
- import torch
2
- from scipy import integrate
3
-
4
- from ...util import append_dims
5
-
6
-
7
- def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
8
- if order - 1 > i:
9
- raise ValueError(f"Order {order} too high for step {i}")
10
-
11
- def fn(tau):
12
- prod = 1.0
13
- for k in range(order):
14
- if j == k:
15
- continue
16
- prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
17
- return prod
18
-
19
- return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
20
-
21
-
22
- def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
23
- if not eta:
24
- return sigma_to, 0.0
25
- sigma_up = torch.minimum(
26
- sigma_to,
27
- eta
28
- * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
29
- )
30
- sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
31
- return sigma_down, sigma_up
32
-
33
-
34
- def to_d(x, sigma, denoised):
35
- return (x - denoised) / append_dims(sigma, x.ndim)
36
-
37
-
38
- def to_neg_log_sigma(sigma):
39
- return sigma.log().neg()
40
-
41
-
42
- def to_sigma(neg_log_sigma):
43
- return neg_log_sigma.neg().exp()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/sigma_sampling.py DELETED
@@ -1,31 +0,0 @@
1
- import torch
2
-
3
- from ...util import default, instantiate_from_config
4
-
5
-
6
- class EDMSampling:
7
- def __init__(self, p_mean=-1.2, p_std=1.2):
8
- self.p_mean = p_mean
9
- self.p_std = p_std
10
-
11
- def __call__(self, n_samples, rand=None):
12
- log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
13
- return log_sigma.exp()
14
-
15
-
16
- class DiscreteSampling:
17
- def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
18
- self.num_idx = num_idx
19
- self.sigmas = instantiate_from_config(discretization_config)(
20
- num_idx, do_append_zero=do_append_zero, flip=flip
21
- )
22
-
23
- def idx_to_sigma(self, idx):
24
- return self.sigmas[idx]
25
-
26
- def __call__(self, n_samples, rand=None):
27
- idx = default(
28
- rand,
29
- torch.randint(0, self.num_idx, (n_samples,)),
30
- )
31
- return self.idx_to_sigma(idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/util.py DELETED
@@ -1,369 +0,0 @@
1
- """
2
- partially adopted from
3
- https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
4
- and
5
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
6
- and
7
- https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
8
-
9
- thanks!
10
- """
11
-
12
- import math
13
- from typing import Optional
14
-
15
- import torch
16
- import torch.nn as nn
17
- from einops import rearrange, repeat
18
-
19
-
20
- def make_beta_schedule(
21
- schedule,
22
- n_timestep,
23
- linear_start=1e-4,
24
- linear_end=2e-2,
25
- ):
26
- if schedule == "linear":
27
- betas = (
28
- torch.linspace(
29
- linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
30
- )
31
- ** 2
32
- )
33
- return betas.numpy()
34
-
35
-
36
- def extract_into_tensor(a, t, x_shape):
37
- b, *_ = t.shape
38
- out = a.gather(-1, t)
39
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
40
-
41
-
42
- def mixed_checkpoint(func, inputs: dict, params, flag):
43
- """
44
- Evaluate a function without caching intermediate activations, allowing for
45
- reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
46
- borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
47
- it also works with non-tensor inputs
48
- :param func: the function to evaluate.
49
- :param inputs: the argument dictionary to pass to `func`.
50
- :param params: a sequence of parameters `func` depends on but does not
51
- explicitly take as arguments.
52
- :param flag: if False, disable gradient checkpointing.
53
- """
54
- if flag:
55
- tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
56
- tensor_inputs = [
57
- inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
58
- ]
59
- non_tensor_keys = [
60
- key for key in inputs if not isinstance(inputs[key], torch.Tensor)
61
- ]
62
- non_tensor_inputs = [
63
- inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
64
- ]
65
- args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
66
- return MixedCheckpointFunction.apply(
67
- func,
68
- len(tensor_inputs),
69
- len(non_tensor_inputs),
70
- tensor_keys,
71
- non_tensor_keys,
72
- *args,
73
- )
74
- else:
75
- return func(**inputs)
76
-
77
-
78
- class MixedCheckpointFunction(torch.autograd.Function):
79
- @staticmethod
80
- def forward(
81
- ctx,
82
- run_function,
83
- length_tensors,
84
- length_non_tensors,
85
- tensor_keys,
86
- non_tensor_keys,
87
- *args,
88
- ):
89
- ctx.end_tensors = length_tensors
90
- ctx.end_non_tensors = length_tensors + length_non_tensors
91
- ctx.gpu_autocast_kwargs = {
92
- "enabled": torch.is_autocast_enabled(),
93
- "dtype": torch.get_autocast_gpu_dtype(),
94
- "cache_enabled": torch.is_autocast_cache_enabled(),
95
- }
96
- assert (
97
- len(tensor_keys) == length_tensors
98
- and len(non_tensor_keys) == length_non_tensors
99
- )
100
-
101
- ctx.input_tensors = {
102
- key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
103
- }
104
- ctx.input_non_tensors = {
105
- key: val
106
- for (key, val) in zip(
107
- non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
108
- )
109
- }
110
- ctx.run_function = run_function
111
- ctx.input_params = list(args[ctx.end_non_tensors :])
112
-
113
- with torch.no_grad():
114
- output_tensors = ctx.run_function(
115
- **ctx.input_tensors, **ctx.input_non_tensors
116
- )
117
- return output_tensors
118
-
119
- @staticmethod
120
- def backward(ctx, *output_grads):
121
- # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
122
- ctx.input_tensors = {
123
- key: ctx.input_tensors[key].detach().requires_grad_(True)
124
- for key in ctx.input_tensors
125
- }
126
-
127
- with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
128
- # Fixes a bug where the first op in run_function modifies the
129
- # Tensor storage in place, which is not allowed for detach()'d
130
- # Tensors.
131
- shallow_copies = {
132
- key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
133
- for key in ctx.input_tensors
134
- }
135
- # shallow_copies.update(additional_args)
136
- output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
137
- input_grads = torch.autograd.grad(
138
- output_tensors,
139
- list(ctx.input_tensors.values()) + ctx.input_params,
140
- output_grads,
141
- allow_unused=True,
142
- )
143
- del ctx.input_tensors
144
- del ctx.input_params
145
- del output_tensors
146
- return (
147
- (None, None, None, None, None)
148
- + input_grads[: ctx.end_tensors]
149
- + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
150
- + input_grads[ctx.end_tensors :]
151
- )
152
-
153
-
154
- def checkpoint(func, inputs, params, flag):
155
- """
156
- Evaluate a function without caching intermediate activations, allowing for
157
- reduced memory at the expense of extra compute in the backward pass.
158
- :param func: the function to evaluate.
159
- :param inputs: the argument sequence to pass to `func`.
160
- :param params: a sequence of parameters `func` depends on but does not
161
- explicitly take as arguments.
162
- :param flag: if False, disable gradient checkpointing.
163
- """
164
- if flag:
165
- args = tuple(inputs) + tuple(params)
166
- return CheckpointFunction.apply(func, len(inputs), *args)
167
- else:
168
- return func(*inputs)
169
-
170
-
171
- class CheckpointFunction(torch.autograd.Function):
172
- @staticmethod
173
- def forward(ctx, run_function, length, *args):
174
- ctx.run_function = run_function
175
- ctx.input_tensors = list(args[:length])
176
- ctx.input_params = list(args[length:])
177
- ctx.gpu_autocast_kwargs = {
178
- "enabled": torch.is_autocast_enabled(),
179
- "dtype": torch.get_autocast_gpu_dtype(),
180
- "cache_enabled": torch.is_autocast_cache_enabled(),
181
- }
182
- with torch.no_grad():
183
- output_tensors = ctx.run_function(*ctx.input_tensors)
184
- return output_tensors
185
-
186
- @staticmethod
187
- def backward(ctx, *output_grads):
188
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
189
- with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
190
- # Fixes a bug where the first op in run_function modifies the
191
- # Tensor storage in place, which is not allowed for detach()'d
192
- # Tensors.
193
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
194
- output_tensors = ctx.run_function(*shallow_copies)
195
- input_grads = torch.autograd.grad(
196
- output_tensors,
197
- ctx.input_tensors + ctx.input_params,
198
- output_grads,
199
- allow_unused=True,
200
- )
201
- del ctx.input_tensors
202
- del ctx.input_params
203
- del output_tensors
204
- return (None, None) + input_grads
205
-
206
-
207
- def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
208
- """
209
- Create sinusoidal timestep embeddings.
210
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
211
- These may be fractional.
212
- :param dim: the dimension of the output.
213
- :param max_period: controls the minimum frequency of the embeddings.
214
- :return: an [N x dim] Tensor of positional embeddings.
215
- """
216
- if not repeat_only:
217
- half = dim // 2
218
- freqs = torch.exp(
219
- -math.log(max_period)
220
- * torch.arange(start=0, end=half, dtype=torch.float32)
221
- / half
222
- ).to(device=timesteps.device)
223
- args = timesteps[:, None].float() * freqs[None]
224
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
225
- if dim % 2:
226
- embedding = torch.cat(
227
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
228
- )
229
- else:
230
- embedding = repeat(timesteps, "b -> b d", d=dim)
231
- return embedding
232
-
233
-
234
- def zero_module(module):
235
- """
236
- Zero out the parameters of a module and return it.
237
- """
238
- for p in module.parameters():
239
- p.detach().zero_()
240
- return module
241
-
242
-
243
- def scale_module(module, scale):
244
- """
245
- Scale the parameters of a module and return it.
246
- """
247
- for p in module.parameters():
248
- p.detach().mul_(scale)
249
- return module
250
-
251
-
252
- def mean_flat(tensor):
253
- """
254
- Take the mean over all non-batch dimensions.
255
- """
256
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
257
-
258
-
259
- def normalization(channels):
260
- """
261
- Make a standard normalization layer.
262
- :param channels: number of input channels.
263
- :return: an nn.Module for normalization.
264
- """
265
- return GroupNorm32(32, channels)
266
-
267
-
268
- # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
269
- class SiLU(nn.Module):
270
- def forward(self, x):
271
- return x * torch.sigmoid(x)
272
-
273
-
274
- class GroupNorm32(nn.GroupNorm):
275
- def forward(self, x):
276
- return super().forward(x.float()).type(x.dtype)
277
-
278
-
279
- def conv_nd(dims, *args, **kwargs):
280
- """
281
- Create a 1D, 2D, or 3D convolution module.
282
- """
283
- if dims == 1:
284
- return nn.Conv1d(*args, **kwargs)
285
- elif dims == 2:
286
- return nn.Conv2d(*args, **kwargs)
287
- elif dims == 3:
288
- return nn.Conv3d(*args, **kwargs)
289
- raise ValueError(f"unsupported dimensions: {dims}")
290
-
291
-
292
- def linear(*args, **kwargs):
293
- """
294
- Create a linear module.
295
- """
296
- return nn.Linear(*args, **kwargs)
297
-
298
-
299
- def avg_pool_nd(dims, *args, **kwargs):
300
- """
301
- Create a 1D, 2D, or 3D average pooling module.
302
- """
303
- if dims == 1:
304
- return nn.AvgPool1d(*args, **kwargs)
305
- elif dims == 2:
306
- return nn.AvgPool2d(*args, **kwargs)
307
- elif dims == 3:
308
- return nn.AvgPool3d(*args, **kwargs)
309
- raise ValueError(f"unsupported dimensions: {dims}")
310
-
311
-
312
- class AlphaBlender(nn.Module):
313
- strategies = ["learned", "fixed", "learned_with_images"]
314
-
315
- def __init__(
316
- self,
317
- alpha: float,
318
- merge_strategy: str = "learned_with_images",
319
- rearrange_pattern: str = "b t -> (b t) 1 1",
320
- ):
321
- super().__init__()
322
- self.merge_strategy = merge_strategy
323
- self.rearrange_pattern = rearrange_pattern
324
-
325
- assert (
326
- merge_strategy in self.strategies
327
- ), f"merge_strategy needs to be in {self.strategies}"
328
-
329
- if self.merge_strategy == "fixed":
330
- self.register_buffer("mix_factor", torch.Tensor([alpha]))
331
- elif (
332
- self.merge_strategy == "learned"
333
- or self.merge_strategy == "learned_with_images"
334
- ):
335
- self.register_parameter(
336
- "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
337
- )
338
- else:
339
- raise ValueError(f"unknown merge strategy {self.merge_strategy}")
340
-
341
- def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
342
- if self.merge_strategy == "fixed":
343
- alpha = self.mix_factor
344
- elif self.merge_strategy == "learned":
345
- alpha = torch.sigmoid(self.mix_factor)
346
- elif self.merge_strategy == "learned_with_images":
347
- assert image_only_indicator is not None, "need image_only_indicator ..."
348
- alpha = torch.where(
349
- image_only_indicator.bool(),
350
- torch.ones(1, 1, device=image_only_indicator.device),
351
- rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
352
- )
353
- alpha = rearrange(alpha, self.rearrange_pattern)
354
- else:
355
- raise NotImplementedError
356
- return alpha
357
-
358
- def forward(
359
- self,
360
- x_spatial: torch.Tensor,
361
- x_temporal: torch.Tensor,
362
- image_only_indicator: Optional[torch.Tensor] = None,
363
- ) -> torch.Tensor:
364
- alpha = self.get_alpha(image_only_indicator)
365
- x = (
366
- alpha.to(x_spatial.dtype) * x_spatial
367
- + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
368
- )
369
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/video_model.py DELETED
@@ -1,493 +0,0 @@
1
- from functools import partial
2
- from typing import List, Optional, Union
3
-
4
- from einops import rearrange
5
-
6
- from ...modules.diffusionmodules.openaimodel import *
7
- from ...modules.video_attention import SpatialVideoTransformer
8
- from ...util import default
9
- from .util import AlphaBlender
10
-
11
-
12
- class VideoResBlock(ResBlock):
13
- def __init__(
14
- self,
15
- channels: int,
16
- emb_channels: int,
17
- dropout: float,
18
- video_kernel_size: Union[int, List[int]] = 3,
19
- merge_strategy: str = "fixed",
20
- merge_factor: float = 0.5,
21
- out_channels: Optional[int] = None,
22
- use_conv: bool = False,
23
- use_scale_shift_norm: bool = False,
24
- dims: int = 2,
25
- use_checkpoint: bool = False,
26
- up: bool = False,
27
- down: bool = False,
28
- ):
29
- super().__init__(
30
- channels,
31
- emb_channels,
32
- dropout,
33
- out_channels=out_channels,
34
- use_conv=use_conv,
35
- use_scale_shift_norm=use_scale_shift_norm,
36
- dims=dims,
37
- use_checkpoint=use_checkpoint,
38
- up=up,
39
- down=down,
40
- )
41
-
42
- self.time_stack = ResBlock(
43
- default(out_channels, channels),
44
- emb_channels,
45
- dropout=dropout,
46
- dims=3,
47
- out_channels=default(out_channels, channels),
48
- use_scale_shift_norm=False,
49
- use_conv=False,
50
- up=False,
51
- down=False,
52
- kernel_size=video_kernel_size,
53
- use_checkpoint=use_checkpoint,
54
- exchange_temb_dims=True,
55
- )
56
- self.time_mixer = AlphaBlender(
57
- alpha=merge_factor,
58
- merge_strategy=merge_strategy,
59
- rearrange_pattern="b t -> b 1 t 1 1",
60
- )
61
-
62
- def forward(
63
- self,
64
- x: th.Tensor,
65
- emb: th.Tensor,
66
- num_video_frames: int,
67
- image_only_indicator: Optional[th.Tensor] = None,
68
- ) -> th.Tensor:
69
- x = super().forward(x, emb)
70
-
71
- x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
72
- x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
73
-
74
- x = self.time_stack(
75
- x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
76
- )
77
- x = self.time_mixer(
78
- x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
79
- )
80
- x = rearrange(x, "b c t h w -> (b t) c h w")
81
- return x
82
-
83
-
84
- class VideoUNet(nn.Module):
85
- def __init__(
86
- self,
87
- in_channels: int,
88
- model_channels: int,
89
- out_channels: int,
90
- num_res_blocks: int,
91
- attention_resolutions: int,
92
- dropout: float = 0.0,
93
- channel_mult: List[int] = (1, 2, 4, 8),
94
- conv_resample: bool = True,
95
- dims: int = 2,
96
- num_classes: Optional[int] = None,
97
- use_checkpoint: bool = False,
98
- num_heads: int = -1,
99
- num_head_channels: int = -1,
100
- num_heads_upsample: int = -1,
101
- use_scale_shift_norm: bool = False,
102
- resblock_updown: bool = False,
103
- transformer_depth: Union[List[int], int] = 1,
104
- transformer_depth_middle: Optional[int] = None,
105
- context_dim: Optional[int] = None,
106
- time_downup: bool = False,
107
- time_context_dim: Optional[int] = None,
108
- extra_ff_mix_layer: bool = False,
109
- use_spatial_context: bool = False,
110
- merge_strategy: str = "fixed",
111
- merge_factor: float = 0.5,
112
- spatial_transformer_attn_type: str = "softmax",
113
- video_kernel_size: Union[int, List[int]] = 3,
114
- use_linear_in_transformer: bool = False,
115
- adm_in_channels: Optional[int] = None,
116
- disable_temporal_crossattention: bool = False,
117
- max_ddpm_temb_period: int = 10000,
118
- ):
119
- super().__init__()
120
- assert context_dim is not None
121
-
122
- if num_heads_upsample == -1:
123
- num_heads_upsample = num_heads
124
-
125
- if num_heads == -1:
126
- assert num_head_channels != -1
127
-
128
- if num_head_channels == -1:
129
- assert num_heads != -1
130
-
131
- self.in_channels = in_channels
132
- self.model_channels = model_channels
133
- self.out_channels = out_channels
134
- if isinstance(transformer_depth, int):
135
- transformer_depth = len(channel_mult) * [transformer_depth]
136
- transformer_depth_middle = default(
137
- transformer_depth_middle, transformer_depth[-1]
138
- )
139
-
140
- self.num_res_blocks = num_res_blocks
141
- self.attention_resolutions = attention_resolutions
142
- self.dropout = dropout
143
- self.channel_mult = channel_mult
144
- self.conv_resample = conv_resample
145
- self.num_classes = num_classes
146
- self.use_checkpoint = use_checkpoint
147
- self.num_heads = num_heads
148
- self.num_head_channels = num_head_channels
149
- self.num_heads_upsample = num_heads_upsample
150
-
151
- time_embed_dim = model_channels * 4
152
- self.time_embed = nn.Sequential(
153
- linear(model_channels, time_embed_dim),
154
- nn.SiLU(),
155
- linear(time_embed_dim, time_embed_dim),
156
- )
157
-
158
- if self.num_classes is not None:
159
- if isinstance(self.num_classes, int):
160
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
161
- elif self.num_classes == "continuous":
162
- print("setting up linear c_adm embedding layer")
163
- self.label_emb = nn.Linear(1, time_embed_dim)
164
- elif self.num_classes == "timestep":
165
- self.label_emb = nn.Sequential(
166
- Timestep(model_channels),
167
- nn.Sequential(
168
- linear(model_channels, time_embed_dim),
169
- nn.SiLU(),
170
- linear(time_embed_dim, time_embed_dim),
171
- ),
172
- )
173
-
174
- elif self.num_classes == "sequential":
175
- assert adm_in_channels is not None
176
- self.label_emb = nn.Sequential(
177
- nn.Sequential(
178
- linear(adm_in_channels, time_embed_dim),
179
- nn.SiLU(),
180
- linear(time_embed_dim, time_embed_dim),
181
- )
182
- )
183
- else:
184
- raise ValueError()
185
-
186
- self.input_blocks = nn.ModuleList(
187
- [
188
- TimestepEmbedSequential(
189
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
190
- )
191
- ]
192
- )
193
- self._feature_size = model_channels
194
- input_block_chans = [model_channels]
195
- ch = model_channels
196
- ds = 1
197
-
198
- def get_attention_layer(
199
- ch,
200
- num_heads,
201
- dim_head,
202
- depth=1,
203
- context_dim=None,
204
- use_checkpoint=False,
205
- disabled_sa=False,
206
- ):
207
- return SpatialVideoTransformer(
208
- ch,
209
- num_heads,
210
- dim_head,
211
- depth=depth,
212
- context_dim=context_dim,
213
- time_context_dim=time_context_dim,
214
- dropout=dropout,
215
- ff_in=extra_ff_mix_layer,
216
- use_spatial_context=use_spatial_context,
217
- merge_strategy=merge_strategy,
218
- merge_factor=merge_factor,
219
- checkpoint=use_checkpoint,
220
- use_linear=use_linear_in_transformer,
221
- attn_mode=spatial_transformer_attn_type,
222
- disable_self_attn=disabled_sa,
223
- disable_temporal_crossattention=disable_temporal_crossattention,
224
- max_time_embed_period=max_ddpm_temb_period,
225
- )
226
-
227
- def get_resblock(
228
- merge_factor,
229
- merge_strategy,
230
- video_kernel_size,
231
- ch,
232
- time_embed_dim,
233
- dropout,
234
- out_ch,
235
- dims,
236
- use_checkpoint,
237
- use_scale_shift_norm,
238
- down=False,
239
- up=False,
240
- ):
241
- return VideoResBlock(
242
- merge_factor=merge_factor,
243
- merge_strategy=merge_strategy,
244
- video_kernel_size=video_kernel_size,
245
- channels=ch,
246
- emb_channels=time_embed_dim,
247
- dropout=dropout,
248
- out_channels=out_ch,
249
- dims=dims,
250
- use_checkpoint=use_checkpoint,
251
- use_scale_shift_norm=use_scale_shift_norm,
252
- down=down,
253
- up=up,
254
- )
255
-
256
- for level, mult in enumerate(channel_mult):
257
- for _ in range(num_res_blocks):
258
- layers = [
259
- get_resblock(
260
- merge_factor=merge_factor,
261
- merge_strategy=merge_strategy,
262
- video_kernel_size=video_kernel_size,
263
- ch=ch,
264
- time_embed_dim=time_embed_dim,
265
- dropout=dropout,
266
- out_ch=mult * model_channels,
267
- dims=dims,
268
- use_checkpoint=use_checkpoint,
269
- use_scale_shift_norm=use_scale_shift_norm,
270
- )
271
- ]
272
- ch = mult * model_channels
273
- if ds in attention_resolutions:
274
- if num_head_channels == -1:
275
- dim_head = ch // num_heads
276
- else:
277
- num_heads = ch // num_head_channels
278
- dim_head = num_head_channels
279
-
280
- layers.append(
281
- get_attention_layer(
282
- ch,
283
- num_heads,
284
- dim_head,
285
- depth=transformer_depth[level],
286
- context_dim=context_dim,
287
- use_checkpoint=use_checkpoint,
288
- disabled_sa=False,
289
- )
290
- )
291
- self.input_blocks.append(TimestepEmbedSequential(*layers))
292
- self._feature_size += ch
293
- input_block_chans.append(ch)
294
- if level != len(channel_mult) - 1:
295
- ds *= 2
296
- out_ch = ch
297
- self.input_blocks.append(
298
- TimestepEmbedSequential(
299
- get_resblock(
300
- merge_factor=merge_factor,
301
- merge_strategy=merge_strategy,
302
- video_kernel_size=video_kernel_size,
303
- ch=ch,
304
- time_embed_dim=time_embed_dim,
305
- dropout=dropout,
306
- out_ch=out_ch,
307
- dims=dims,
308
- use_checkpoint=use_checkpoint,
309
- use_scale_shift_norm=use_scale_shift_norm,
310
- down=True,
311
- )
312
- if resblock_updown
313
- else Downsample(
314
- ch,
315
- conv_resample,
316
- dims=dims,
317
- out_channels=out_ch,
318
- third_down=time_downup,
319
- )
320
- )
321
- )
322
- ch = out_ch
323
- input_block_chans.append(ch)
324
-
325
- self._feature_size += ch
326
-
327
- if num_head_channels == -1:
328
- dim_head = ch // num_heads
329
- else:
330
- num_heads = ch // num_head_channels
331
- dim_head = num_head_channels
332
-
333
- self.middle_block = TimestepEmbedSequential(
334
- get_resblock(
335
- merge_factor=merge_factor,
336
- merge_strategy=merge_strategy,
337
- video_kernel_size=video_kernel_size,
338
- ch=ch,
339
- time_embed_dim=time_embed_dim,
340
- out_ch=None,
341
- dropout=dropout,
342
- dims=dims,
343
- use_checkpoint=use_checkpoint,
344
- use_scale_shift_norm=use_scale_shift_norm,
345
- ),
346
- get_attention_layer(
347
- ch,
348
- num_heads,
349
- dim_head,
350
- depth=transformer_depth_middle,
351
- context_dim=context_dim,
352
- use_checkpoint=use_checkpoint,
353
- ),
354
- get_resblock(
355
- merge_factor=merge_factor,
356
- merge_strategy=merge_strategy,
357
- video_kernel_size=video_kernel_size,
358
- ch=ch,
359
- out_ch=None,
360
- time_embed_dim=time_embed_dim,
361
- dropout=dropout,
362
- dims=dims,
363
- use_checkpoint=use_checkpoint,
364
- use_scale_shift_norm=use_scale_shift_norm,
365
- ),
366
- )
367
- self._feature_size += ch
368
-
369
- self.output_blocks = nn.ModuleList([])
370
- for level, mult in list(enumerate(channel_mult))[::-1]:
371
- for i in range(num_res_blocks + 1):
372
- ich = input_block_chans.pop()
373
- layers = [
374
- get_resblock(
375
- merge_factor=merge_factor,
376
- merge_strategy=merge_strategy,
377
- video_kernel_size=video_kernel_size,
378
- ch=ch + ich,
379
- time_embed_dim=time_embed_dim,
380
- dropout=dropout,
381
- out_ch=model_channels * mult,
382
- dims=dims,
383
- use_checkpoint=use_checkpoint,
384
- use_scale_shift_norm=use_scale_shift_norm,
385
- )
386
- ]
387
- ch = model_channels * mult
388
- if ds in attention_resolutions:
389
- if num_head_channels == -1:
390
- dim_head = ch // num_heads
391
- else:
392
- num_heads = ch // num_head_channels
393
- dim_head = num_head_channels
394
-
395
- layers.append(
396
- get_attention_layer(
397
- ch,
398
- num_heads,
399
- dim_head,
400
- depth=transformer_depth[level],
401
- context_dim=context_dim,
402
- use_checkpoint=use_checkpoint,
403
- disabled_sa=False,
404
- )
405
- )
406
- if level and i == num_res_blocks:
407
- out_ch = ch
408
- ds //= 2
409
- layers.append(
410
- get_resblock(
411
- merge_factor=merge_factor,
412
- merge_strategy=merge_strategy,
413
- video_kernel_size=video_kernel_size,
414
- ch=ch,
415
- time_embed_dim=time_embed_dim,
416
- dropout=dropout,
417
- out_ch=out_ch,
418
- dims=dims,
419
- use_checkpoint=use_checkpoint,
420
- use_scale_shift_norm=use_scale_shift_norm,
421
- up=True,
422
- )
423
- if resblock_updown
424
- else Upsample(
425
- ch,
426
- conv_resample,
427
- dims=dims,
428
- out_channels=out_ch,
429
- third_up=time_downup,
430
- )
431
- )
432
-
433
- self.output_blocks.append(TimestepEmbedSequential(*layers))
434
- self._feature_size += ch
435
-
436
- self.out = nn.Sequential(
437
- normalization(ch),
438
- nn.SiLU(),
439
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
440
- )
441
-
442
- def forward(
443
- self,
444
- x: th.Tensor,
445
- timesteps: th.Tensor,
446
- context: Optional[th.Tensor] = None,
447
- y: Optional[th.Tensor] = None,
448
- time_context: Optional[th.Tensor] = None,
449
- num_video_frames: Optional[int] = None,
450
- image_only_indicator: Optional[th.Tensor] = None,
451
- ):
452
- assert (y is not None) == (
453
- self.num_classes is not None
454
- ), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
455
- hs = []
456
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
457
- emb = self.time_embed(t_emb)
458
-
459
- if self.num_classes is not None:
460
- assert y.shape[0] == x.shape[0]
461
- emb = emb + self.label_emb(y)
462
-
463
- h = x
464
- for module in self.input_blocks:
465
- h = module(
466
- h,
467
- emb,
468
- context=context,
469
- image_only_indicator=image_only_indicator,
470
- time_context=time_context,
471
- num_video_frames=num_video_frames,
472
- )
473
- hs.append(h)
474
- h = self.middle_block(
475
- h,
476
- emb,
477
- context=context,
478
- image_only_indicator=image_only_indicator,
479
- time_context=time_context,
480
- num_video_frames=num_video_frames,
481
- )
482
- for module in self.output_blocks:
483
- h = th.cat([h, hs.pop()], dim=1)
484
- h = module(
485
- h,
486
- emb,
487
- context=context,
488
- image_only_indicator=image_only_indicator,
489
- time_context=time_context,
490
- num_video_frames=num_video_frames,
491
- )
492
- h = h.type(x.dtype)
493
- return self.out(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/diffusionmodules/wrappers.py DELETED
@@ -1,34 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from packaging import version
4
-
5
- OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
6
-
7
-
8
- class IdentityWrapper(nn.Module):
9
- def __init__(self, diffusion_model, compile_model: bool = False):
10
- super().__init__()
11
- compile = (
12
- torch.compile
13
- if (version.parse(torch.__version__) >= version.parse("2.0.0"))
14
- and compile_model
15
- else lambda x: x
16
- )
17
- self.diffusion_model = compile(diffusion_model)
18
-
19
- def forward(self, *args, **kwargs):
20
- return self.diffusion_model(*args, **kwargs)
21
-
22
-
23
- class OpenAIWrapper(IdentityWrapper):
24
- def forward(
25
- self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
26
- ) -> torch.Tensor:
27
- x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
28
- return self.diffusion_model(
29
- x,
30
- timesteps=t,
31
- context=c.get("crossattn", None),
32
- y=c.get("vector", None),
33
- **kwargs,
34
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/distributions/__init__.py DELETED
File without changes
sgm/modules/distributions/distributions.py DELETED
@@ -1,102 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
-
5
- class AbstractDistribution:
6
- def sample(self):
7
- raise NotImplementedError()
8
-
9
- def mode(self):
10
- raise NotImplementedError()
11
-
12
-
13
- class DiracDistribution(AbstractDistribution):
14
- def __init__(self, value):
15
- self.value = value
16
-
17
- def sample(self):
18
- return self.value
19
-
20
- def mode(self):
21
- return self.value
22
-
23
-
24
- class DiagonalGaussianDistribution(object):
25
- def __init__(self, parameters, deterministic=False):
26
- self.parameters = parameters
27
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
- self.deterministic = deterministic
30
- self.std = torch.exp(0.5 * self.logvar)
31
- self.var = torch.exp(self.logvar)
32
- if self.deterministic:
33
- self.var = self.std = torch.zeros_like(self.mean).to(
34
- device=self.parameters.device
35
- )
36
-
37
- def sample(self):
38
- x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
- device=self.parameters.device
40
- )
41
- return x
42
-
43
- def kl(self, other=None):
44
- if self.deterministic:
45
- return torch.Tensor([0.0])
46
- else:
47
- if other is None:
48
- return 0.5 * torch.sum(
49
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
- dim=[1, 2, 3],
51
- )
52
- else:
53
- return 0.5 * torch.sum(
54
- torch.pow(self.mean - other.mean, 2) / other.var
55
- + self.var / other.var
56
- - 1.0
57
- - self.logvar
58
- + other.logvar,
59
- dim=[1, 2, 3],
60
- )
61
-
62
- def nll(self, sample, dims=[1, 2, 3]):
63
- if self.deterministic:
64
- return torch.Tensor([0.0])
65
- logtwopi = np.log(2.0 * np.pi)
66
- return 0.5 * torch.sum(
67
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
- dim=dims,
69
- )
70
-
71
- def mode(self):
72
- return self.mean
73
-
74
-
75
- def normal_kl(mean1, logvar1, mean2, logvar2):
76
- """
77
- source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
- Compute the KL divergence between two gaussians.
79
- Shapes are automatically broadcasted, so batches can be compared to
80
- scalars, among other use cases.
81
- """
82
- tensor = None
83
- for obj in (mean1, logvar1, mean2, logvar2):
84
- if isinstance(obj, torch.Tensor):
85
- tensor = obj
86
- break
87
- assert tensor is not None, "at least one argument must be a Tensor"
88
-
89
- # Force variances to be Tensors. Broadcasting helps convert scalars to
90
- # Tensors, but it does not work for torch.exp().
91
- logvar1, logvar2 = [
92
- x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
- for x in (logvar1, logvar2)
94
- ]
95
-
96
- return 0.5 * (
97
- -1.0
98
- + logvar2
99
- - logvar1
100
- + torch.exp(logvar1 - logvar2)
101
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sgm/modules/ema.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class LitEma(nn.Module):
6
- def __init__(self, model, decay=0.9999, use_num_upates=True):
7
- super().__init__()
8
- if decay < 0.0 or decay > 1.0:
9
- raise ValueError("Decay must be between 0 and 1")
10
-
11
- self.m_name2s_name = {}
12
- self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
- self.register_buffer(
14
- "num_updates",
15
- torch.tensor(0, dtype=torch.int)
16
- if use_num_upates
17
- else torch.tensor(-1, dtype=torch.int),
18
- )
19
-
20
- for name, p in model.named_parameters():
21
- if p.requires_grad:
22
- # remove as '.'-character is not allowed in buffers
23
- s_name = name.replace(".", "")
24
- self.m_name2s_name.update({name: s_name})
25
- self.register_buffer(s_name, p.clone().detach().data)
26
-
27
- self.collected_params = []
28
-
29
- def reset_num_updates(self):
30
- del self.num_updates
31
- self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32
-
33
- def forward(self, model):
34
- decay = self.decay
35
-
36
- if self.num_updates >= 0:
37
- self.num_updates += 1
38
- decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39
-
40
- one_minus_decay = 1.0 - decay
41
-
42
- with torch.no_grad():
43
- m_param = dict(model.named_parameters())
44
- shadow_params = dict(self.named_buffers())
45
-
46
- for key in m_param:
47
- if m_param[key].requires_grad:
48
- sname = self.m_name2s_name[key]
49
- shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50
- shadow_params[sname].sub_(
51
- one_minus_decay * (shadow_params[sname] - m_param[key])
52
- )
53
- else:
54
- assert not key in self.m_name2s_name
55
-
56
- def copy_to(self, model):
57
- m_param = dict(model.named_parameters())
58
- shadow_params = dict(self.named_buffers())
59
- for key in m_param:
60
- if m_param[key].requires_grad:
61
- m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62
- else:
63
- assert not key in self.m_name2s_name
64
-
65
- def store(self, parameters):
66
- """
67
- Save the current parameters for restoring later.
68
- Args:
69
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
- temporarily stored.
71
- """
72
- self.collected_params = [param.clone() for param in parameters]
73
-
74
- def restore(self, parameters):
75
- """
76
- Restore the parameters stored with the `store` method.
77
- Useful to validate the model with EMA parameters without affecting the
78
- original optimization process. Store the parameters before the
79
- `copy_to` method. After validation (or model saving), use this to
80
- restore the former parameters.
81
- Args:
82
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83
- updated with the stored parameters.
84
- """
85
- for c_param, param in zip(self.collected_params, parameters):
86
- param.data.copy_(c_param.data)