ML-INTA commited on
Commit
7f43c1b
1 Parent(s): 42c7345

Upload 358 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffusers/__init__.py +204 -0
  2. diffusers/__pycache__/__init__.cpython-39.pyc +0 -0
  3. diffusers/__pycache__/configuration_utils.cpython-39.pyc +0 -0
  4. diffusers/commands/__init__.py +27 -0
  5. diffusers/commands/__pycache__/__init__.cpython-311.pyc +0 -0
  6. diffusers/commands/__pycache__/diffusers_cli.cpython-311.pyc +0 -0
  7. diffusers/commands/__pycache__/env.cpython-311.pyc +0 -0
  8. diffusers/commands/diffusers_cli.py +41 -0
  9. diffusers/commands/env.py +84 -0
  10. diffusers/configuration_utils.py +615 -0
  11. diffusers/dependency_versions_check.py +47 -0
  12. diffusers/dependency_versions_table.py +35 -0
  13. diffusers/experimental/__init__.py +1 -0
  14. diffusers/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
  15. diffusers/experimental/rl/__init__.py +1 -0
  16. diffusers/experimental/rl/__pycache__/__init__.cpython-311.pyc +0 -0
  17. diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-311.pyc +0 -0
  18. diffusers/experimental/rl/value_guided_sampling.py +152 -0
  19. diffusers/loaders.py +243 -0
  20. diffusers/models/__init__.py +32 -0
  21. diffusers/models/__pycache__/__init__.cpython-311.pyc +0 -0
  22. diffusers/models/__pycache__/attention.cpython-311.pyc +0 -0
  23. diffusers/models/__pycache__/attention_flax.cpython-311.pyc +0 -0
  24. diffusers/models/__pycache__/autoencoder_kl.cpython-311.pyc +0 -0
  25. diffusers/models/__pycache__/controlnet.cpython-311.pyc +0 -0
  26. diffusers/models/__pycache__/cross_attention.cpython-311.pyc +0 -0
  27. diffusers/models/__pycache__/dual_transformer_2d.cpython-311.pyc +0 -0
  28. diffusers/models/__pycache__/embeddings.cpython-311.pyc +0 -0
  29. diffusers/models/__pycache__/embeddings_flax.cpython-311.pyc +0 -0
  30. diffusers/models/__pycache__/modeling_flax_pytorch_utils.cpython-311.pyc +0 -0
  31. diffusers/models/__pycache__/modeling_flax_utils.cpython-311.pyc +0 -0
  32. diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-311.pyc +0 -0
  33. diffusers/models/__pycache__/modeling_utils.cpython-311.pyc +0 -0
  34. diffusers/models/__pycache__/prior_transformer.cpython-311.pyc +0 -0
  35. diffusers/models/__pycache__/resnet.cpython-311.pyc +0 -0
  36. diffusers/models/__pycache__/resnet_flax.cpython-311.pyc +0 -0
  37. diffusers/models/__pycache__/transformer_2d.cpython-311.pyc +0 -0
  38. diffusers/models/__pycache__/unet_1d.cpython-311.pyc +0 -0
  39. diffusers/models/__pycache__/unet_1d_blocks.cpython-311.pyc +0 -0
  40. diffusers/models/__pycache__/unet_2d.cpython-311.pyc +0 -0
  41. diffusers/models/__pycache__/unet_2d_blocks.cpython-311.pyc +0 -0
  42. diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-311.pyc +0 -0
  43. diffusers/models/__pycache__/unet_2d_condition.cpython-311.pyc +0 -0
  44. diffusers/models/__pycache__/unet_2d_condition_flax.cpython-311.pyc +0 -0
  45. diffusers/models/__pycache__/vae.cpython-311.pyc +0 -0
  46. diffusers/models/__pycache__/vae_flax.cpython-311.pyc +0 -0
  47. diffusers/models/__pycache__/vq_model.cpython-311.pyc +0 -0
  48. diffusers/models/attention.py +517 -0
  49. diffusers/models/attention_flax.py +302 -0
  50. diffusers/models/autoencoder_kl.py +320 -0
diffusers/__init__.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.14.0"
2
+
3
+ from .configuration_utils import ConfigMixin
4
+ from .utils import (
5
+ OptionalDependencyNotAvailable,
6
+ is_flax_available,
7
+ is_inflect_available,
8
+ is_k_diffusion_available,
9
+ is_k_diffusion_version,
10
+ is_librosa_available,
11
+ is_onnx_available,
12
+ is_scipy_available,
13
+ is_torch_available,
14
+ is_transformers_available,
15
+ is_transformers_version,
16
+ is_unidecode_available,
17
+ logging,
18
+ )
19
+
20
+
21
+ try:
22
+ if not is_onnx_available():
23
+ raise OptionalDependencyNotAvailable()
24
+ except OptionalDependencyNotAvailable:
25
+ from .utils.dummy_onnx_objects import * # noqa F403
26
+ else:
27
+ from .pipelines import OnnxRuntimeModel
28
+
29
+ try:
30
+ if not is_torch_available():
31
+ raise OptionalDependencyNotAvailable()
32
+ except OptionalDependencyNotAvailable:
33
+ from .utils.dummy_pt_objects import * # noqa F403
34
+ else:
35
+ from .models import (
36
+ AutoencoderKL,
37
+ ControlNetModel,
38
+ ModelMixin,
39
+ PriorTransformer,
40
+ Transformer2DModel,
41
+ UNet1DModel,
42
+ UNet2DConditionModel,
43
+ UNet2DModel,
44
+ VQModel,
45
+ )
46
+ from .optimization import (
47
+ get_constant_schedule,
48
+ get_constant_schedule_with_warmup,
49
+ get_cosine_schedule_with_warmup,
50
+ get_cosine_with_hard_restarts_schedule_with_warmup,
51
+ get_linear_schedule_with_warmup,
52
+ get_polynomial_decay_schedule_with_warmup,
53
+ get_scheduler,
54
+ )
55
+ from .pipelines import (
56
+ AudioPipelineOutput,
57
+ DanceDiffusionPipeline,
58
+ DDIMPipeline,
59
+ DDPMPipeline,
60
+ DiffusionPipeline,
61
+ DiTPipeline,
62
+ ImagePipelineOutput,
63
+ KarrasVePipeline,
64
+ LDMPipeline,
65
+ LDMSuperResolutionPipeline,
66
+ PNDMPipeline,
67
+ RePaintPipeline,
68
+ ScoreSdeVePipeline,
69
+ )
70
+ from .schedulers import (
71
+ DDIMInverseScheduler,
72
+ DDIMScheduler,
73
+ DDPMScheduler,
74
+ DEISMultistepScheduler,
75
+ DPMSolverMultistepScheduler,
76
+ DPMSolverSinglestepScheduler,
77
+ EulerAncestralDiscreteScheduler,
78
+ EulerDiscreteScheduler,
79
+ HeunDiscreteScheduler,
80
+ IPNDMScheduler,
81
+ KarrasVeScheduler,
82
+ KDPM2AncestralDiscreteScheduler,
83
+ KDPM2DiscreteScheduler,
84
+ PNDMScheduler,
85
+ RePaintScheduler,
86
+ SchedulerMixin,
87
+ ScoreSdeVeScheduler,
88
+ UnCLIPScheduler,
89
+ UniPCMultistepScheduler,
90
+ VQDiffusionScheduler,
91
+ )
92
+ from .training_utils import EMAModel
93
+
94
+ try:
95
+ if not (is_torch_available() and is_scipy_available()):
96
+ raise OptionalDependencyNotAvailable()
97
+ except OptionalDependencyNotAvailable:
98
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
99
+ else:
100
+ from .schedulers import LMSDiscreteScheduler
101
+
102
+
103
+ try:
104
+ if not (is_torch_available() and is_transformers_available()):
105
+ raise OptionalDependencyNotAvailable()
106
+ except OptionalDependencyNotAvailable:
107
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
108
+ else:
109
+ from .pipelines import (
110
+ AltDiffusionImg2ImgPipeline,
111
+ AltDiffusionPipeline,
112
+ CycleDiffusionPipeline,
113
+ LDMTextToImagePipeline,
114
+ PaintByExamplePipeline,
115
+ SemanticStableDiffusionPipeline,
116
+ StableDiffusionAttendAndExcitePipeline,
117
+ StableDiffusionControlNetPipeline,
118
+ StableDiffusionDepth2ImgPipeline,
119
+ StableDiffusionImageVariationPipeline,
120
+ StableDiffusionImg2ImgPipeline,
121
+ StableDiffusionInpaintPipeline,
122
+ StableDiffusionInpaintPipelineLegacy,
123
+ StableDiffusionInstructPix2PixPipeline,
124
+ StableDiffusionLatentUpscalePipeline,
125
+ StableDiffusionPanoramaPipeline,
126
+ StableDiffusionPipeline,
127
+ StableDiffusionPipelineSafe,
128
+ StableDiffusionPix2PixZeroPipeline,
129
+ StableDiffusionSAGPipeline,
130
+ StableDiffusionUpscalePipeline,
131
+ StableUnCLIPImg2ImgPipeline,
132
+ StableUnCLIPPipeline,
133
+ UnCLIPImageVariationPipeline,
134
+ UnCLIPPipeline,
135
+ VersatileDiffusionDualGuidedPipeline,
136
+ VersatileDiffusionImageVariationPipeline,
137
+ VersatileDiffusionPipeline,
138
+ VersatileDiffusionTextToImagePipeline,
139
+ VQDiffusionPipeline,
140
+ )
141
+
142
+ try:
143
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
144
+ raise OptionalDependencyNotAvailable()
145
+ except OptionalDependencyNotAvailable:
146
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
147
+ else:
148
+ from .pipelines import StableDiffusionKDiffusionPipeline
149
+
150
+ try:
151
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
152
+ raise OptionalDependencyNotAvailable()
153
+ except OptionalDependencyNotAvailable:
154
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
155
+ else:
156
+ from .pipelines import (
157
+ OnnxStableDiffusionImg2ImgPipeline,
158
+ OnnxStableDiffusionInpaintPipeline,
159
+ OnnxStableDiffusionInpaintPipelineLegacy,
160
+ OnnxStableDiffusionPipeline,
161
+ StableDiffusionOnnxPipeline,
162
+ )
163
+
164
+ try:
165
+ if not (is_torch_available() and is_librosa_available()):
166
+ raise OptionalDependencyNotAvailable()
167
+ except OptionalDependencyNotAvailable:
168
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
169
+ else:
170
+ from .pipelines import AudioDiffusionPipeline, Mel
171
+
172
+ try:
173
+ if not is_flax_available():
174
+ raise OptionalDependencyNotAvailable()
175
+ except OptionalDependencyNotAvailable:
176
+ from .utils.dummy_flax_objects import * # noqa F403
177
+ else:
178
+ from .models.modeling_flax_utils import FlaxModelMixin
179
+ from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
180
+ from .models.vae_flax import FlaxAutoencoderKL
181
+ from .pipelines import FlaxDiffusionPipeline
182
+ from .schedulers import (
183
+ FlaxDDIMScheduler,
184
+ FlaxDDPMScheduler,
185
+ FlaxDPMSolverMultistepScheduler,
186
+ FlaxKarrasVeScheduler,
187
+ FlaxLMSDiscreteScheduler,
188
+ FlaxPNDMScheduler,
189
+ FlaxSchedulerMixin,
190
+ FlaxScoreSdeVeScheduler,
191
+ )
192
+
193
+
194
+ try:
195
+ if not (is_flax_available() and is_transformers_available()):
196
+ raise OptionalDependencyNotAvailable()
197
+ except OptionalDependencyNotAvailable:
198
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
199
+ else:
200
+ from .pipelines import (
201
+ FlaxStableDiffusionImg2ImgPipeline,
202
+ FlaxStableDiffusionInpaintPipeline,
203
+ FlaxStableDiffusionPipeline,
204
+ )
diffusers/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (5.5 kB). View file
 
diffusers/__pycache__/configuration_utils.cpython-39.pyc ADDED
Binary file (22.1 kB). View file
 
diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
diffusers/commands/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.11 kB). View file
 
diffusers/commands/__pycache__/diffusers_cli.cpython-311.pyc ADDED
Binary file (1.28 kB). View file
 
diffusers/commands/__pycache__/env.cpython-311.pyc ADDED
Binary file (3.65 kB). View file
 
diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+
20
+
21
+ def main():
22
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
+
25
+ # Register commands
26
+ EnvironmentCommand.register_subcommand(commands_parser)
27
+
28
+ # Let's go
29
+ args = parser.parse_args()
30
+
31
+ if not hasattr(args, "func"):
32
+ parser.print_help()
33
+ exit(1)
34
+
35
+ # Run
36
+ service = args.func(args)
37
+ service.run()
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
diffusers/commands/env.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available():
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ accelerate_version = "not installed"
53
+ if is_accelerate_available():
54
+ import accelerate
55
+
56
+ accelerate_version = accelerate.__version__
57
+
58
+ xformers_version = "not installed"
59
+ if is_xformers_available():
60
+ import xformers
61
+
62
+ xformers_version = xformers.__version__
63
+
64
+ info = {
65
+ "`diffusers` version": version,
66
+ "Platform": platform.platform(),
67
+ "Python version": platform.python_version(),
68
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69
+ "Huggingface_hub version": hub_version,
70
+ "Transformers version": transformers_version,
71
+ "Accelerate version": accelerate_version,
72
+ "xFormers version": xformers_version,
73
+ "Using GPU in script?": "<fill in>",
74
+ "Using distributed or parallel set-up in script?": "<fill in>",
75
+ }
76
+
77
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78
+ print(self.format_dict(info))
79
+
80
+ return info
81
+
82
+ @staticmethod
83
+ def format_dict(d):
84
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diffusers/configuration_utils.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import dataclasses
18
+ import functools
19
+ import importlib
20
+ import inspect
21
+ import json
22
+ import os
23
+ import re
24
+ from collections import OrderedDict
25
+ from pathlib import PosixPath
26
+ from typing import Any, Dict, Tuple, Union
27
+
28
+ import numpy as np
29
+ from huggingface_hub import hf_hub_download
30
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
+ from requests import HTTPError
32
+
33
+ from . import __version__
34
+ from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
40
+
41
+
42
+ class FrozenDict(OrderedDict):
43
+ def __init__(self, *args, **kwargs):
44
+ super().__init__(*args, **kwargs)
45
+
46
+ for key, value in self.items():
47
+ setattr(self, key, value)
48
+
49
+ self.__frozen = True
50
+
51
+ def __delitem__(self, *args, **kwargs):
52
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
53
+
54
+ def setdefault(self, *args, **kwargs):
55
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
56
+
57
+ def pop(self, *args, **kwargs):
58
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
59
+
60
+ def update(self, *args, **kwargs):
61
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
62
+
63
+ def __setattr__(self, name, value):
64
+ if hasattr(self, "__frozen") and self.__frozen:
65
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
66
+ super().__setattr__(name, value)
67
+
68
+ def __setitem__(self, name, value):
69
+ if hasattr(self, "__frozen") and self.__frozen:
70
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
71
+ super().__setitem__(name, value)
72
+
73
+
74
+ class ConfigMixin:
75
+ r"""
76
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
77
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
78
+ - [`~ConfigMixin.from_config`]
79
+ - [`~ConfigMixin.save_config`]
80
+
81
+ Class attributes:
82
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
83
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
84
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
85
+ overridden by subclass).
86
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
87
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
88
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
89
+ subclass).
90
+ """
91
+ config_name = None
92
+ ignore_for_config = []
93
+ has_compatibles = False
94
+
95
+ _deprecated_kwargs = []
96
+
97
+ def register_to_config(self, **kwargs):
98
+ if self.config_name is None:
99
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
100
+ # Special case for `kwargs` used in deprecation warning added to schedulers
101
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
102
+ # or solve in a more general way.
103
+ kwargs.pop("kwargs", None)
104
+ for key, value in kwargs.items():
105
+ try:
106
+ setattr(self, key, value)
107
+ except AttributeError as err:
108
+ logger.error(f"Can't set {key} with value {value} for {self}")
109
+ raise err
110
+
111
+ if not hasattr(self, "_internal_dict"):
112
+ internal_dict = kwargs
113
+ else:
114
+ previous_dict = dict(self._internal_dict)
115
+ internal_dict = {**self._internal_dict, **kwargs}
116
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
117
+
118
+ self._internal_dict = FrozenDict(internal_dict)
119
+
120
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
121
+ """
122
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
123
+ [`~ConfigMixin.from_config`] class method.
124
+
125
+ Args:
126
+ save_directory (`str` or `os.PathLike`):
127
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
128
+ """
129
+ if os.path.isfile(save_directory):
130
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
131
+
132
+ os.makedirs(save_directory, exist_ok=True)
133
+
134
+ # If we save using the predefined names, we can load using `from_config`
135
+ output_config_file = os.path.join(save_directory, self.config_name)
136
+
137
+ self.to_json_file(output_config_file)
138
+ logger.info(f"Configuration saved in {output_config_file}")
139
+
140
+ @classmethod
141
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
142
+ r"""
143
+ Instantiate a Python class from a config dictionary
144
+
145
+ Parameters:
146
+ config (`Dict[str, Any]`):
147
+ A config dictionary from which the Python class will be instantiated. Make sure to only load
148
+ configuration files of compatible classes.
149
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
150
+ Whether kwargs that are not consumed by the Python class should be returned or not.
151
+
152
+ kwargs (remaining dictionary of keyword arguments, *optional*):
153
+ Can be used to update the configuration object (after it being loaded) and initiate the Python class.
154
+ `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
155
+ overwrite same named arguments of `config`.
156
+
157
+ Examples:
158
+
159
+ ```python
160
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
161
+
162
+ >>> # Download scheduler from huggingface.co and cache.
163
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
164
+
165
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
166
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
167
+
168
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
169
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
170
+ ```
171
+ """
172
+ # <===== TO BE REMOVED WITH DEPRECATION
173
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
174
+ if "pretrained_model_name_or_path" in kwargs:
175
+ config = kwargs.pop("pretrained_model_name_or_path")
176
+
177
+ if config is None:
178
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
179
+ # ======>
180
+
181
+ if not isinstance(config, dict):
182
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
183
+ if "Scheduler" in cls.__name__:
184
+ deprecation_message += (
185
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
186
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
187
+ " be removed in v1.0.0."
188
+ )
189
+ elif "Model" in cls.__name__:
190
+ deprecation_message += (
191
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
192
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
193
+ " instead. This functionality will be removed in v1.0.0."
194
+ )
195
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
196
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
197
+
198
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
199
+
200
+ # Allow dtype to be specified on initialization
201
+ if "dtype" in unused_kwargs:
202
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
203
+
204
+ # add possible deprecated kwargs
205
+ for deprecated_kwarg in cls._deprecated_kwargs:
206
+ if deprecated_kwarg in unused_kwargs:
207
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
208
+
209
+ # Return model and optionally state and/or unused_kwargs
210
+ model = cls(**init_dict)
211
+
212
+ # make sure to also save config parameters that might be used for compatible classes
213
+ model.register_to_config(**hidden_dict)
214
+
215
+ # add hidden kwargs of compatible classes to unused_kwargs
216
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
217
+
218
+ if return_unused_kwargs:
219
+ return (model, unused_kwargs)
220
+ else:
221
+ return model
222
+
223
+ @classmethod
224
+ def get_config_dict(cls, *args, **kwargs):
225
+ deprecation_message = (
226
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
227
+ " removed in version v1.0.0"
228
+ )
229
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
230
+ return cls.load_config(*args, **kwargs)
231
+
232
+ @classmethod
233
+ def load_config(
234
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
235
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
236
+ r"""
237
+ Instantiate a Python class from a config dictionary
238
+
239
+ Parameters:
240
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
241
+ Can be either:
242
+
243
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
244
+ organization name, like `google/ddpm-celebahq-256`.
245
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
246
+ `./my_model_directory/`.
247
+
248
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
249
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
250
+ standard cache should not be used.
251
+ force_download (`bool`, *optional*, defaults to `False`):
252
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
253
+ cached versions if they exist.
254
+ resume_download (`bool`, *optional*, defaults to `False`):
255
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
256
+ file exists.
257
+ proxies (`Dict[str, str]`, *optional*):
258
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
259
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
260
+ output_loading_info(`bool`, *optional*, defaults to `False`):
261
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
262
+ local_files_only(`bool`, *optional*, defaults to `False`):
263
+ Whether or not to only look at local files (i.e., do not try to download the model).
264
+ use_auth_token (`str` or *bool*, *optional*):
265
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
266
+ when running `transformers-cli login` (stored in `~/.huggingface`).
267
+ revision (`str`, *optional*, defaults to `"main"`):
268
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
269
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
270
+ identifier allowed by git.
271
+ subfolder (`str`, *optional*, defaults to `""`):
272
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
273
+ huggingface.co or downloaded locally), you can specify the folder name here.
274
+
275
+ <Tip>
276
+
277
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
278
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
279
+
280
+ </Tip>
281
+
282
+ <Tip>
283
+
284
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
285
+ use this method in a firewalled environment.
286
+
287
+ </Tip>
288
+ """
289
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
290
+ force_download = kwargs.pop("force_download", False)
291
+ resume_download = kwargs.pop("resume_download", False)
292
+ proxies = kwargs.pop("proxies", None)
293
+ use_auth_token = kwargs.pop("use_auth_token", None)
294
+ local_files_only = kwargs.pop("local_files_only", False)
295
+ revision = kwargs.pop("revision", None)
296
+ _ = kwargs.pop("mirror", None)
297
+ subfolder = kwargs.pop("subfolder", None)
298
+
299
+ user_agent = {"file_type": "config"}
300
+
301
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
302
+
303
+ if cls.config_name is None:
304
+ raise ValueError(
305
+ "`self.config_name` is not defined. Note that one should not load a config from "
306
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
307
+ )
308
+
309
+ if os.path.isfile(pretrained_model_name_or_path):
310
+ config_file = pretrained_model_name_or_path
311
+ elif os.path.isdir(pretrained_model_name_or_path):
312
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
313
+ # Load from a PyTorch checkpoint
314
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
315
+ elif subfolder is not None and os.path.isfile(
316
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
317
+ ):
318
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
319
+ else:
320
+ raise EnvironmentError(
321
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
322
+ )
323
+ else:
324
+ try:
325
+ # Load from URL or cache if already cached
326
+ config_file = hf_hub_download(
327
+ pretrained_model_name_or_path,
328
+ filename=cls.config_name,
329
+ cache_dir=cache_dir,
330
+ force_download=force_download,
331
+ proxies=proxies,
332
+ resume_download=resume_download,
333
+ local_files_only=local_files_only,
334
+ use_auth_token=use_auth_token,
335
+ user_agent=user_agent,
336
+ subfolder=subfolder,
337
+ revision=revision,
338
+ )
339
+
340
+ except RepositoryNotFoundError:
341
+ raise EnvironmentError(
342
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
343
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
344
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
345
+ " login`."
346
+ )
347
+ except RevisionNotFoundError:
348
+ raise EnvironmentError(
349
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
350
+ " this model name. Check the model page at"
351
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
352
+ )
353
+ except EntryNotFoundError:
354
+ raise EnvironmentError(
355
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
356
+ )
357
+ except HTTPError as err:
358
+ raise EnvironmentError(
359
+ "There was a specific connection error when trying to load"
360
+ f" {pretrained_model_name_or_path}:\n{err}"
361
+ )
362
+ except ValueError:
363
+ raise EnvironmentError(
364
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
365
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
366
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
367
+ " run the library in offline mode at"
368
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
369
+ )
370
+ except EnvironmentError:
371
+ raise EnvironmentError(
372
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
373
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
374
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
375
+ f"containing a {cls.config_name} file"
376
+ )
377
+
378
+ try:
379
+ # Load config dict
380
+ config_dict = cls._dict_from_json_file(config_file)
381
+ except (json.JSONDecodeError, UnicodeDecodeError):
382
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
383
+
384
+ if return_unused_kwargs:
385
+ return config_dict, kwargs
386
+
387
+ return config_dict
388
+
389
+ @staticmethod
390
+ def _get_init_keys(cls):
391
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
392
+
393
+ @classmethod
394
+ def extract_init_dict(cls, config_dict, **kwargs):
395
+ # 0. Copy origin config dict
396
+ original_dict = {k: v for k, v in config_dict.items()}
397
+
398
+ # 1. Retrieve expected config attributes from __init__ signature
399
+ expected_keys = cls._get_init_keys(cls)
400
+ expected_keys.remove("self")
401
+ # remove general kwargs if present in dict
402
+ if "kwargs" in expected_keys:
403
+ expected_keys.remove("kwargs")
404
+ # remove flax internal keys
405
+ if hasattr(cls, "_flax_internal_args"):
406
+ for arg in cls._flax_internal_args:
407
+ expected_keys.remove(arg)
408
+
409
+ # 2. Remove attributes that cannot be expected from expected config attributes
410
+ # remove keys to be ignored
411
+ if len(cls.ignore_for_config) > 0:
412
+ expected_keys = expected_keys - set(cls.ignore_for_config)
413
+
414
+ # load diffusers library to import compatible and original scheduler
415
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
416
+
417
+ if cls.has_compatibles:
418
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
419
+ else:
420
+ compatible_classes = []
421
+
422
+ expected_keys_comp_cls = set()
423
+ for c in compatible_classes:
424
+ expected_keys_c = cls._get_init_keys(c)
425
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
426
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
427
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
428
+
429
+ # remove attributes from orig class that cannot be expected
430
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
431
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
432
+ orig_cls = getattr(diffusers_library, orig_cls_name)
433
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
434
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
435
+
436
+ # remove private attributes
437
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
438
+
439
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
440
+ init_dict = {}
441
+ for key in expected_keys:
442
+ # if config param is passed to kwarg and is present in config dict
443
+ # it should overwrite existing config dict key
444
+ if key in kwargs and key in config_dict:
445
+ config_dict[key] = kwargs.pop(key)
446
+
447
+ if key in kwargs:
448
+ # overwrite key
449
+ init_dict[key] = kwargs.pop(key)
450
+ elif key in config_dict:
451
+ # use value from config dict
452
+ init_dict[key] = config_dict.pop(key)
453
+
454
+ # 4. Give nice warning if unexpected values have been passed
455
+ if len(config_dict) > 0:
456
+ logger.warning(
457
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
458
+ "but are not expected and will be ignored. Please verify your "
459
+ f"{cls.config_name} configuration file."
460
+ )
461
+
462
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
463
+ passed_keys = set(init_dict.keys())
464
+ if len(expected_keys - passed_keys) > 0:
465
+ logger.info(
466
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
467
+ )
468
+
469
+ # 6. Define unused keyword arguments
470
+ unused_kwargs = {**config_dict, **kwargs}
471
+
472
+ # 7. Define "hidden" config parameters that were saved for compatible classes
473
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
474
+
475
+ return init_dict, unused_kwargs, hidden_config_dict
476
+
477
+ @classmethod
478
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
479
+ with open(json_file, "r", encoding="utf-8") as reader:
480
+ text = reader.read()
481
+ return json.loads(text)
482
+
483
+ def __repr__(self):
484
+ return f"{self.__class__.__name__} {self.to_json_string()}"
485
+
486
+ @property
487
+ def config(self) -> Dict[str, Any]:
488
+ """
489
+ Returns the config of the class as a frozen dictionary
490
+
491
+ Returns:
492
+ `Dict[str, Any]`: Config of the class.
493
+ """
494
+ return self._internal_dict
495
+
496
+ def to_json_string(self) -> str:
497
+ """
498
+ Serializes this instance to a JSON string.
499
+
500
+ Returns:
501
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
502
+ """
503
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
504
+ config_dict["_class_name"] = self.__class__.__name__
505
+ config_dict["_diffusers_version"] = __version__
506
+
507
+ def to_json_saveable(value):
508
+ if isinstance(value, np.ndarray):
509
+ value = value.tolist()
510
+ elif isinstance(value, PosixPath):
511
+ value = str(value)
512
+ return value
513
+
514
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
515
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
516
+
517
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
518
+ """
519
+ Save this instance to a JSON file.
520
+
521
+ Args:
522
+ json_file_path (`str` or `os.PathLike`):
523
+ Path to the JSON file in which this configuration instance's parameters will be saved.
524
+ """
525
+ with open(json_file_path, "w", encoding="utf-8") as writer:
526
+ writer.write(self.to_json_string())
527
+
528
+
529
+ def register_to_config(init):
530
+ r"""
531
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
532
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
533
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
534
+
535
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
536
+ """
537
+
538
+ @functools.wraps(init)
539
+ def inner_init(self, *args, **kwargs):
540
+ # Ignore private kwargs in the init.
541
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
542
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
543
+ if not isinstance(self, ConfigMixin):
544
+ raise RuntimeError(
545
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
546
+ "not inherit from `ConfigMixin`."
547
+ )
548
+
549
+ ignore = getattr(self, "ignore_for_config", [])
550
+ # Get positional arguments aligned with kwargs
551
+ new_kwargs = {}
552
+ signature = inspect.signature(init)
553
+ parameters = {
554
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
555
+ }
556
+ for arg, name in zip(args, parameters.keys()):
557
+ new_kwargs[name] = arg
558
+
559
+ # Then add all kwargs
560
+ new_kwargs.update(
561
+ {
562
+ k: init_kwargs.get(k, default)
563
+ for k, default in parameters.items()
564
+ if k not in ignore and k not in new_kwargs
565
+ }
566
+ )
567
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
568
+ getattr(self, "register_to_config")(**new_kwargs)
569
+ init(self, *args, **init_kwargs)
570
+
571
+ return inner_init
572
+
573
+
574
+ def flax_register_to_config(cls):
575
+ original_init = cls.__init__
576
+
577
+ @functools.wraps(original_init)
578
+ def init(self, *args, **kwargs):
579
+ if not isinstance(self, ConfigMixin):
580
+ raise RuntimeError(
581
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
582
+ "not inherit from `ConfigMixin`."
583
+ )
584
+
585
+ # Ignore private kwargs in the init. Retrieve all passed attributes
586
+ init_kwargs = {k: v for k, v in kwargs.items()}
587
+
588
+ # Retrieve default values
589
+ fields = dataclasses.fields(self)
590
+ default_kwargs = {}
591
+ for field in fields:
592
+ # ignore flax specific attributes
593
+ if field.name in self._flax_internal_args:
594
+ continue
595
+ if type(field.default) == dataclasses._MISSING_TYPE:
596
+ default_kwargs[field.name] = None
597
+ else:
598
+ default_kwargs[field.name] = getattr(self, field.name)
599
+
600
+ # Make sure init_kwargs override default kwargs
601
+ new_kwargs = {**default_kwargs, **init_kwargs}
602
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
603
+ if "dtype" in new_kwargs:
604
+ new_kwargs.pop("dtype")
605
+
606
+ # Get positional arguments aligned with kwargs
607
+ for i, arg in enumerate(args):
608
+ name = fields[i].name
609
+ new_kwargs[name] = arg
610
+
611
+ getattr(self, "register_to_config")(**new_kwargs)
612
+ original_init(self, *args, **kwargs)
613
+
614
+ cls.__init__ = init
615
+ return cls
diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
+ if sys.version_info < (3, 7):
28
+ pkgs_to_check_at_runtime.append("dataclasses")
29
+ if sys.version_info < (3, 8):
30
+ pkgs_to_check_at_runtime.append("importlib_metadata")
31
+
32
+ for pkg in pkgs_to_check_at_runtime:
33
+ if pkg in deps:
34
+ if pkg == "tokenizers":
35
+ # must be loaded here, or else tqdm check may fail
36
+ from .utils import is_tokenizers_available
37
+
38
+ if not is_tokenizers_available():
39
+ continue # not required, check version only if installed
40
+
41
+ require_version_core(deps[pkg])
42
+ else:
43
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
+
45
+
46
+ def dep_version_check(pkg, hint=None):
47
+ require_version(deps[pkg], hint)
diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "black": "black~=23.1",
8
+ "datasets": "datasets",
9
+ "filelock": "filelock",
10
+ "flax": "flax>=0.4.1",
11
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
+ "huggingface-hub": "huggingface-hub>=0.10.0",
13
+ "importlib_metadata": "importlib_metadata",
14
+ "isort": "isort>=5.5.4",
15
+ "jax": "jax>=0.2.8,!=0.3.2",
16
+ "jaxlib": "jaxlib>=0.1.65",
17
+ "Jinja2": "Jinja2",
18
+ "k-diffusion": "k-diffusion>=0.0.12",
19
+ "librosa": "librosa",
20
+ "numpy": "numpy",
21
+ "parameterized": "parameterized",
22
+ "pytest": "pytest",
23
+ "pytest-timeout": "pytest-timeout",
24
+ "pytest-xdist": "pytest-xdist",
25
+ "ruff": "ruff>=0.0.241",
26
+ "safetensors": "safetensors",
27
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
28
+ "scipy": "scipy",
29
+ "regex": "regex!=2019.12.17",
30
+ "requests": "requests",
31
+ "tensorboard": "tensorboard",
32
+ "torch": "torch>=1.4",
33
+ "torchvision": "torchvision",
34
+ "transformers": "transformers>=4.25.1",
35
+ }
diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
diffusers/experimental/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (264 Bytes). View file
 
diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
diffusers/experimental/rl/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (286 Bytes). View file
 
diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-311.pyc ADDED
Binary file (8.86 kB). View file
 
diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils import randn_tensor
22
+ from ...utils.dummy_pt_objects import DDPMScheduler
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
+ Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
30
+
31
+ Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
32
+
33
+ Parameters:
34
+ value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
36
+ scheduler ([`SchedulerMixin`]):
37
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
38
+ application is [`DDPMScheduler`].
39
+ env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ value_function: UNet1DModel,
45
+ unet: UNet1DModel,
46
+ scheduler: DDPMScheduler,
47
+ env,
48
+ ):
49
+ super().__init__()
50
+ self.value_function = value_function
51
+ self.unet = unet
52
+ self.scheduler = scheduler
53
+ self.env = env
54
+ self.data = env.get_dataset()
55
+ self.means = dict()
56
+ for key in self.data.keys():
57
+ try:
58
+ self.means[key] = self.data[key].mean()
59
+ except: # noqa: E722
60
+ pass
61
+ self.stds = dict()
62
+ for key in self.data.keys():
63
+ try:
64
+ self.stds[key] = self.data[key].std()
65
+ except: # noqa: E722
66
+ pass
67
+ self.state_dim = env.observation_space.shape[0]
68
+ self.action_dim = env.action_space.shape[0]
69
+
70
+ def normalize(self, x_in, key):
71
+ return (x_in - self.means[key]) / self.stds[key]
72
+
73
+ def de_normalize(self, x_in, key):
74
+ return x_in * self.stds[key] + self.means[key]
75
+
76
+ def to_torch(self, x_in):
77
+ if type(x_in) is dict:
78
+ return {k: self.to_torch(v) for k, v in x_in.items()}
79
+ elif torch.is_tensor(x_in):
80
+ return x_in.to(self.unet.device)
81
+ return torch.tensor(x_in, device=self.unet.device)
82
+
83
+ def reset_x0(self, x_in, cond, act_dim):
84
+ for key, val in cond.items():
85
+ x_in[:, key, act_dim:] = val.clone()
86
+ return x_in
87
+
88
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
89
+ batch_size = x.shape[0]
90
+ y = None
91
+ for i in tqdm.tqdm(self.scheduler.timesteps):
92
+ # create batch of timesteps to pass into model
93
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
94
+ for _ in range(n_guide_steps):
95
+ with torch.enable_grad():
96
+ x.requires_grad_()
97
+
98
+ # permute to match dimension for pre-trained models
99
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
100
+ grad = torch.autograd.grad([y.sum()], [x])[0]
101
+
102
+ posterior_variance = self.scheduler._get_variance(i)
103
+ model_std = torch.exp(0.5 * posterior_variance)
104
+ grad = model_std * grad
105
+
106
+ grad[timesteps < 2] = 0
107
+ x = x.detach()
108
+ x = x + scale * grad
109
+ x = self.reset_x0(x, conditions, self.action_dim)
110
+
111
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
112
+
113
+ # TODO: verify deprecation of this kwarg
114
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
115
+
116
+ # apply conditions to the trajectory (set the initial state)
117
+ x = self.reset_x0(x, conditions, self.action_dim)
118
+ x = self.to_torch(x)
119
+ return x, y
120
+
121
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
122
+ # normalize the observations and create batch dimension
123
+ obs = self.normalize(obs, "observations")
124
+ obs = obs[None].repeat(batch_size, axis=0)
125
+
126
+ conditions = {0: self.to_torch(obs)}
127
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
128
+
129
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
130
+ x1 = randn_tensor(shape, device=self.unet.device)
131
+ x = self.reset_x0(x1, conditions, self.action_dim)
132
+ x = self.to_torch(x)
133
+
134
+ # run the diffusion process
135
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
136
+
137
+ # sort output trajectories by value
138
+ sorted_idx = y.argsort(0, descending=True).squeeze()
139
+ sorted_values = x[sorted_idx]
140
+ actions = sorted_values[:, :, : self.action_dim]
141
+ actions = actions.detach().cpu().numpy()
142
+ denorm_actions = self.de_normalize(actions, key="actions")
143
+
144
+ # select the action with the highest value
145
+ if y is not None:
146
+ selected_index = 0
147
+ else:
148
+ # if we didn't run value guiding, select a random action
149
+ selected_index = np.random.randint(0, batch_size)
150
+
151
+ denorm_actions = denorm_actions[selected_index, 0]
152
+ return denorm_actions
diffusers/loaders.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from collections import defaultdict
16
+ from typing import Callable, Dict, Union
17
+
18
+ import torch
19
+
20
+ from .models.cross_attention import LoRACrossAttnProcessor
21
+ from .models.modeling_utils import _get_model_file
22
+ from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
29
+
30
+
31
+ class AttnProcsLayers(torch.nn.Module):
32
+ def __init__(self, state_dict: Dict[str, torch.Tensor]):
33
+ super().__init__()
34
+ self.layers = torch.nn.ModuleList(state_dict.values())
35
+ self.mapping = {k: v for k, v in enumerate(state_dict.keys())}
36
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
37
+
38
+ # we add a hook to state_dict() and load_state_dict() so that the
39
+ # naming fits with `unet.attn_processors`
40
+ def map_to(module, state_dict, *args, **kwargs):
41
+ new_state_dict = {}
42
+ for key, value in state_dict.items():
43
+ num = int(key.split(".")[1]) # 0 is always "layers"
44
+ new_key = key.replace(f"layers.{num}", module.mapping[num])
45
+ new_state_dict[new_key] = value
46
+
47
+ return new_state_dict
48
+
49
+ def map_from(module, state_dict, *args, **kwargs):
50
+ all_keys = list(state_dict.keys())
51
+ for key in all_keys:
52
+ replace_key = key.split(".processor")[0] + ".processor"
53
+ new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
54
+ state_dict[new_key] = state_dict[key]
55
+ del state_dict[key]
56
+
57
+ self._register_state_dict_hook(map_to)
58
+ self._register_load_state_dict_pre_hook(map_from, with_module=True)
59
+
60
+
61
+ class UNet2DConditionLoadersMixin:
62
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
63
+ r"""
64
+ Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
65
+ defined in
66
+ [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
67
+ and be a `torch.nn.Module` class.
68
+
69
+ <Tip warning={true}>
70
+
71
+ This function is experimental and might change in the future.
72
+
73
+ </Tip>
74
+
75
+ Parameters:
76
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
77
+ Can be either:
78
+
79
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
80
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
81
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
82
+ `./my_model_directory/`.
83
+ - A [torch state
84
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
85
+
86
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
87
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
88
+ standard cache should not be used.
89
+ force_download (`bool`, *optional*, defaults to `False`):
90
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
91
+ cached versions if they exist.
92
+ resume_download (`bool`, *optional*, defaults to `False`):
93
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
94
+ file exists.
95
+ proxies (`Dict[str, str]`, *optional*):
96
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
97
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
98
+ local_files_only(`bool`, *optional*, defaults to `False`):
99
+ Whether or not to only look at local files (i.e., do not try to download the model).
100
+ use_auth_token (`str` or *bool*, *optional*):
101
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
102
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
103
+ revision (`str`, *optional*, defaults to `"main"`):
104
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
105
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
106
+ identifier allowed by git.
107
+ subfolder (`str`, *optional*, defaults to `""`):
108
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
109
+ huggingface.co or downloaded locally), you can specify the folder name here.
110
+
111
+ mirror (`str`, *optional*):
112
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
113
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
114
+ Please refer to the mirror site for more information.
115
+
116
+ <Tip>
117
+
118
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
119
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
120
+
121
+ </Tip>
122
+
123
+ <Tip>
124
+
125
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
126
+ this method in a firewalled environment.
127
+
128
+ </Tip>
129
+ """
130
+
131
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
132
+ force_download = kwargs.pop("force_download", False)
133
+ resume_download = kwargs.pop("resume_download", False)
134
+ proxies = kwargs.pop("proxies", None)
135
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
136
+ use_auth_token = kwargs.pop("use_auth_token", None)
137
+ revision = kwargs.pop("revision", None)
138
+ subfolder = kwargs.pop("subfolder", None)
139
+ weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
140
+
141
+ user_agent = {
142
+ "file_type": "attn_procs_weights",
143
+ "framework": "pytorch",
144
+ }
145
+
146
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
147
+ model_file = _get_model_file(
148
+ pretrained_model_name_or_path_or_dict,
149
+ weights_name=weight_name,
150
+ cache_dir=cache_dir,
151
+ force_download=force_download,
152
+ resume_download=resume_download,
153
+ proxies=proxies,
154
+ local_files_only=local_files_only,
155
+ use_auth_token=use_auth_token,
156
+ revision=revision,
157
+ subfolder=subfolder,
158
+ user_agent=user_agent,
159
+ )
160
+ state_dict = torch.load(model_file, map_location="cpu")
161
+ else:
162
+ state_dict = pretrained_model_name_or_path_or_dict
163
+
164
+ # fill attn processors
165
+ attn_processors = {}
166
+
167
+ is_lora = all("lora" in k for k in state_dict.keys())
168
+
169
+ if is_lora:
170
+ lora_grouped_dict = defaultdict(dict)
171
+ for key, value in state_dict.items():
172
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
173
+ lora_grouped_dict[attn_processor_key][sub_key] = value
174
+
175
+ for key, value_dict in lora_grouped_dict.items():
176
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
177
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
178
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
179
+
180
+ attn_processors[key] = LoRACrossAttnProcessor(
181
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
182
+ )
183
+ attn_processors[key].load_state_dict(value_dict)
184
+
185
+ else:
186
+ raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
187
+
188
+ # set correct dtype & device
189
+ attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
190
+
191
+ # set layers
192
+ self.set_attn_processor(attn_processors)
193
+
194
+ def save_attn_procs(
195
+ self,
196
+ save_directory: Union[str, os.PathLike],
197
+ is_main_process: bool = True,
198
+ weights_name: str = LORA_WEIGHT_NAME,
199
+ save_function: Callable = None,
200
+ ):
201
+ r"""
202
+ Save an attention processor to a directory, so that it can be re-loaded using the
203
+ `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method.
204
+
205
+ Arguments:
206
+ save_directory (`str` or `os.PathLike`):
207
+ Directory to which to save. Will be created if it doesn't exist.
208
+ is_main_process (`bool`, *optional*, defaults to `True`):
209
+ Whether the process calling this is the main process or not. Useful when in distributed training like
210
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
211
+ the main process to avoid race conditions.
212
+ save_function (`Callable`):
213
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
214
+ need to replace `torch.save` by another method. Can be configured with the environment variable
215
+ `DIFFUSERS_SAVE_MODE`.
216
+ """
217
+ if os.path.isfile(save_directory):
218
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
219
+ return
220
+
221
+ if save_function is None:
222
+ save_function = torch.save
223
+
224
+ os.makedirs(save_directory, exist_ok=True)
225
+
226
+ model_to_save = AttnProcsLayers(self.attn_processors)
227
+
228
+ # Save the model
229
+ state_dict = model_to_save.state_dict()
230
+
231
+ # Clean the folder from a previous save
232
+ for filename in os.listdir(save_directory):
233
+ full_filename = os.path.join(save_directory, filename)
234
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
235
+ # in distributed settings to avoid race conditions.
236
+ weights_no_suffix = weights_name.replace(".bin", "")
237
+ if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
238
+ os.remove(full_filename)
239
+
240
+ # Save the model
241
+ save_function(state_dict, os.path.join(save_directory, weights_name))
242
+
243
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
diffusers/models/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..utils import is_flax_available, is_torch_available
16
+
17
+
18
+ if is_torch_available():
19
+ from .autoencoder_kl import AutoencoderKL
20
+ from .controlnet import ControlNetModel
21
+ from .dual_transformer_2d import DualTransformer2DModel
22
+ from .modeling_utils import ModelMixin
23
+ from .prior_transformer import PriorTransformer
24
+ from .transformer_2d import Transformer2DModel
25
+ from .unet_1d import UNet1DModel
26
+ from .unet_2d import UNet2DModel
27
+ from .unet_2d_condition import UNet2DConditionModel
28
+ from .vq_model import VQModel
29
+
30
+ if is_flax_available():
31
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
32
+ from .vae_flax import FlaxAutoencoderKL
diffusers/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.22 kB). View file
 
diffusers/models/__pycache__/attention.cpython-311.pyc ADDED
Binary file (25.8 kB). View file
 
diffusers/models/__pycache__/attention_flax.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
diffusers/models/__pycache__/autoencoder_kl.cpython-311.pyc ADDED
Binary file (17.9 kB). View file
 
diffusers/models/__pycache__/controlnet.cpython-311.pyc ADDED
Binary file (23.7 kB). View file
 
diffusers/models/__pycache__/cross_attention.cpython-311.pyc ADDED
Binary file (33.2 kB). View file
 
diffusers/models/__pycache__/dual_transformer_2d.cpython-311.pyc ADDED
Binary file (7.08 kB). View file
 
diffusers/models/__pycache__/embeddings.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
diffusers/models/__pycache__/embeddings_flax.cpython-311.pyc ADDED
Binary file (4.9 kB). View file
 
diffusers/models/__pycache__/modeling_flax_pytorch_utils.cpython-311.pyc ADDED
Binary file (4.6 kB). View file
 
diffusers/models/__pycache__/modeling_flax_utils.cpython-311.pyc ADDED
Binary file (28.4 kB). View file
 
diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-311.pyc ADDED
Binary file (7.7 kB). View file
 
diffusers/models/__pycache__/modeling_utils.cpython-311.pyc ADDED
Binary file (44.3 kB). View file
 
diffusers/models/__pycache__/prior_transformer.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
diffusers/models/__pycache__/resnet.cpython-311.pyc ADDED
Binary file (39.8 kB). View file
 
diffusers/models/__pycache__/resnet_flax.cpython-311.pyc ADDED
Binary file (5.04 kB). View file
 
diffusers/models/__pycache__/transformer_2d.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
diffusers/models/__pycache__/unet_1d.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
diffusers/models/__pycache__/unet_1d_blocks.cpython-311.pyc ADDED
Binary file (33.8 kB). View file
 
diffusers/models/__pycache__/unet_2d.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
diffusers/models/__pycache__/unet_2d_blocks.cpython-311.pyc ADDED
Binary file (79.9 kB). View file
 
diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
diffusers/models/__pycache__/unet_2d_condition.cpython-311.pyc ADDED
Binary file (31 kB). View file
 
diffusers/models/__pycache__/unet_2d_condition_flax.cpython-311.pyc ADDED
Binary file (14.4 kB). View file
 
diffusers/models/__pycache__/vae.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
diffusers/models/__pycache__/vae_flax.cpython-311.pyc ADDED
Binary file (39.5 kB). View file
 
diffusers/models/__pycache__/vq_model.cpython-311.pyc ADDED
Binary file (7.41 kB). View file
 
diffusers/models/attention.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Callable, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..utils.import_utils import is_xformers_available
22
+ from .cross_attention import CrossAttention
23
+ from .embeddings import CombinedTimestepLabelEmbeddings
24
+
25
+
26
+ if is_xformers_available():
27
+ import xformers
28
+ import xformers.ops
29
+ else:
30
+ xformers = None
31
+
32
+
33
+ class AttentionBlock(nn.Module):
34
+ """
35
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
36
+ to the N-d case.
37
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
38
+ Uses three q, k, v linear layers to compute attention.
39
+
40
+ Parameters:
41
+ channels (`int`): The number of channels in the input and output.
42
+ num_head_channels (`int`, *optional*):
43
+ The number of channels in each head. If None, then `num_heads` = 1.
44
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
45
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
46
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
47
+ """
48
+
49
+ # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
50
+
51
+ def __init__(
52
+ self,
53
+ channels: int,
54
+ num_head_channels: Optional[int] = None,
55
+ norm_num_groups: int = 32,
56
+ rescale_output_factor: float = 1.0,
57
+ eps: float = 1e-5,
58
+ ):
59
+ super().__init__()
60
+ self.channels = channels
61
+
62
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
63
+ self.num_head_size = num_head_channels
64
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
65
+
66
+ # define q,k,v as linear layers
67
+ self.query = nn.Linear(channels, channels)
68
+ self.key = nn.Linear(channels, channels)
69
+ self.value = nn.Linear(channels, channels)
70
+
71
+ self.rescale_output_factor = rescale_output_factor
72
+ self.proj_attn = nn.Linear(channels, channels, 1)
73
+
74
+ self._use_memory_efficient_attention_xformers = False
75
+ self._attention_op = None
76
+
77
+ def reshape_heads_to_batch_dim(self, tensor):
78
+ batch_size, seq_len, dim = tensor.shape
79
+ head_size = self.num_heads
80
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
82
+ return tensor
83
+
84
+ def reshape_batch_dim_to_heads(self, tensor):
85
+ batch_size, seq_len, dim = tensor.shape
86
+ head_size = self.num_heads
87
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
88
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
89
+ return tensor
90
+
91
+ def set_use_memory_efficient_attention_xformers(
92
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
93
+ ):
94
+ if use_memory_efficient_attention_xformers:
95
+ if not is_xformers_available():
96
+ raise ModuleNotFoundError(
97
+ (
98
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
99
+ " xformers"
100
+ ),
101
+ name="xformers",
102
+ )
103
+ elif not torch.cuda.is_available():
104
+ raise ValueError(
105
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
106
+ " only available for GPU "
107
+ )
108
+ else:
109
+ try:
110
+ # Make sure we can run the memory efficient attention
111
+ _ = xformers.ops.memory_efficient_attention(
112
+ torch.randn((1, 2, 40), device="cuda"),
113
+ torch.randn((1, 2, 40), device="cuda"),
114
+ torch.randn((1, 2, 40), device="cuda"),
115
+ )
116
+ except Exception as e:
117
+ raise e
118
+ self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
119
+ self._attention_op = attention_op
120
+
121
+ def forward(self, hidden_states):
122
+ residual = hidden_states
123
+ batch, channel, height, width = hidden_states.shape
124
+
125
+ # norm
126
+ hidden_states = self.group_norm(hidden_states)
127
+
128
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
129
+
130
+ # proj to q, k, v
131
+ query_proj = self.query(hidden_states)
132
+ key_proj = self.key(hidden_states)
133
+ value_proj = self.value(hidden_states)
134
+
135
+ scale = 1 / math.sqrt(self.channels / self.num_heads)
136
+
137
+ query_proj = self.reshape_heads_to_batch_dim(query_proj)
138
+ key_proj = self.reshape_heads_to_batch_dim(key_proj)
139
+ value_proj = self.reshape_heads_to_batch_dim(value_proj)
140
+
141
+ if self._use_memory_efficient_attention_xformers:
142
+ # Memory efficient attention
143
+ hidden_states = xformers.ops.memory_efficient_attention(
144
+ query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
145
+ )
146
+ hidden_states = hidden_states.to(query_proj.dtype)
147
+ else:
148
+ attention_scores = torch.baddbmm(
149
+ torch.empty(
150
+ query_proj.shape[0],
151
+ query_proj.shape[1],
152
+ key_proj.shape[1],
153
+ dtype=query_proj.dtype,
154
+ device=query_proj.device,
155
+ ),
156
+ query_proj,
157
+ key_proj.transpose(-1, -2),
158
+ beta=0,
159
+ alpha=scale,
160
+ )
161
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
162
+ hidden_states = torch.bmm(attention_probs, value_proj)
163
+
164
+ # reshape hidden_states
165
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
166
+
167
+ # compute next hidden_states
168
+ hidden_states = self.proj_attn(hidden_states)
169
+
170
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
171
+
172
+ # res connect and rescale
173
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
174
+ return hidden_states
175
+
176
+
177
+ class BasicTransformerBlock(nn.Module):
178
+ r"""
179
+ A basic Transformer block.
180
+
181
+ Parameters:
182
+ dim (`int`): The number of channels in the input and output.
183
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
184
+ attention_head_dim (`int`): The number of channels in each head.
185
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
186
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
187
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
188
+ num_embeds_ada_norm (:
189
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
190
+ attention_bias (:
191
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ dim: int,
197
+ num_attention_heads: int,
198
+ attention_head_dim: int,
199
+ dropout=0.0,
200
+ cross_attention_dim: Optional[int] = None,
201
+ activation_fn: str = "geglu",
202
+ num_embeds_ada_norm: Optional[int] = None,
203
+ attention_bias: bool = False,
204
+ only_cross_attention: bool = False,
205
+ upcast_attention: bool = False,
206
+ norm_elementwise_affine: bool = True,
207
+ norm_type: str = "layer_norm",
208
+ final_dropout: bool = False,
209
+ ):
210
+ super().__init__()
211
+ self.only_cross_attention = only_cross_attention
212
+
213
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
214
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
215
+
216
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
217
+ raise ValueError(
218
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
219
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
220
+ )
221
+
222
+ # 1. Self-Attn
223
+ self.attn1 = CrossAttention(
224
+ query_dim=dim,
225
+ heads=num_attention_heads,
226
+ dim_head=attention_head_dim,
227
+ dropout=dropout,
228
+ bias=attention_bias,
229
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
230
+ upcast_attention=upcast_attention,
231
+ )
232
+
233
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
234
+
235
+ # 2. Cross-Attn
236
+ if cross_attention_dim is not None:
237
+ self.attn2 = CrossAttention(
238
+ query_dim=dim,
239
+ cross_attention_dim=cross_attention_dim,
240
+ heads=num_attention_heads,
241
+ dim_head=attention_head_dim,
242
+ dropout=dropout,
243
+ bias=attention_bias,
244
+ upcast_attention=upcast_attention,
245
+ ) # is self-attn if encoder_hidden_states is none
246
+ else:
247
+ self.attn2 = None
248
+
249
+ if self.use_ada_layer_norm:
250
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
251
+ elif self.use_ada_layer_norm_zero:
252
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
253
+ else:
254
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
255
+
256
+ if cross_attention_dim is not None:
257
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
258
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
259
+ # the second cross attention block.
260
+ self.norm2 = (
261
+ AdaLayerNorm(dim, num_embeds_ada_norm)
262
+ if self.use_ada_layer_norm
263
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
264
+ )
265
+ else:
266
+ self.norm2 = None
267
+
268
+ # 3. Feed-forward
269
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states,
274
+ encoder_hidden_states=None,
275
+ timestep=None,
276
+ attention_mask=None,
277
+ cross_attention_kwargs=None,
278
+ class_labels=None,
279
+ ):
280
+ if self.use_ada_layer_norm:
281
+ norm_hidden_states = self.norm1(hidden_states, timestep)
282
+ elif self.use_ada_layer_norm_zero:
283
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
284
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
285
+ )
286
+ else:
287
+ norm_hidden_states = self.norm1(hidden_states)
288
+
289
+ # 1. Self-Attention
290
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
291
+ attn_output = self.attn1(
292
+ norm_hidden_states,
293
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
294
+ attention_mask=attention_mask,
295
+ **cross_attention_kwargs,
296
+ )
297
+ if self.use_ada_layer_norm_zero:
298
+ attn_output = gate_msa.unsqueeze(1) * attn_output
299
+ hidden_states = attn_output + hidden_states
300
+
301
+ if self.attn2 is not None:
302
+ norm_hidden_states = (
303
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
304
+ )
305
+
306
+ # 2. Cross-Attention
307
+ attn_output = self.attn2(
308
+ norm_hidden_states,
309
+ encoder_hidden_states=encoder_hidden_states,
310
+ attention_mask=attention_mask,
311
+ **cross_attention_kwargs,
312
+ )
313
+ hidden_states = attn_output + hidden_states
314
+
315
+ # 3. Feed-forward
316
+ norm_hidden_states = self.norm3(hidden_states)
317
+
318
+ if self.use_ada_layer_norm_zero:
319
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
320
+
321
+ ff_output = self.ff(norm_hidden_states)
322
+
323
+ if self.use_ada_layer_norm_zero:
324
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
325
+
326
+ hidden_states = ff_output + hidden_states
327
+
328
+ return hidden_states
329
+
330
+
331
+ class FeedForward(nn.Module):
332
+ r"""
333
+ A feed-forward layer.
334
+
335
+ Parameters:
336
+ dim (`int`): The number of channels in the input.
337
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
338
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
339
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
340
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
341
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
342
+ """
343
+
344
+ def __init__(
345
+ self,
346
+ dim: int,
347
+ dim_out: Optional[int] = None,
348
+ mult: int = 4,
349
+ dropout: float = 0.0,
350
+ activation_fn: str = "geglu",
351
+ final_dropout: bool = False,
352
+ ):
353
+ super().__init__()
354
+ inner_dim = int(dim * mult)
355
+ dim_out = dim_out if dim_out is not None else dim
356
+
357
+ if activation_fn == "gelu":
358
+ act_fn = GELU(dim, inner_dim)
359
+ if activation_fn == "gelu-approximate":
360
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
361
+ elif activation_fn == "geglu":
362
+ act_fn = GEGLU(dim, inner_dim)
363
+ elif activation_fn == "geglu-approximate":
364
+ act_fn = ApproximateGELU(dim, inner_dim)
365
+
366
+ self.net = nn.ModuleList([])
367
+ # project in
368
+ self.net.append(act_fn)
369
+ # project dropout
370
+ self.net.append(nn.Dropout(dropout))
371
+ # project out
372
+ self.net.append(nn.Linear(inner_dim, dim_out))
373
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
374
+ if final_dropout:
375
+ self.net.append(nn.Dropout(dropout))
376
+
377
+ def forward(self, hidden_states):
378
+ for module in self.net:
379
+ hidden_states = module(hidden_states)
380
+ return hidden_states
381
+
382
+
383
+ class GELU(nn.Module):
384
+ r"""
385
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
386
+ """
387
+
388
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
389
+ super().__init__()
390
+ self.proj = nn.Linear(dim_in, dim_out)
391
+ self.approximate = approximate
392
+
393
+ def gelu(self, gate):
394
+ if gate.device.type != "mps":
395
+ return F.gelu(gate, approximate=self.approximate)
396
+ # mps: gelu is not implemented for float16
397
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
398
+
399
+ def forward(self, hidden_states):
400
+ hidden_states = self.proj(hidden_states)
401
+ hidden_states = self.gelu(hidden_states)
402
+ return hidden_states
403
+
404
+
405
+ class GEGLU(nn.Module):
406
+ r"""
407
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
408
+
409
+ Parameters:
410
+ dim_in (`int`): The number of channels in the input.
411
+ dim_out (`int`): The number of channels in the output.
412
+ """
413
+
414
+ def __init__(self, dim_in: int, dim_out: int):
415
+ super().__init__()
416
+ self.proj = nn.Linear(dim_in, dim_out * 2)
417
+
418
+ def gelu(self, gate):
419
+ if gate.device.type != "mps":
420
+ return F.gelu(gate)
421
+ # mps: gelu is not implemented for float16
422
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
423
+
424
+ def forward(self, hidden_states):
425
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
426
+ return hidden_states * self.gelu(gate)
427
+
428
+
429
+ class ApproximateGELU(nn.Module):
430
+ """
431
+ The approximate form of Gaussian Error Linear Unit (GELU)
432
+
433
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
434
+ """
435
+
436
+ def __init__(self, dim_in: int, dim_out: int):
437
+ super().__init__()
438
+ self.proj = nn.Linear(dim_in, dim_out)
439
+
440
+ def forward(self, x):
441
+ x = self.proj(x)
442
+ return x * torch.sigmoid(1.702 * x)
443
+
444
+
445
+ class AdaLayerNorm(nn.Module):
446
+ """
447
+ Norm layer modified to incorporate timestep embeddings.
448
+ """
449
+
450
+ def __init__(self, embedding_dim, num_embeddings):
451
+ super().__init__()
452
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
453
+ self.silu = nn.SiLU()
454
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
455
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
456
+
457
+ def forward(self, x, timestep):
458
+ emb = self.linear(self.silu(self.emb(timestep)))
459
+ scale, shift = torch.chunk(emb, 2)
460
+ x = self.norm(x) * (1 + scale) + shift
461
+ return x
462
+
463
+
464
+ class AdaLayerNormZero(nn.Module):
465
+ """
466
+ Norm layer adaptive layer norm zero (adaLN-Zero).
467
+ """
468
+
469
+ def __init__(self, embedding_dim, num_embeddings):
470
+ super().__init__()
471
+
472
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
473
+
474
+ self.silu = nn.SiLU()
475
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
476
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
477
+
478
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
479
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
480
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
481
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
482
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
483
+
484
+
485
+ class AdaGroupNorm(nn.Module):
486
+ """
487
+ GroupNorm layer modified to incorporate timestep embeddings.
488
+ """
489
+
490
+ def __init__(
491
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
492
+ ):
493
+ super().__init__()
494
+ self.num_groups = num_groups
495
+ self.eps = eps
496
+ self.act = None
497
+ if act_fn == "swish":
498
+ self.act = lambda x: F.silu(x)
499
+ elif act_fn == "mish":
500
+ self.act = nn.Mish()
501
+ elif act_fn == "silu":
502
+ self.act = nn.SiLU()
503
+ elif act_fn == "gelu":
504
+ self.act = nn.GELU()
505
+
506
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
507
+
508
+ def forward(self, x, emb):
509
+ if self.act:
510
+ emb = self.act(emb)
511
+ emb = self.linear(emb)
512
+ emb = emb[:, :, None, None]
513
+ scale, shift = emb.chunk(2, dim=1)
514
+
515
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
516
+ x = x * (1 + scale) + shift
517
+ return x
diffusers/models/attention_flax.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import flax.linen as nn
16
+ import jax.numpy as jnp
17
+
18
+
19
+ class FlaxCrossAttention(nn.Module):
20
+ r"""
21
+ A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
22
+
23
+ Parameters:
24
+ query_dim (:obj:`int`):
25
+ Input hidden states dimension
26
+ heads (:obj:`int`, *optional*, defaults to 8):
27
+ Number of heads
28
+ dim_head (:obj:`int`, *optional*, defaults to 64):
29
+ Hidden states dimension inside each head
30
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
31
+ Dropout rate
32
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
33
+ Parameters `dtype`
34
+
35
+ """
36
+ query_dim: int
37
+ heads: int = 8
38
+ dim_head: int = 64
39
+ dropout: float = 0.0
40
+ dtype: jnp.dtype = jnp.float32
41
+
42
+ def setup(self):
43
+ inner_dim = self.dim_head * self.heads
44
+ self.scale = self.dim_head**-0.5
45
+
46
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
47
+ self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
48
+ self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
49
+ self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
50
+
51
+ self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
52
+
53
+ def reshape_heads_to_batch_dim(self, tensor):
54
+ batch_size, seq_len, dim = tensor.shape
55
+ head_size = self.heads
56
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
57
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
58
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
59
+ return tensor
60
+
61
+ def reshape_batch_dim_to_heads(self, tensor):
62
+ batch_size, seq_len, dim = tensor.shape
63
+ head_size = self.heads
64
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
65
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
66
+ tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
67
+ return tensor
68
+
69
+ def __call__(self, hidden_states, context=None, deterministic=True):
70
+ context = hidden_states if context is None else context
71
+
72
+ query_proj = self.query(hidden_states)
73
+ key_proj = self.key(context)
74
+ value_proj = self.value(context)
75
+
76
+ query_states = self.reshape_heads_to_batch_dim(query_proj)
77
+ key_states = self.reshape_heads_to_batch_dim(key_proj)
78
+ value_states = self.reshape_heads_to_batch_dim(value_proj)
79
+
80
+ # compute attentions
81
+ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
82
+ attention_scores = attention_scores * self.scale
83
+ attention_probs = nn.softmax(attention_scores, axis=2)
84
+
85
+ # attend to values
86
+ hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
87
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
88
+ hidden_states = self.proj_attn(hidden_states)
89
+ return hidden_states
90
+
91
+
92
+ class FlaxBasicTransformerBlock(nn.Module):
93
+ r"""
94
+ A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
95
+ https://arxiv.org/abs/1706.03762
96
+
97
+
98
+ Parameters:
99
+ dim (:obj:`int`):
100
+ Inner hidden states dimension
101
+ n_heads (:obj:`int`):
102
+ Number of heads
103
+ d_head (:obj:`int`):
104
+ Hidden states dimension inside each head
105
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
106
+ Dropout rate
107
+ only_cross_attention (`bool`, defaults to `False`):
108
+ Whether to only apply cross attention.
109
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
110
+ Parameters `dtype`
111
+ """
112
+ dim: int
113
+ n_heads: int
114
+ d_head: int
115
+ dropout: float = 0.0
116
+ only_cross_attention: bool = False
117
+ dtype: jnp.dtype = jnp.float32
118
+
119
+ def setup(self):
120
+ # self attention (or cross_attention if only_cross_attention is True)
121
+ self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
122
+ # cross attention
123
+ self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
124
+ self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
125
+ self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
126
+ self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
127
+ self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
128
+
129
+ def __call__(self, hidden_states, context, deterministic=True):
130
+ # self attention
131
+ residual = hidden_states
132
+ if self.only_cross_attention:
133
+ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
134
+ else:
135
+ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
136
+ hidden_states = hidden_states + residual
137
+
138
+ # cross attention
139
+ residual = hidden_states
140
+ hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
141
+ hidden_states = hidden_states + residual
142
+
143
+ # feed forward
144
+ residual = hidden_states
145
+ hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
146
+ hidden_states = hidden_states + residual
147
+
148
+ return hidden_states
149
+
150
+
151
+ class FlaxTransformer2DModel(nn.Module):
152
+ r"""
153
+ A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
154
+ https://arxiv.org/pdf/1506.02025.pdf
155
+
156
+
157
+ Parameters:
158
+ in_channels (:obj:`int`):
159
+ Input number of channels
160
+ n_heads (:obj:`int`):
161
+ Number of heads
162
+ d_head (:obj:`int`):
163
+ Hidden states dimension inside each head
164
+ depth (:obj:`int`, *optional*, defaults to 1):
165
+ Number of transformers block
166
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
167
+ Dropout rate
168
+ use_linear_projection (`bool`, defaults to `False`): tbd
169
+ only_cross_attention (`bool`, defaults to `False`): tbd
170
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
171
+ Parameters `dtype`
172
+ """
173
+ in_channels: int
174
+ n_heads: int
175
+ d_head: int
176
+ depth: int = 1
177
+ dropout: float = 0.0
178
+ use_linear_projection: bool = False
179
+ only_cross_attention: bool = False
180
+ dtype: jnp.dtype = jnp.float32
181
+
182
+ def setup(self):
183
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
184
+
185
+ inner_dim = self.n_heads * self.d_head
186
+ if self.use_linear_projection:
187
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
188
+ else:
189
+ self.proj_in = nn.Conv(
190
+ inner_dim,
191
+ kernel_size=(1, 1),
192
+ strides=(1, 1),
193
+ padding="VALID",
194
+ dtype=self.dtype,
195
+ )
196
+
197
+ self.transformer_blocks = [
198
+ FlaxBasicTransformerBlock(
199
+ inner_dim,
200
+ self.n_heads,
201
+ self.d_head,
202
+ dropout=self.dropout,
203
+ only_cross_attention=self.only_cross_attention,
204
+ dtype=self.dtype,
205
+ )
206
+ for _ in range(self.depth)
207
+ ]
208
+
209
+ if self.use_linear_projection:
210
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
211
+ else:
212
+ self.proj_out = nn.Conv(
213
+ inner_dim,
214
+ kernel_size=(1, 1),
215
+ strides=(1, 1),
216
+ padding="VALID",
217
+ dtype=self.dtype,
218
+ )
219
+
220
+ def __call__(self, hidden_states, context, deterministic=True):
221
+ batch, height, width, channels = hidden_states.shape
222
+ residual = hidden_states
223
+ hidden_states = self.norm(hidden_states)
224
+ if self.use_linear_projection:
225
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
226
+ hidden_states = self.proj_in(hidden_states)
227
+ else:
228
+ hidden_states = self.proj_in(hidden_states)
229
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
230
+
231
+ for transformer_block in self.transformer_blocks:
232
+ hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
233
+
234
+ if self.use_linear_projection:
235
+ hidden_states = self.proj_out(hidden_states)
236
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
237
+ else:
238
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
239
+ hidden_states = self.proj_out(hidden_states)
240
+
241
+ hidden_states = hidden_states + residual
242
+ return hidden_states
243
+
244
+
245
+ class FlaxFeedForward(nn.Module):
246
+ r"""
247
+ Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
248
+ [`FeedForward`] class, with the following simplifications:
249
+ - The activation function is currently hardcoded to a gated linear unit from:
250
+ https://arxiv.org/abs/2002.05202
251
+ - `dim_out` is equal to `dim`.
252
+ - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
253
+
254
+ Parameters:
255
+ dim (:obj:`int`):
256
+ Inner hidden states dimension
257
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
258
+ Dropout rate
259
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
260
+ Parameters `dtype`
261
+ """
262
+ dim: int
263
+ dropout: float = 0.0
264
+ dtype: jnp.dtype = jnp.float32
265
+
266
+ def setup(self):
267
+ # The second linear layer needs to be called
268
+ # net_2 for now to match the index of the Sequential layer
269
+ self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
270
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
271
+
272
+ def __call__(self, hidden_states, deterministic=True):
273
+ hidden_states = self.net_0(hidden_states)
274
+ hidden_states = self.net_2(hidden_states)
275
+ return hidden_states
276
+
277
+
278
+ class FlaxGEGLU(nn.Module):
279
+ r"""
280
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
281
+ https://arxiv.org/abs/2002.05202.
282
+
283
+ Parameters:
284
+ dim (:obj:`int`):
285
+ Input hidden states dimension
286
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
287
+ Dropout rate
288
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
289
+ Parameters `dtype`
290
+ """
291
+ dim: int
292
+ dropout: float = 0.0
293
+ dtype: jnp.dtype = jnp.float32
294
+
295
+ def setup(self):
296
+ inner_dim = self.dim * 4
297
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
298
+
299
+ def __call__(self, hidden_states, deterministic=True):
300
+ hidden_states = self.proj(hidden_states)
301
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
302
+ return hidden_linear * nn.gelu(hidden_gelu)
diffusers/models/autoencoder_kl.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput, apply_forward_hook
22
+ from .modeling_utils import ModelMixin
23
+ from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
24
+
25
+
26
+ @dataclass
27
+ class AutoencoderKLOutput(BaseOutput):
28
+ """
29
+ Output of AutoencoderKL encoding method.
30
+
31
+ Args:
32
+ latent_dist (`DiagonalGaussianDistribution`):
33
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
34
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
35
+ """
36
+
37
+ latent_dist: "DiagonalGaussianDistribution"
38
+
39
+
40
+ class AutoencoderKL(ModelMixin, ConfigMixin):
41
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
42
+ and Max Welling.
43
+
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
45
+ implements for all the model (such as downloading or saving, etc.)
46
+
47
+ Parameters:
48
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
49
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
50
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
51
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
52
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
53
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
54
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
55
+ obj:`(64,)`): Tuple of block output channels.
56
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
57
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
58
+ sample_size (`int`, *optional*, defaults to `32`): TODO
59
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
60
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
61
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
62
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
63
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
64
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
65
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
66
+ """
67
+
68
+ @register_to_config
69
+ def __init__(
70
+ self,
71
+ in_channels: int = 3,
72
+ out_channels: int = 3,
73
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
74
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
75
+ block_out_channels: Tuple[int] = (64,),
76
+ layers_per_block: int = 1,
77
+ act_fn: str = "silu",
78
+ latent_channels: int = 4,
79
+ norm_num_groups: int = 32,
80
+ sample_size: int = 32,
81
+ scaling_factor: float = 0.18215,
82
+ ):
83
+ super().__init__()
84
+
85
+ # pass init params to Encoder
86
+ self.encoder = Encoder(
87
+ in_channels=in_channels,
88
+ out_channels=latent_channels,
89
+ down_block_types=down_block_types,
90
+ block_out_channels=block_out_channels,
91
+ layers_per_block=layers_per_block,
92
+ act_fn=act_fn,
93
+ norm_num_groups=norm_num_groups,
94
+ double_z=True,
95
+ )
96
+
97
+ # pass init params to Decoder
98
+ self.decoder = Decoder(
99
+ in_channels=latent_channels,
100
+ out_channels=out_channels,
101
+ up_block_types=up_block_types,
102
+ block_out_channels=block_out_channels,
103
+ layers_per_block=layers_per_block,
104
+ norm_num_groups=norm_num_groups,
105
+ act_fn=act_fn,
106
+ )
107
+
108
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
109
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
110
+
111
+ self.use_slicing = False
112
+ self.use_tiling = False
113
+
114
+ # only relevant if vae tiling is enabled
115
+ self.tile_sample_min_size = self.config.sample_size
116
+ sample_size = (
117
+ self.config.sample_size[0]
118
+ if isinstance(self.config.sample_size, (list, tuple))
119
+ else self.config.sample_size
120
+ )
121
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
122
+ self.tile_overlap_factor = 0.25
123
+
124
+ def enable_tiling(self, use_tiling: bool = True):
125
+ r"""
126
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
127
+ compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
128
+ the processing of larger images.
129
+ """
130
+ self.use_tiling = use_tiling
131
+
132
+ def disable_tiling(self):
133
+ r"""
134
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
135
+ computing decoding in one step.
136
+ """
137
+ self.enable_tiling(False)
138
+
139
+ def enable_slicing(self):
140
+ r"""
141
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
142
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
143
+ """
144
+ self.use_slicing = True
145
+
146
+ def disable_slicing(self):
147
+ r"""
148
+ Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
149
+ decoding in one step.
150
+ """
151
+ self.use_slicing = False
152
+
153
+ @apply_forward_hook
154
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
155
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
156
+ return self.tiled_encode(x, return_dict=return_dict)
157
+
158
+ h = self.encoder(x)
159
+ moments = self.quant_conv(h)
160
+ posterior = DiagonalGaussianDistribution(moments)
161
+
162
+ if not return_dict:
163
+ return (posterior,)
164
+
165
+ return AutoencoderKLOutput(latent_dist=posterior)
166
+
167
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
168
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
169
+ return self.tiled_decode(z, return_dict=return_dict)
170
+
171
+ z = self.post_quant_conv(z)
172
+ dec = self.decoder(z)
173
+
174
+ if not return_dict:
175
+ return (dec,)
176
+
177
+ return DecoderOutput(sample=dec)
178
+
179
+ @apply_forward_hook
180
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
181
+ if self.use_slicing and z.shape[0] > 1:
182
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
183
+ decoded = torch.cat(decoded_slices)
184
+ else:
185
+ decoded = self._decode(z).sample
186
+
187
+ if not return_dict:
188
+ return (decoded,)
189
+
190
+ return DecoderOutput(sample=decoded)
191
+
192
+ def blend_v(self, a, b, blend_extent):
193
+ for y in range(blend_extent):
194
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
195
+ return b
196
+
197
+ def blend_h(self, a, b, blend_extent):
198
+ for x in range(blend_extent):
199
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
200
+ return b
201
+
202
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
203
+ r"""Encode a batch of images using a tiled encoder.
204
+ Args:
205
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
206
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
207
+ different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
208
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
209
+ look of the output, but they should be much less noticeable.
210
+ x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
211
+ Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
212
+ """
213
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
214
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
215
+ row_limit = self.tile_latent_min_size - blend_extent
216
+
217
+ # Split the image into 512x512 tiles and encode them separately.
218
+ rows = []
219
+ for i in range(0, x.shape[2], overlap_size):
220
+ row = []
221
+ for j in range(0, x.shape[3], overlap_size):
222
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
223
+ tile = self.encoder(tile)
224
+ tile = self.quant_conv(tile)
225
+ row.append(tile)
226
+ rows.append(row)
227
+ result_rows = []
228
+ for i, row in enumerate(rows):
229
+ result_row = []
230
+ for j, tile in enumerate(row):
231
+ # blend the above tile and the left tile
232
+ # to the current tile and add the current tile to the result row
233
+ if i > 0:
234
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
235
+ if j > 0:
236
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
237
+ result_row.append(tile[:, :, :row_limit, :row_limit])
238
+ result_rows.append(torch.cat(result_row, dim=3))
239
+
240
+ moments = torch.cat(result_rows, dim=2)
241
+ posterior = DiagonalGaussianDistribution(moments)
242
+
243
+ if not return_dict:
244
+ return (posterior,)
245
+
246
+ return AutoencoderKLOutput(latent_dist=posterior)
247
+
248
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
249
+ r"""Decode a batch of images using a tiled decoder.
250
+ Args:
251
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
252
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
253
+ different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
254
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
255
+ look of the output, but they should be much less noticeable.
256
+ z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
257
+ `True`):
258
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
259
+ """
260
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
261
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
262
+ row_limit = self.tile_sample_min_size - blend_extent
263
+
264
+ # Split z into overlapping 64x64 tiles and decode them separately.
265
+ # The tiles have an overlap to avoid seams between tiles.
266
+ rows = []
267
+ for i in range(0, z.shape[2], overlap_size):
268
+ row = []
269
+ for j in range(0, z.shape[3], overlap_size):
270
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
271
+ tile = self.post_quant_conv(tile)
272
+ decoded = self.decoder(tile)
273
+ row.append(decoded)
274
+ rows.append(row)
275
+ result_rows = []
276
+ for i, row in enumerate(rows):
277
+ result_row = []
278
+ for j, tile in enumerate(row):
279
+ # blend the above tile and the left tile
280
+ # to the current tile and add the current tile to the result row
281
+ if i > 0:
282
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
283
+ if j > 0:
284
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
285
+ result_row.append(tile[:, :, :row_limit, :row_limit])
286
+ result_rows.append(torch.cat(result_row, dim=3))
287
+
288
+ dec = torch.cat(result_rows, dim=2)
289
+ if not return_dict:
290
+ return (dec,)
291
+
292
+ return DecoderOutput(sample=dec)
293
+
294
+ def forward(
295
+ self,
296
+ sample: torch.FloatTensor,
297
+ sample_posterior: bool = False,
298
+ return_dict: bool = True,
299
+ generator: Optional[torch.Generator] = None,
300
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
301
+ r"""
302
+ Args:
303
+ sample (`torch.FloatTensor`): Input sample.
304
+ sample_posterior (`bool`, *optional*, defaults to `False`):
305
+ Whether to sample from the posterior.
306
+ return_dict (`bool`, *optional*, defaults to `True`):
307
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
308
+ """
309
+ x = sample
310
+ posterior = self.encode(x).latent_dist
311
+ if sample_posterior:
312
+ z = posterior.sample(generator=generator)
313
+ else:
314
+ z = posterior.mode()
315
+ dec = self.decode(z).sample
316
+
317
+ if not return_dict:
318
+ return (dec,)
319
+
320
+ return DecoderOutput(sample=dec)