ML-INTA commited on
Commit
42c7345
1 Parent(s): 38de11e

Delete my_diffusers

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. my_diffusers/__init__.py +0 -204
  2. my_diffusers/__pycache__/__init__.cpython-39.pyc +0 -0
  3. my_diffusers/__pycache__/configuration_utils.cpython-39.pyc +0 -0
  4. my_diffusers/commands/__init__.py +0 -27
  5. my_diffusers/commands/__pycache__/__init__.cpython-311.pyc +0 -0
  6. my_diffusers/commands/__pycache__/diffusers_cli.cpython-311.pyc +0 -0
  7. my_diffusers/commands/__pycache__/env.cpython-311.pyc +0 -0
  8. my_diffusers/commands/diffusers_cli.py +0 -41
  9. my_diffusers/commands/env.py +0 -84
  10. my_diffusers/configuration_utils.py +0 -615
  11. my_diffusers/dependency_versions_check.py +0 -47
  12. my_diffusers/dependency_versions_table.py +0 -35
  13. my_diffusers/experimental/__init__.py +0 -1
  14. my_diffusers/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
  15. my_diffusers/experimental/rl/__init__.py +0 -1
  16. my_diffusers/experimental/rl/__pycache__/__init__.cpython-311.pyc +0 -0
  17. my_diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-311.pyc +0 -0
  18. my_diffusers/experimental/rl/value_guided_sampling.py +0 -152
  19. my_diffusers/loaders.py +0 -243
  20. my_diffusers/models/__init__.py +0 -32
  21. my_diffusers/models/__pycache__/__init__.cpython-311.pyc +0 -0
  22. my_diffusers/models/__pycache__/attention.cpython-311.pyc +0 -0
  23. my_diffusers/models/__pycache__/attention_flax.cpython-311.pyc +0 -0
  24. my_diffusers/models/__pycache__/autoencoder_kl.cpython-311.pyc +0 -0
  25. my_diffusers/models/__pycache__/controlnet.cpython-311.pyc +0 -0
  26. my_diffusers/models/__pycache__/cross_attention.cpython-311.pyc +0 -0
  27. my_diffusers/models/__pycache__/dual_transformer_2d.cpython-311.pyc +0 -0
  28. my_diffusers/models/__pycache__/embeddings.cpython-311.pyc +0 -0
  29. my_diffusers/models/__pycache__/embeddings_flax.cpython-311.pyc +0 -0
  30. my_diffusers/models/__pycache__/modeling_flax_pytorch_utils.cpython-311.pyc +0 -0
  31. my_diffusers/models/__pycache__/modeling_flax_utils.cpython-311.pyc +0 -0
  32. my_diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-311.pyc +0 -0
  33. my_diffusers/models/__pycache__/modeling_utils.cpython-311.pyc +0 -0
  34. my_diffusers/models/__pycache__/prior_transformer.cpython-311.pyc +0 -0
  35. my_diffusers/models/__pycache__/resnet.cpython-311.pyc +0 -0
  36. my_diffusers/models/__pycache__/resnet_flax.cpython-311.pyc +0 -0
  37. my_diffusers/models/__pycache__/transformer_2d.cpython-311.pyc +0 -0
  38. my_diffusers/models/__pycache__/unet_1d.cpython-311.pyc +0 -0
  39. my_diffusers/models/__pycache__/unet_1d_blocks.cpython-311.pyc +0 -0
  40. my_diffusers/models/__pycache__/unet_2d.cpython-311.pyc +0 -0
  41. my_diffusers/models/__pycache__/unet_2d_blocks.cpython-311.pyc +0 -0
  42. my_diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-311.pyc +0 -0
  43. my_diffusers/models/__pycache__/unet_2d_condition.cpython-311.pyc +0 -0
  44. my_diffusers/models/__pycache__/unet_2d_condition_flax.cpython-311.pyc +0 -0
  45. my_diffusers/models/__pycache__/vae.cpython-311.pyc +0 -0
  46. my_diffusers/models/__pycache__/vae_flax.cpython-311.pyc +0 -0
  47. my_diffusers/models/__pycache__/vq_model.cpython-311.pyc +0 -0
  48. my_diffusers/models/attention.py +0 -517
  49. my_diffusers/models/attention_flax.py +0 -302
  50. my_diffusers/models/autoencoder_kl.py +0 -320
my_diffusers/__init__.py DELETED
@@ -1,204 +0,0 @@
1
- __version__ = "0.14.0"
2
-
3
- from .configuration_utils import ConfigMixin
4
- from .utils import (
5
- OptionalDependencyNotAvailable,
6
- is_flax_available,
7
- is_inflect_available,
8
- is_k_diffusion_available,
9
- is_k_diffusion_version,
10
- is_librosa_available,
11
- is_onnx_available,
12
- is_scipy_available,
13
- is_torch_available,
14
- is_transformers_available,
15
- is_transformers_version,
16
- is_unidecode_available,
17
- logging,
18
- )
19
-
20
-
21
- try:
22
- if not is_onnx_available():
23
- raise OptionalDependencyNotAvailable()
24
- except OptionalDependencyNotAvailable:
25
- from .utils.dummy_onnx_objects import * # noqa F403
26
- else:
27
- from .pipelines import OnnxRuntimeModel
28
-
29
- try:
30
- if not is_torch_available():
31
- raise OptionalDependencyNotAvailable()
32
- except OptionalDependencyNotAvailable:
33
- from .utils.dummy_pt_objects import * # noqa F403
34
- else:
35
- from .models import (
36
- AutoencoderKL,
37
- ControlNetModel,
38
- ModelMixin,
39
- PriorTransformer,
40
- Transformer2DModel,
41
- UNet1DModel,
42
- UNet2DConditionModel,
43
- UNet2DModel,
44
- VQModel,
45
- )
46
- from .optimization import (
47
- get_constant_schedule,
48
- get_constant_schedule_with_warmup,
49
- get_cosine_schedule_with_warmup,
50
- get_cosine_with_hard_restarts_schedule_with_warmup,
51
- get_linear_schedule_with_warmup,
52
- get_polynomial_decay_schedule_with_warmup,
53
- get_scheduler,
54
- )
55
- from .pipelines import (
56
- AudioPipelineOutput,
57
- DanceDiffusionPipeline,
58
- DDIMPipeline,
59
- DDPMPipeline,
60
- DiffusionPipeline,
61
- DiTPipeline,
62
- ImagePipelineOutput,
63
- KarrasVePipeline,
64
- LDMPipeline,
65
- LDMSuperResolutionPipeline,
66
- PNDMPipeline,
67
- RePaintPipeline,
68
- ScoreSdeVePipeline,
69
- )
70
- from .schedulers import (
71
- DDIMInverseScheduler,
72
- DDIMScheduler,
73
- DDPMScheduler,
74
- DEISMultistepScheduler,
75
- DPMSolverMultistepScheduler,
76
- DPMSolverSinglestepScheduler,
77
- EulerAncestralDiscreteScheduler,
78
- EulerDiscreteScheduler,
79
- HeunDiscreteScheduler,
80
- IPNDMScheduler,
81
- KarrasVeScheduler,
82
- KDPM2AncestralDiscreteScheduler,
83
- KDPM2DiscreteScheduler,
84
- PNDMScheduler,
85
- RePaintScheduler,
86
- SchedulerMixin,
87
- ScoreSdeVeScheduler,
88
- UnCLIPScheduler,
89
- UniPCMultistepScheduler,
90
- VQDiffusionScheduler,
91
- )
92
- from .training_utils import EMAModel
93
-
94
- try:
95
- if not (is_torch_available() and is_scipy_available()):
96
- raise OptionalDependencyNotAvailable()
97
- except OptionalDependencyNotAvailable:
98
- from .utils.dummy_torch_and_scipy_objects import * # noqa F403
99
- else:
100
- from .schedulers import LMSDiscreteScheduler
101
-
102
-
103
- try:
104
- if not (is_torch_available() and is_transformers_available()):
105
- raise OptionalDependencyNotAvailable()
106
- except OptionalDependencyNotAvailable:
107
- from .utils.dummy_torch_and_transformers_objects import * # noqa F403
108
- else:
109
- from .pipelines import (
110
- AltDiffusionImg2ImgPipeline,
111
- AltDiffusionPipeline,
112
- CycleDiffusionPipeline,
113
- LDMTextToImagePipeline,
114
- PaintByExamplePipeline,
115
- SemanticStableDiffusionPipeline,
116
- StableDiffusionAttendAndExcitePipeline,
117
- StableDiffusionControlNetPipeline,
118
- StableDiffusionDepth2ImgPipeline,
119
- StableDiffusionImageVariationPipeline,
120
- StableDiffusionImg2ImgPipeline,
121
- StableDiffusionInpaintPipeline,
122
- StableDiffusionInpaintPipelineLegacy,
123
- StableDiffusionInstructPix2PixPipeline,
124
- StableDiffusionLatentUpscalePipeline,
125
- StableDiffusionPanoramaPipeline,
126
- StableDiffusionPipeline,
127
- StableDiffusionPipelineSafe,
128
- StableDiffusionPix2PixZeroPipeline,
129
- StableDiffusionSAGPipeline,
130
- StableDiffusionUpscalePipeline,
131
- StableUnCLIPImg2ImgPipeline,
132
- StableUnCLIPPipeline,
133
- UnCLIPImageVariationPipeline,
134
- UnCLIPPipeline,
135
- VersatileDiffusionDualGuidedPipeline,
136
- VersatileDiffusionImageVariationPipeline,
137
- VersatileDiffusionPipeline,
138
- VersatileDiffusionTextToImagePipeline,
139
- VQDiffusionPipeline,
140
- )
141
-
142
- try:
143
- if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
144
- raise OptionalDependencyNotAvailable()
145
- except OptionalDependencyNotAvailable:
146
- from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
147
- else:
148
- from .pipelines import StableDiffusionKDiffusionPipeline
149
-
150
- try:
151
- if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
152
- raise OptionalDependencyNotAvailable()
153
- except OptionalDependencyNotAvailable:
154
- from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
155
- else:
156
- from .pipelines import (
157
- OnnxStableDiffusionImg2ImgPipeline,
158
- OnnxStableDiffusionInpaintPipeline,
159
- OnnxStableDiffusionInpaintPipelineLegacy,
160
- OnnxStableDiffusionPipeline,
161
- StableDiffusionOnnxPipeline,
162
- )
163
-
164
- try:
165
- if not (is_torch_available() and is_librosa_available()):
166
- raise OptionalDependencyNotAvailable()
167
- except OptionalDependencyNotAvailable:
168
- from .utils.dummy_torch_and_librosa_objects import * # noqa F403
169
- else:
170
- from .pipelines import AudioDiffusionPipeline, Mel
171
-
172
- try:
173
- if not is_flax_available():
174
- raise OptionalDependencyNotAvailable()
175
- except OptionalDependencyNotAvailable:
176
- from .utils.dummy_flax_objects import * # noqa F403
177
- else:
178
- from .models.modeling_flax_utils import FlaxModelMixin
179
- from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
180
- from .models.vae_flax import FlaxAutoencoderKL
181
- from .pipelines import FlaxDiffusionPipeline
182
- from .schedulers import (
183
- FlaxDDIMScheduler,
184
- FlaxDDPMScheduler,
185
- FlaxDPMSolverMultistepScheduler,
186
- FlaxKarrasVeScheduler,
187
- FlaxLMSDiscreteScheduler,
188
- FlaxPNDMScheduler,
189
- FlaxSchedulerMixin,
190
- FlaxScoreSdeVeScheduler,
191
- )
192
-
193
-
194
- try:
195
- if not (is_flax_available() and is_transformers_available()):
196
- raise OptionalDependencyNotAvailable()
197
- except OptionalDependencyNotAvailable:
198
- from .utils.dummy_flax_and_transformers_objects import * # noqa F403
199
- else:
200
- from .pipelines import (
201
- FlaxStableDiffusionImg2ImgPipeline,
202
- FlaxStableDiffusionInpaintPipeline,
203
- FlaxStableDiffusionPipeline,
204
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (5.5 kB)
 
my_diffusers/__pycache__/configuration_utils.cpython-39.pyc DELETED
Binary file (22.1 kB)
 
my_diffusers/commands/__init__.py DELETED
@@ -1,27 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from abc import ABC, abstractmethod
16
- from argparse import ArgumentParser
17
-
18
-
19
- class BaseDiffusersCLICommand(ABC):
20
- @staticmethod
21
- @abstractmethod
22
- def register_subcommand(parser: ArgumentParser):
23
- raise NotImplementedError()
24
-
25
- @abstractmethod
26
- def run(self):
27
- raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/commands/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (1.11 kB)
 
my_diffusers/commands/__pycache__/diffusers_cli.cpython-311.pyc DELETED
Binary file (1.28 kB)
 
my_diffusers/commands/__pycache__/env.cpython-311.pyc DELETED
Binary file (3.65 kB)
 
my_diffusers/commands/diffusers_cli.py DELETED
@@ -1,41 +0,0 @@
1
- #!/usr/bin/env python
2
- # Copyright 2023 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from argparse import ArgumentParser
17
-
18
- from .env import EnvironmentCommand
19
-
20
-
21
- def main():
22
- parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
- commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
-
25
- # Register commands
26
- EnvironmentCommand.register_subcommand(commands_parser)
27
-
28
- # Let's go
29
- args = parser.parse_args()
30
-
31
- if not hasattr(args, "func"):
32
- parser.print_help()
33
- exit(1)
34
-
35
- # Run
36
- service = args.func(args)
37
- service.run()
38
-
39
-
40
- if __name__ == "__main__":
41
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/commands/env.py DELETED
@@ -1,84 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import platform
16
- from argparse import ArgumentParser
17
-
18
- import huggingface_hub
19
-
20
- from .. import __version__ as version
21
- from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
- from . import BaseDiffusersCLICommand
23
-
24
-
25
- def info_command_factory(_):
26
- return EnvironmentCommand()
27
-
28
-
29
- class EnvironmentCommand(BaseDiffusersCLICommand):
30
- @staticmethod
31
- def register_subcommand(parser: ArgumentParser):
32
- download_parser = parser.add_parser("env")
33
- download_parser.set_defaults(func=info_command_factory)
34
-
35
- def run(self):
36
- hub_version = huggingface_hub.__version__
37
-
38
- pt_version = "not installed"
39
- pt_cuda_available = "NA"
40
- if is_torch_available():
41
- import torch
42
-
43
- pt_version = torch.__version__
44
- pt_cuda_available = torch.cuda.is_available()
45
-
46
- transformers_version = "not installed"
47
- if is_transformers_available():
48
- import transformers
49
-
50
- transformers_version = transformers.__version__
51
-
52
- accelerate_version = "not installed"
53
- if is_accelerate_available():
54
- import accelerate
55
-
56
- accelerate_version = accelerate.__version__
57
-
58
- xformers_version = "not installed"
59
- if is_xformers_available():
60
- import xformers
61
-
62
- xformers_version = xformers.__version__
63
-
64
- info = {
65
- "`diffusers` version": version,
66
- "Platform": platform.platform(),
67
- "Python version": platform.python_version(),
68
- "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69
- "Huggingface_hub version": hub_version,
70
- "Transformers version": transformers_version,
71
- "Accelerate version": accelerate_version,
72
- "xFormers version": xformers_version,
73
- "Using GPU in script?": "<fill in>",
74
- "Using distributed or parallel set-up in script?": "<fill in>",
75
- }
76
-
77
- print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78
- print(self.format_dict(info))
79
-
80
- return info
81
-
82
- @staticmethod
83
- def format_dict(d):
84
- return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/configuration_utils.py DELETED
@@ -1,615 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """ ConfigMixin base class and utilities."""
17
- import dataclasses
18
- import functools
19
- import importlib
20
- import inspect
21
- import json
22
- import os
23
- import re
24
- from collections import OrderedDict
25
- from pathlib import PosixPath
26
- from typing import Any, Dict, Tuple, Union
27
-
28
- import numpy as np
29
- from huggingface_hub import hf_hub_download
30
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
- from requests import HTTPError
32
-
33
- from . import __version__
34
- from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
35
-
36
-
37
- logger = logging.get_logger(__name__)
38
-
39
- _re_configuration_file = re.compile(r"config\.(.*)\.json")
40
-
41
-
42
- class FrozenDict(OrderedDict):
43
- def __init__(self, *args, **kwargs):
44
- super().__init__(*args, **kwargs)
45
-
46
- for key, value in self.items():
47
- setattr(self, key, value)
48
-
49
- self.__frozen = True
50
-
51
- def __delitem__(self, *args, **kwargs):
52
- raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
53
-
54
- def setdefault(self, *args, **kwargs):
55
- raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
56
-
57
- def pop(self, *args, **kwargs):
58
- raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
59
-
60
- def update(self, *args, **kwargs):
61
- raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
62
-
63
- def __setattr__(self, name, value):
64
- if hasattr(self, "__frozen") and self.__frozen:
65
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
66
- super().__setattr__(name, value)
67
-
68
- def __setitem__(self, name, value):
69
- if hasattr(self, "__frozen") and self.__frozen:
70
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
71
- super().__setitem__(name, value)
72
-
73
-
74
- class ConfigMixin:
75
- r"""
76
- Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
77
- methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
78
- - [`~ConfigMixin.from_config`]
79
- - [`~ConfigMixin.save_config`]
80
-
81
- Class attributes:
82
- - **config_name** (`str`) -- A filename under which the config should stored when calling
83
- [`~ConfigMixin.save_config`] (should be overridden by parent class).
84
- - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
85
- overridden by subclass).
86
- - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
87
- - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
88
- should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
89
- subclass).
90
- """
91
- config_name = None
92
- ignore_for_config = []
93
- has_compatibles = False
94
-
95
- _deprecated_kwargs = []
96
-
97
- def register_to_config(self, **kwargs):
98
- if self.config_name is None:
99
- raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
100
- # Special case for `kwargs` used in deprecation warning added to schedulers
101
- # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
102
- # or solve in a more general way.
103
- kwargs.pop("kwargs", None)
104
- for key, value in kwargs.items():
105
- try:
106
- setattr(self, key, value)
107
- except AttributeError as err:
108
- logger.error(f"Can't set {key} with value {value} for {self}")
109
- raise err
110
-
111
- if not hasattr(self, "_internal_dict"):
112
- internal_dict = kwargs
113
- else:
114
- previous_dict = dict(self._internal_dict)
115
- internal_dict = {**self._internal_dict, **kwargs}
116
- logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
117
-
118
- self._internal_dict = FrozenDict(internal_dict)
119
-
120
- def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
121
- """
122
- Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
123
- [`~ConfigMixin.from_config`] class method.
124
-
125
- Args:
126
- save_directory (`str` or `os.PathLike`):
127
- Directory where the configuration JSON file will be saved (will be created if it does not exist).
128
- """
129
- if os.path.isfile(save_directory):
130
- raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
131
-
132
- os.makedirs(save_directory, exist_ok=True)
133
-
134
- # If we save using the predefined names, we can load using `from_config`
135
- output_config_file = os.path.join(save_directory, self.config_name)
136
-
137
- self.to_json_file(output_config_file)
138
- logger.info(f"Configuration saved in {output_config_file}")
139
-
140
- @classmethod
141
- def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
142
- r"""
143
- Instantiate a Python class from a config dictionary
144
-
145
- Parameters:
146
- config (`Dict[str, Any]`):
147
- A config dictionary from which the Python class will be instantiated. Make sure to only load
148
- configuration files of compatible classes.
149
- return_unused_kwargs (`bool`, *optional*, defaults to `False`):
150
- Whether kwargs that are not consumed by the Python class should be returned or not.
151
-
152
- kwargs (remaining dictionary of keyword arguments, *optional*):
153
- Can be used to update the configuration object (after it being loaded) and initiate the Python class.
154
- `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
155
- overwrite same named arguments of `config`.
156
-
157
- Examples:
158
-
159
- ```python
160
- >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
161
-
162
- >>> # Download scheduler from huggingface.co and cache.
163
- >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
164
-
165
- >>> # Instantiate DDIM scheduler class with same config as DDPM
166
- >>> scheduler = DDIMScheduler.from_config(scheduler.config)
167
-
168
- >>> # Instantiate PNDM scheduler class with same config as DDPM
169
- >>> scheduler = PNDMScheduler.from_config(scheduler.config)
170
- ```
171
- """
172
- # <===== TO BE REMOVED WITH DEPRECATION
173
- # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
174
- if "pretrained_model_name_or_path" in kwargs:
175
- config = kwargs.pop("pretrained_model_name_or_path")
176
-
177
- if config is None:
178
- raise ValueError("Please make sure to provide a config as the first positional argument.")
179
- # ======>
180
-
181
- if not isinstance(config, dict):
182
- deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
183
- if "Scheduler" in cls.__name__:
184
- deprecation_message += (
185
- f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
186
- " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
187
- " be removed in v1.0.0."
188
- )
189
- elif "Model" in cls.__name__:
190
- deprecation_message += (
191
- f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
192
- f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
193
- " instead. This functionality will be removed in v1.0.0."
194
- )
195
- deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
196
- config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
197
-
198
- init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
199
-
200
- # Allow dtype to be specified on initialization
201
- if "dtype" in unused_kwargs:
202
- init_dict["dtype"] = unused_kwargs.pop("dtype")
203
-
204
- # add possible deprecated kwargs
205
- for deprecated_kwarg in cls._deprecated_kwargs:
206
- if deprecated_kwarg in unused_kwargs:
207
- init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
208
-
209
- # Return model and optionally state and/or unused_kwargs
210
- model = cls(**init_dict)
211
-
212
- # make sure to also save config parameters that might be used for compatible classes
213
- model.register_to_config(**hidden_dict)
214
-
215
- # add hidden kwargs of compatible classes to unused_kwargs
216
- unused_kwargs = {**unused_kwargs, **hidden_dict}
217
-
218
- if return_unused_kwargs:
219
- return (model, unused_kwargs)
220
- else:
221
- return model
222
-
223
- @classmethod
224
- def get_config_dict(cls, *args, **kwargs):
225
- deprecation_message = (
226
- f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
227
- " removed in version v1.0.0"
228
- )
229
- deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
230
- return cls.load_config(*args, **kwargs)
231
-
232
- @classmethod
233
- def load_config(
234
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
235
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
236
- r"""
237
- Instantiate a Python class from a config dictionary
238
-
239
- Parameters:
240
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
241
- Can be either:
242
-
243
- - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
244
- organization name, like `google/ddpm-celebahq-256`.
245
- - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
246
- `./my_model_directory/`.
247
-
248
- cache_dir (`Union[str, os.PathLike]`, *optional*):
249
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
250
- standard cache should not be used.
251
- force_download (`bool`, *optional*, defaults to `False`):
252
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
253
- cached versions if they exist.
254
- resume_download (`bool`, *optional*, defaults to `False`):
255
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
256
- file exists.
257
- proxies (`Dict[str, str]`, *optional*):
258
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
259
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
260
- output_loading_info(`bool`, *optional*, defaults to `False`):
261
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
262
- local_files_only(`bool`, *optional*, defaults to `False`):
263
- Whether or not to only look at local files (i.e., do not try to download the model).
264
- use_auth_token (`str` or *bool*, *optional*):
265
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
266
- when running `transformers-cli login` (stored in `~/.huggingface`).
267
- revision (`str`, *optional*, defaults to `"main"`):
268
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
269
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
270
- identifier allowed by git.
271
- subfolder (`str`, *optional*, defaults to `""`):
272
- In case the relevant files are located inside a subfolder of the model repo (either remote in
273
- huggingface.co or downloaded locally), you can specify the folder name here.
274
-
275
- <Tip>
276
-
277
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
278
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
279
-
280
- </Tip>
281
-
282
- <Tip>
283
-
284
- Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
285
- use this method in a firewalled environment.
286
-
287
- </Tip>
288
- """
289
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
290
- force_download = kwargs.pop("force_download", False)
291
- resume_download = kwargs.pop("resume_download", False)
292
- proxies = kwargs.pop("proxies", None)
293
- use_auth_token = kwargs.pop("use_auth_token", None)
294
- local_files_only = kwargs.pop("local_files_only", False)
295
- revision = kwargs.pop("revision", None)
296
- _ = kwargs.pop("mirror", None)
297
- subfolder = kwargs.pop("subfolder", None)
298
-
299
- user_agent = {"file_type": "config"}
300
-
301
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
302
-
303
- if cls.config_name is None:
304
- raise ValueError(
305
- "`self.config_name` is not defined. Note that one should not load a config from "
306
- "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
307
- )
308
-
309
- if os.path.isfile(pretrained_model_name_or_path):
310
- config_file = pretrained_model_name_or_path
311
- elif os.path.isdir(pretrained_model_name_or_path):
312
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
313
- # Load from a PyTorch checkpoint
314
- config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
315
- elif subfolder is not None and os.path.isfile(
316
- os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
317
- ):
318
- config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
319
- else:
320
- raise EnvironmentError(
321
- f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
322
- )
323
- else:
324
- try:
325
- # Load from URL or cache if already cached
326
- config_file = hf_hub_download(
327
- pretrained_model_name_or_path,
328
- filename=cls.config_name,
329
- cache_dir=cache_dir,
330
- force_download=force_download,
331
- proxies=proxies,
332
- resume_download=resume_download,
333
- local_files_only=local_files_only,
334
- use_auth_token=use_auth_token,
335
- user_agent=user_agent,
336
- subfolder=subfolder,
337
- revision=revision,
338
- )
339
-
340
- except RepositoryNotFoundError:
341
- raise EnvironmentError(
342
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
343
- " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
344
- " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
345
- " login`."
346
- )
347
- except RevisionNotFoundError:
348
- raise EnvironmentError(
349
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
350
- " this model name. Check the model page at"
351
- f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
352
- )
353
- except EntryNotFoundError:
354
- raise EnvironmentError(
355
- f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
356
- )
357
- except HTTPError as err:
358
- raise EnvironmentError(
359
- "There was a specific connection error when trying to load"
360
- f" {pretrained_model_name_or_path}:\n{err}"
361
- )
362
- except ValueError:
363
- raise EnvironmentError(
364
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
365
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
366
- f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
367
- " run the library in offline mode at"
368
- " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
369
- )
370
- except EnvironmentError:
371
- raise EnvironmentError(
372
- f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
373
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
374
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
375
- f"containing a {cls.config_name} file"
376
- )
377
-
378
- try:
379
- # Load config dict
380
- config_dict = cls._dict_from_json_file(config_file)
381
- except (json.JSONDecodeError, UnicodeDecodeError):
382
- raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
383
-
384
- if return_unused_kwargs:
385
- return config_dict, kwargs
386
-
387
- return config_dict
388
-
389
- @staticmethod
390
- def _get_init_keys(cls):
391
- return set(dict(inspect.signature(cls.__init__).parameters).keys())
392
-
393
- @classmethod
394
- def extract_init_dict(cls, config_dict, **kwargs):
395
- # 0. Copy origin config dict
396
- original_dict = {k: v for k, v in config_dict.items()}
397
-
398
- # 1. Retrieve expected config attributes from __init__ signature
399
- expected_keys = cls._get_init_keys(cls)
400
- expected_keys.remove("self")
401
- # remove general kwargs if present in dict
402
- if "kwargs" in expected_keys:
403
- expected_keys.remove("kwargs")
404
- # remove flax internal keys
405
- if hasattr(cls, "_flax_internal_args"):
406
- for arg in cls._flax_internal_args:
407
- expected_keys.remove(arg)
408
-
409
- # 2. Remove attributes that cannot be expected from expected config attributes
410
- # remove keys to be ignored
411
- if len(cls.ignore_for_config) > 0:
412
- expected_keys = expected_keys - set(cls.ignore_for_config)
413
-
414
- # load diffusers library to import compatible and original scheduler
415
- diffusers_library = importlib.import_module(__name__.split(".")[0])
416
-
417
- if cls.has_compatibles:
418
- compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
419
- else:
420
- compatible_classes = []
421
-
422
- expected_keys_comp_cls = set()
423
- for c in compatible_classes:
424
- expected_keys_c = cls._get_init_keys(c)
425
- expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
426
- expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
427
- config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
428
-
429
- # remove attributes from orig class that cannot be expected
430
- orig_cls_name = config_dict.pop("_class_name", cls.__name__)
431
- if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
432
- orig_cls = getattr(diffusers_library, orig_cls_name)
433
- unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
434
- config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
435
-
436
- # remove private attributes
437
- config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
438
-
439
- # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
440
- init_dict = {}
441
- for key in expected_keys:
442
- # if config param is passed to kwarg and is present in config dict
443
- # it should overwrite existing config dict key
444
- if key in kwargs and key in config_dict:
445
- config_dict[key] = kwargs.pop(key)
446
-
447
- if key in kwargs:
448
- # overwrite key
449
- init_dict[key] = kwargs.pop(key)
450
- elif key in config_dict:
451
- # use value from config dict
452
- init_dict[key] = config_dict.pop(key)
453
-
454
- # 4. Give nice warning if unexpected values have been passed
455
- if len(config_dict) > 0:
456
- logger.warning(
457
- f"The config attributes {config_dict} were passed to {cls.__name__}, "
458
- "but are not expected and will be ignored. Please verify your "
459
- f"{cls.config_name} configuration file."
460
- )
461
-
462
- # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
463
- passed_keys = set(init_dict.keys())
464
- if len(expected_keys - passed_keys) > 0:
465
- logger.info(
466
- f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
467
- )
468
-
469
- # 6. Define unused keyword arguments
470
- unused_kwargs = {**config_dict, **kwargs}
471
-
472
- # 7. Define "hidden" config parameters that were saved for compatible classes
473
- hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
474
-
475
- return init_dict, unused_kwargs, hidden_config_dict
476
-
477
- @classmethod
478
- def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
479
- with open(json_file, "r", encoding="utf-8") as reader:
480
- text = reader.read()
481
- return json.loads(text)
482
-
483
- def __repr__(self):
484
- return f"{self.__class__.__name__} {self.to_json_string()}"
485
-
486
- @property
487
- def config(self) -> Dict[str, Any]:
488
- """
489
- Returns the config of the class as a frozen dictionary
490
-
491
- Returns:
492
- `Dict[str, Any]`: Config of the class.
493
- """
494
- return self._internal_dict
495
-
496
- def to_json_string(self) -> str:
497
- """
498
- Serializes this instance to a JSON string.
499
-
500
- Returns:
501
- `str`: String containing all the attributes that make up this configuration instance in JSON format.
502
- """
503
- config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
504
- config_dict["_class_name"] = self.__class__.__name__
505
- config_dict["_diffusers_version"] = __version__
506
-
507
- def to_json_saveable(value):
508
- if isinstance(value, np.ndarray):
509
- value = value.tolist()
510
- elif isinstance(value, PosixPath):
511
- value = str(value)
512
- return value
513
-
514
- config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
515
- return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
516
-
517
- def to_json_file(self, json_file_path: Union[str, os.PathLike]):
518
- """
519
- Save this instance to a JSON file.
520
-
521
- Args:
522
- json_file_path (`str` or `os.PathLike`):
523
- Path to the JSON file in which this configuration instance's parameters will be saved.
524
- """
525
- with open(json_file_path, "w", encoding="utf-8") as writer:
526
- writer.write(self.to_json_string())
527
-
528
-
529
- def register_to_config(init):
530
- r"""
531
- Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
532
- automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
533
- shouldn't be registered in the config, use the `ignore_for_config` class variable
534
-
535
- Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
536
- """
537
-
538
- @functools.wraps(init)
539
- def inner_init(self, *args, **kwargs):
540
- # Ignore private kwargs in the init.
541
- init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
542
- config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
543
- if not isinstance(self, ConfigMixin):
544
- raise RuntimeError(
545
- f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
546
- "not inherit from `ConfigMixin`."
547
- )
548
-
549
- ignore = getattr(self, "ignore_for_config", [])
550
- # Get positional arguments aligned with kwargs
551
- new_kwargs = {}
552
- signature = inspect.signature(init)
553
- parameters = {
554
- name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
555
- }
556
- for arg, name in zip(args, parameters.keys()):
557
- new_kwargs[name] = arg
558
-
559
- # Then add all kwargs
560
- new_kwargs.update(
561
- {
562
- k: init_kwargs.get(k, default)
563
- for k, default in parameters.items()
564
- if k not in ignore and k not in new_kwargs
565
- }
566
- )
567
- new_kwargs = {**config_init_kwargs, **new_kwargs}
568
- getattr(self, "register_to_config")(**new_kwargs)
569
- init(self, *args, **init_kwargs)
570
-
571
- return inner_init
572
-
573
-
574
- def flax_register_to_config(cls):
575
- original_init = cls.__init__
576
-
577
- @functools.wraps(original_init)
578
- def init(self, *args, **kwargs):
579
- if not isinstance(self, ConfigMixin):
580
- raise RuntimeError(
581
- f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
582
- "not inherit from `ConfigMixin`."
583
- )
584
-
585
- # Ignore private kwargs in the init. Retrieve all passed attributes
586
- init_kwargs = {k: v for k, v in kwargs.items()}
587
-
588
- # Retrieve default values
589
- fields = dataclasses.fields(self)
590
- default_kwargs = {}
591
- for field in fields:
592
- # ignore flax specific attributes
593
- if field.name in self._flax_internal_args:
594
- continue
595
- if type(field.default) == dataclasses._MISSING_TYPE:
596
- default_kwargs[field.name] = None
597
- else:
598
- default_kwargs[field.name] = getattr(self, field.name)
599
-
600
- # Make sure init_kwargs override default kwargs
601
- new_kwargs = {**default_kwargs, **init_kwargs}
602
- # dtype should be part of `init_kwargs`, but not `new_kwargs`
603
- if "dtype" in new_kwargs:
604
- new_kwargs.pop("dtype")
605
-
606
- # Get positional arguments aligned with kwargs
607
- for i, arg in enumerate(args):
608
- name = fields[i].name
609
- new_kwargs[name] = arg
610
-
611
- getattr(self, "register_to_config")(**new_kwargs)
612
- original_init(self, *args, **kwargs)
613
-
614
- cls.__init__ = init
615
- return cls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/dependency_versions_check.py DELETED
@@ -1,47 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import sys
15
-
16
- from .dependency_versions_table import deps
17
- from .utils.versions import require_version, require_version_core
18
-
19
-
20
- # define which module versions we always want to check at run time
21
- # (usually the ones defined in `install_requires` in setup.py)
22
- #
23
- # order specific notes:
24
- # - tqdm must be checked before tokenizers
25
-
26
- pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
- if sys.version_info < (3, 7):
28
- pkgs_to_check_at_runtime.append("dataclasses")
29
- if sys.version_info < (3, 8):
30
- pkgs_to_check_at_runtime.append("importlib_metadata")
31
-
32
- for pkg in pkgs_to_check_at_runtime:
33
- if pkg in deps:
34
- if pkg == "tokenizers":
35
- # must be loaded here, or else tqdm check may fail
36
- from .utils import is_tokenizers_available
37
-
38
- if not is_tokenizers_available():
39
- continue # not required, check version only if installed
40
-
41
- require_version_core(deps[pkg])
42
- else:
43
- raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
-
45
-
46
- def dep_version_check(pkg, hint=None):
47
- require_version(deps[pkg], hint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/dependency_versions_table.py DELETED
@@ -1,35 +0,0 @@
1
- # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
- # 1. modify the `_deps` dict in setup.py
3
- # 2. run `make deps_table_update``
4
- deps = {
5
- "Pillow": "Pillow",
6
- "accelerate": "accelerate>=0.11.0",
7
- "black": "black~=23.1",
8
- "datasets": "datasets",
9
- "filelock": "filelock",
10
- "flax": "flax>=0.4.1",
11
- "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
- "huggingface-hub": "huggingface-hub>=0.10.0",
13
- "importlib_metadata": "importlib_metadata",
14
- "isort": "isort>=5.5.4",
15
- "jax": "jax>=0.2.8,!=0.3.2",
16
- "jaxlib": "jaxlib>=0.1.65",
17
- "Jinja2": "Jinja2",
18
- "k-diffusion": "k-diffusion>=0.0.12",
19
- "librosa": "librosa",
20
- "numpy": "numpy",
21
- "parameterized": "parameterized",
22
- "pytest": "pytest",
23
- "pytest-timeout": "pytest-timeout",
24
- "pytest-xdist": "pytest-xdist",
25
- "ruff": "ruff>=0.0.241",
26
- "safetensors": "safetensors",
27
- "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
28
- "scipy": "scipy",
29
- "regex": "regex!=2019.12.17",
30
- "requests": "requests",
31
- "tensorboard": "tensorboard",
32
- "torch": "torch>=1.4",
33
- "torchvision": "torchvision",
34
- "transformers": "transformers>=4.25.1",
35
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/experimental/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .rl import ValueGuidedRLPipeline
 
 
my_diffusers/experimental/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (264 Bytes)
 
my_diffusers/experimental/rl/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .value_guided_sampling import ValueGuidedRLPipeline
 
 
my_diffusers/experimental/rl/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (286 Bytes)
 
my_diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-311.pyc DELETED
Binary file (8.86 kB)
 
my_diffusers/experimental/rl/value_guided_sampling.py DELETED
@@ -1,152 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import numpy as np
16
- import torch
17
- import tqdm
18
-
19
- from ...models.unet_1d import UNet1DModel
20
- from ...pipelines import DiffusionPipeline
21
- from ...utils import randn_tensor
22
- from ...utils.dummy_pt_objects import DDPMScheduler
23
-
24
-
25
- class ValueGuidedRLPipeline(DiffusionPipeline):
26
- r"""
27
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
- Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
30
-
31
- Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
32
-
33
- Parameters:
34
- value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
35
- unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
36
- scheduler ([`SchedulerMixin`]):
37
- A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
38
- application is [`DDPMScheduler`].
39
- env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
40
- """
41
-
42
- def __init__(
43
- self,
44
- value_function: UNet1DModel,
45
- unet: UNet1DModel,
46
- scheduler: DDPMScheduler,
47
- env,
48
- ):
49
- super().__init__()
50
- self.value_function = value_function
51
- self.unet = unet
52
- self.scheduler = scheduler
53
- self.env = env
54
- self.data = env.get_dataset()
55
- self.means = dict()
56
- for key in self.data.keys():
57
- try:
58
- self.means[key] = self.data[key].mean()
59
- except: # noqa: E722
60
- pass
61
- self.stds = dict()
62
- for key in self.data.keys():
63
- try:
64
- self.stds[key] = self.data[key].std()
65
- except: # noqa: E722
66
- pass
67
- self.state_dim = env.observation_space.shape[0]
68
- self.action_dim = env.action_space.shape[0]
69
-
70
- def normalize(self, x_in, key):
71
- return (x_in - self.means[key]) / self.stds[key]
72
-
73
- def de_normalize(self, x_in, key):
74
- return x_in * self.stds[key] + self.means[key]
75
-
76
- def to_torch(self, x_in):
77
- if type(x_in) is dict:
78
- return {k: self.to_torch(v) for k, v in x_in.items()}
79
- elif torch.is_tensor(x_in):
80
- return x_in.to(self.unet.device)
81
- return torch.tensor(x_in, device=self.unet.device)
82
-
83
- def reset_x0(self, x_in, cond, act_dim):
84
- for key, val in cond.items():
85
- x_in[:, key, act_dim:] = val.clone()
86
- return x_in
87
-
88
- def run_diffusion(self, x, conditions, n_guide_steps, scale):
89
- batch_size = x.shape[0]
90
- y = None
91
- for i in tqdm.tqdm(self.scheduler.timesteps):
92
- # create batch of timesteps to pass into model
93
- timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
94
- for _ in range(n_guide_steps):
95
- with torch.enable_grad():
96
- x.requires_grad_()
97
-
98
- # permute to match dimension for pre-trained models
99
- y = self.value_function(x.permute(0, 2, 1), timesteps).sample
100
- grad = torch.autograd.grad([y.sum()], [x])[0]
101
-
102
- posterior_variance = self.scheduler._get_variance(i)
103
- model_std = torch.exp(0.5 * posterior_variance)
104
- grad = model_std * grad
105
-
106
- grad[timesteps < 2] = 0
107
- x = x.detach()
108
- x = x + scale * grad
109
- x = self.reset_x0(x, conditions, self.action_dim)
110
-
111
- prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
112
-
113
- # TODO: verify deprecation of this kwarg
114
- x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
115
-
116
- # apply conditions to the trajectory (set the initial state)
117
- x = self.reset_x0(x, conditions, self.action_dim)
118
- x = self.to_torch(x)
119
- return x, y
120
-
121
- def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
122
- # normalize the observations and create batch dimension
123
- obs = self.normalize(obs, "observations")
124
- obs = obs[None].repeat(batch_size, axis=0)
125
-
126
- conditions = {0: self.to_torch(obs)}
127
- shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
128
-
129
- # generate initial noise and apply our conditions (to make the trajectories start at current state)
130
- x1 = randn_tensor(shape, device=self.unet.device)
131
- x = self.reset_x0(x1, conditions, self.action_dim)
132
- x = self.to_torch(x)
133
-
134
- # run the diffusion process
135
- x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
136
-
137
- # sort output trajectories by value
138
- sorted_idx = y.argsort(0, descending=True).squeeze()
139
- sorted_values = x[sorted_idx]
140
- actions = sorted_values[:, :, : self.action_dim]
141
- actions = actions.detach().cpu().numpy()
142
- denorm_actions = self.de_normalize(actions, key="actions")
143
-
144
- # select the action with the highest value
145
- if y is not None:
146
- selected_index = 0
147
- else:
148
- # if we didn't run value guiding, select a random action
149
- selected_index = np.random.randint(0, batch_size)
150
-
151
- denorm_actions = denorm_actions[selected_index, 0]
152
- return denorm_actions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/loaders.py DELETED
@@ -1,243 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import os
15
- from collections import defaultdict
16
- from typing import Callable, Dict, Union
17
-
18
- import torch
19
-
20
- from .models.cross_attention import LoRACrossAttnProcessor
21
- from .models.modeling_utils import _get_model_file
22
- from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
23
-
24
-
25
- logger = logging.get_logger(__name__)
26
-
27
-
28
- LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
29
-
30
-
31
- class AttnProcsLayers(torch.nn.Module):
32
- def __init__(self, state_dict: Dict[str, torch.Tensor]):
33
- super().__init__()
34
- self.layers = torch.nn.ModuleList(state_dict.values())
35
- self.mapping = {k: v for k, v in enumerate(state_dict.keys())}
36
- self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
37
-
38
- # we add a hook to state_dict() and load_state_dict() so that the
39
- # naming fits with `unet.attn_processors`
40
- def map_to(module, state_dict, *args, **kwargs):
41
- new_state_dict = {}
42
- for key, value in state_dict.items():
43
- num = int(key.split(".")[1]) # 0 is always "layers"
44
- new_key = key.replace(f"layers.{num}", module.mapping[num])
45
- new_state_dict[new_key] = value
46
-
47
- return new_state_dict
48
-
49
- def map_from(module, state_dict, *args, **kwargs):
50
- all_keys = list(state_dict.keys())
51
- for key in all_keys:
52
- replace_key = key.split(".processor")[0] + ".processor"
53
- new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
54
- state_dict[new_key] = state_dict[key]
55
- del state_dict[key]
56
-
57
- self._register_state_dict_hook(map_to)
58
- self._register_load_state_dict_pre_hook(map_from, with_module=True)
59
-
60
-
61
- class UNet2DConditionLoadersMixin:
62
- def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
63
- r"""
64
- Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
65
- defined in
66
- [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
67
- and be a `torch.nn.Module` class.
68
-
69
- <Tip warning={true}>
70
-
71
- This function is experimental and might change in the future.
72
-
73
- </Tip>
74
-
75
- Parameters:
76
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
77
- Can be either:
78
-
79
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
80
- Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
81
- - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
82
- `./my_model_directory/`.
83
- - A [torch state
84
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
85
-
86
- cache_dir (`Union[str, os.PathLike]`, *optional*):
87
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
88
- standard cache should not be used.
89
- force_download (`bool`, *optional*, defaults to `False`):
90
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
91
- cached versions if they exist.
92
- resume_download (`bool`, *optional*, defaults to `False`):
93
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
94
- file exists.
95
- proxies (`Dict[str, str]`, *optional*):
96
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
97
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
98
- local_files_only(`bool`, *optional*, defaults to `False`):
99
- Whether or not to only look at local files (i.e., do not try to download the model).
100
- use_auth_token (`str` or *bool*, *optional*):
101
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
102
- when running `diffusers-cli login` (stored in `~/.huggingface`).
103
- revision (`str`, *optional*, defaults to `"main"`):
104
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
105
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
106
- identifier allowed by git.
107
- subfolder (`str`, *optional*, defaults to `""`):
108
- In case the relevant files are located inside a subfolder of the model repo (either remote in
109
- huggingface.co or downloaded locally), you can specify the folder name here.
110
-
111
- mirror (`str`, *optional*):
112
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
113
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
114
- Please refer to the mirror site for more information.
115
-
116
- <Tip>
117
-
118
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
119
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
120
-
121
- </Tip>
122
-
123
- <Tip>
124
-
125
- Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
126
- this method in a firewalled environment.
127
-
128
- </Tip>
129
- """
130
-
131
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
132
- force_download = kwargs.pop("force_download", False)
133
- resume_download = kwargs.pop("resume_download", False)
134
- proxies = kwargs.pop("proxies", None)
135
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
136
- use_auth_token = kwargs.pop("use_auth_token", None)
137
- revision = kwargs.pop("revision", None)
138
- subfolder = kwargs.pop("subfolder", None)
139
- weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
140
-
141
- user_agent = {
142
- "file_type": "attn_procs_weights",
143
- "framework": "pytorch",
144
- }
145
-
146
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
147
- model_file = _get_model_file(
148
- pretrained_model_name_or_path_or_dict,
149
- weights_name=weight_name,
150
- cache_dir=cache_dir,
151
- force_download=force_download,
152
- resume_download=resume_download,
153
- proxies=proxies,
154
- local_files_only=local_files_only,
155
- use_auth_token=use_auth_token,
156
- revision=revision,
157
- subfolder=subfolder,
158
- user_agent=user_agent,
159
- )
160
- state_dict = torch.load(model_file, map_location="cpu")
161
- else:
162
- state_dict = pretrained_model_name_or_path_or_dict
163
-
164
- # fill attn processors
165
- attn_processors = {}
166
-
167
- is_lora = all("lora" in k for k in state_dict.keys())
168
-
169
- if is_lora:
170
- lora_grouped_dict = defaultdict(dict)
171
- for key, value in state_dict.items():
172
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
173
- lora_grouped_dict[attn_processor_key][sub_key] = value
174
-
175
- for key, value_dict in lora_grouped_dict.items():
176
- rank = value_dict["to_k_lora.down.weight"].shape[0]
177
- cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
178
- hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
179
-
180
- attn_processors[key] = LoRACrossAttnProcessor(
181
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
182
- )
183
- attn_processors[key].load_state_dict(value_dict)
184
-
185
- else:
186
- raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
187
-
188
- # set correct dtype & device
189
- attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
190
-
191
- # set layers
192
- self.set_attn_processor(attn_processors)
193
-
194
- def save_attn_procs(
195
- self,
196
- save_directory: Union[str, os.PathLike],
197
- is_main_process: bool = True,
198
- weights_name: str = LORA_WEIGHT_NAME,
199
- save_function: Callable = None,
200
- ):
201
- r"""
202
- Save an attention processor to a directory, so that it can be re-loaded using the
203
- `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method.
204
-
205
- Arguments:
206
- save_directory (`str` or `os.PathLike`):
207
- Directory to which to save. Will be created if it doesn't exist.
208
- is_main_process (`bool`, *optional*, defaults to `True`):
209
- Whether the process calling this is the main process or not. Useful when in distributed training like
210
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
211
- the main process to avoid race conditions.
212
- save_function (`Callable`):
213
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
214
- need to replace `torch.save` by another method. Can be configured with the environment variable
215
- `DIFFUSERS_SAVE_MODE`.
216
- """
217
- if os.path.isfile(save_directory):
218
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
219
- return
220
-
221
- if save_function is None:
222
- save_function = torch.save
223
-
224
- os.makedirs(save_directory, exist_ok=True)
225
-
226
- model_to_save = AttnProcsLayers(self.attn_processors)
227
-
228
- # Save the model
229
- state_dict = model_to_save.state_dict()
230
-
231
- # Clean the folder from a previous save
232
- for filename in os.listdir(save_directory):
233
- full_filename = os.path.join(save_directory, filename)
234
- # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
235
- # in distributed settings to avoid race conditions.
236
- weights_no_suffix = weights_name.replace(".bin", "")
237
- if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
238
- os.remove(full_filename)
239
-
240
- # Save the model
241
- save_function(state_dict, os.path.join(save_directory, weights_name))
242
-
243
- logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/models/__init__.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from ..utils import is_flax_available, is_torch_available
16
-
17
-
18
- if is_torch_available():
19
- from .autoencoder_kl import AutoencoderKL
20
- from .controlnet import ControlNetModel
21
- from .dual_transformer_2d import DualTransformer2DModel
22
- from .modeling_utils import ModelMixin
23
- from .prior_transformer import PriorTransformer
24
- from .transformer_2d import Transformer2DModel
25
- from .unet_1d import UNet1DModel
26
- from .unet_2d import UNet2DModel
27
- from .unet_2d_condition import UNet2DConditionModel
28
- from .vq_model import VQModel
29
-
30
- if is_flax_available():
31
- from .unet_2d_condition_flax import FlaxUNet2DConditionModel
32
- from .vae_flax import FlaxAutoencoderKL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/models/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (1.22 kB)
 
my_diffusers/models/__pycache__/attention.cpython-311.pyc DELETED
Binary file (25.8 kB)
 
my_diffusers/models/__pycache__/attention_flax.cpython-311.pyc DELETED
Binary file (14.6 kB)
 
my_diffusers/models/__pycache__/autoencoder_kl.cpython-311.pyc DELETED
Binary file (17.9 kB)
 
my_diffusers/models/__pycache__/controlnet.cpython-311.pyc DELETED
Binary file (23.7 kB)
 
my_diffusers/models/__pycache__/cross_attention.cpython-311.pyc DELETED
Binary file (33.2 kB)
 
my_diffusers/models/__pycache__/dual_transformer_2d.cpython-311.pyc DELETED
Binary file (7.08 kB)
 
my_diffusers/models/__pycache__/embeddings.cpython-311.pyc DELETED
Binary file (19.2 kB)
 
my_diffusers/models/__pycache__/embeddings_flax.cpython-311.pyc DELETED
Binary file (4.9 kB)
 
my_diffusers/models/__pycache__/modeling_flax_pytorch_utils.cpython-311.pyc DELETED
Binary file (4.6 kB)
 
my_diffusers/models/__pycache__/modeling_flax_utils.cpython-311.pyc DELETED
Binary file (28.4 kB)
 
my_diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-311.pyc DELETED
Binary file (7.7 kB)
 
my_diffusers/models/__pycache__/modeling_utils.cpython-311.pyc DELETED
Binary file (44.3 kB)
 
my_diffusers/models/__pycache__/prior_transformer.cpython-311.pyc DELETED
Binary file (10.8 kB)
 
my_diffusers/models/__pycache__/resnet.cpython-311.pyc DELETED
Binary file (39.8 kB)
 
my_diffusers/models/__pycache__/resnet_flax.cpython-311.pyc DELETED
Binary file (5.04 kB)
 
my_diffusers/models/__pycache__/transformer_2d.cpython-311.pyc DELETED
Binary file (16.1 kB)
 
my_diffusers/models/__pycache__/unet_1d.cpython-311.pyc DELETED
Binary file (10.9 kB)
 
my_diffusers/models/__pycache__/unet_1d_blocks.cpython-311.pyc DELETED
Binary file (33.8 kB)
 
my_diffusers/models/__pycache__/unet_2d.cpython-311.pyc DELETED
Binary file (14.9 kB)
 
my_diffusers/models/__pycache__/unet_2d_blocks.cpython-311.pyc DELETED
Binary file (79.9 kB)
 
my_diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-311.pyc DELETED
Binary file (15.1 kB)
 
my_diffusers/models/__pycache__/unet_2d_condition.cpython-311.pyc DELETED
Binary file (31 kB)
 
my_diffusers/models/__pycache__/unet_2d_condition_flax.cpython-311.pyc DELETED
Binary file (14.4 kB)
 
my_diffusers/models/__pycache__/vae.cpython-311.pyc DELETED
Binary file (17.1 kB)
 
my_diffusers/models/__pycache__/vae_flax.cpython-311.pyc DELETED
Binary file (39.5 kB)
 
my_diffusers/models/__pycache__/vq_model.cpython-311.pyc DELETED
Binary file (7.41 kB)
 
my_diffusers/models/attention.py DELETED
@@ -1,517 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
- from typing import Callable, Optional
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- from torch import nn
20
-
21
- from ..utils.import_utils import is_xformers_available
22
- from .cross_attention import CrossAttention
23
- from .embeddings import CombinedTimestepLabelEmbeddings
24
-
25
-
26
- if is_xformers_available():
27
- import xformers
28
- import xformers.ops
29
- else:
30
- xformers = None
31
-
32
-
33
- class AttentionBlock(nn.Module):
34
- """
35
- An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
36
- to the N-d case.
37
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
38
- Uses three q, k, v linear layers to compute attention.
39
-
40
- Parameters:
41
- channels (`int`): The number of channels in the input and output.
42
- num_head_channels (`int`, *optional*):
43
- The number of channels in each head. If None, then `num_heads` = 1.
44
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
45
- rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
46
- eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
47
- """
48
-
49
- # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
50
-
51
- def __init__(
52
- self,
53
- channels: int,
54
- num_head_channels: Optional[int] = None,
55
- norm_num_groups: int = 32,
56
- rescale_output_factor: float = 1.0,
57
- eps: float = 1e-5,
58
- ):
59
- super().__init__()
60
- self.channels = channels
61
-
62
- self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
63
- self.num_head_size = num_head_channels
64
- self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
65
-
66
- # define q,k,v as linear layers
67
- self.query = nn.Linear(channels, channels)
68
- self.key = nn.Linear(channels, channels)
69
- self.value = nn.Linear(channels, channels)
70
-
71
- self.rescale_output_factor = rescale_output_factor
72
- self.proj_attn = nn.Linear(channels, channels, 1)
73
-
74
- self._use_memory_efficient_attention_xformers = False
75
- self._attention_op = None
76
-
77
- def reshape_heads_to_batch_dim(self, tensor):
78
- batch_size, seq_len, dim = tensor.shape
79
- head_size = self.num_heads
80
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
82
- return tensor
83
-
84
- def reshape_batch_dim_to_heads(self, tensor):
85
- batch_size, seq_len, dim = tensor.shape
86
- head_size = self.num_heads
87
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
88
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
89
- return tensor
90
-
91
- def set_use_memory_efficient_attention_xformers(
92
- self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
93
- ):
94
- if use_memory_efficient_attention_xformers:
95
- if not is_xformers_available():
96
- raise ModuleNotFoundError(
97
- (
98
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
99
- " xformers"
100
- ),
101
- name="xformers",
102
- )
103
- elif not torch.cuda.is_available():
104
- raise ValueError(
105
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
106
- " only available for GPU "
107
- )
108
- else:
109
- try:
110
- # Make sure we can run the memory efficient attention
111
- _ = xformers.ops.memory_efficient_attention(
112
- torch.randn((1, 2, 40), device="cuda"),
113
- torch.randn((1, 2, 40), device="cuda"),
114
- torch.randn((1, 2, 40), device="cuda"),
115
- )
116
- except Exception as e:
117
- raise e
118
- self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
119
- self._attention_op = attention_op
120
-
121
- def forward(self, hidden_states):
122
- residual = hidden_states
123
- batch, channel, height, width = hidden_states.shape
124
-
125
- # norm
126
- hidden_states = self.group_norm(hidden_states)
127
-
128
- hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
129
-
130
- # proj to q, k, v
131
- query_proj = self.query(hidden_states)
132
- key_proj = self.key(hidden_states)
133
- value_proj = self.value(hidden_states)
134
-
135
- scale = 1 / math.sqrt(self.channels / self.num_heads)
136
-
137
- query_proj = self.reshape_heads_to_batch_dim(query_proj)
138
- key_proj = self.reshape_heads_to_batch_dim(key_proj)
139
- value_proj = self.reshape_heads_to_batch_dim(value_proj)
140
-
141
- if self._use_memory_efficient_attention_xformers:
142
- # Memory efficient attention
143
- hidden_states = xformers.ops.memory_efficient_attention(
144
- query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
145
- )
146
- hidden_states = hidden_states.to(query_proj.dtype)
147
- else:
148
- attention_scores = torch.baddbmm(
149
- torch.empty(
150
- query_proj.shape[0],
151
- query_proj.shape[1],
152
- key_proj.shape[1],
153
- dtype=query_proj.dtype,
154
- device=query_proj.device,
155
- ),
156
- query_proj,
157
- key_proj.transpose(-1, -2),
158
- beta=0,
159
- alpha=scale,
160
- )
161
- attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
162
- hidden_states = torch.bmm(attention_probs, value_proj)
163
-
164
- # reshape hidden_states
165
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
166
-
167
- # compute next hidden_states
168
- hidden_states = self.proj_attn(hidden_states)
169
-
170
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
171
-
172
- # res connect and rescale
173
- hidden_states = (hidden_states + residual) / self.rescale_output_factor
174
- return hidden_states
175
-
176
-
177
- class BasicTransformerBlock(nn.Module):
178
- r"""
179
- A basic Transformer block.
180
-
181
- Parameters:
182
- dim (`int`): The number of channels in the input and output.
183
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
184
- attention_head_dim (`int`): The number of channels in each head.
185
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
186
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
187
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
188
- num_embeds_ada_norm (:
189
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
190
- attention_bias (:
191
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
192
- """
193
-
194
- def __init__(
195
- self,
196
- dim: int,
197
- num_attention_heads: int,
198
- attention_head_dim: int,
199
- dropout=0.0,
200
- cross_attention_dim: Optional[int] = None,
201
- activation_fn: str = "geglu",
202
- num_embeds_ada_norm: Optional[int] = None,
203
- attention_bias: bool = False,
204
- only_cross_attention: bool = False,
205
- upcast_attention: bool = False,
206
- norm_elementwise_affine: bool = True,
207
- norm_type: str = "layer_norm",
208
- final_dropout: bool = False,
209
- ):
210
- super().__init__()
211
- self.only_cross_attention = only_cross_attention
212
-
213
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
214
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
215
-
216
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
217
- raise ValueError(
218
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
219
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
220
- )
221
-
222
- # 1. Self-Attn
223
- self.attn1 = CrossAttention(
224
- query_dim=dim,
225
- heads=num_attention_heads,
226
- dim_head=attention_head_dim,
227
- dropout=dropout,
228
- bias=attention_bias,
229
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
230
- upcast_attention=upcast_attention,
231
- )
232
-
233
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
234
-
235
- # 2. Cross-Attn
236
- if cross_attention_dim is not None:
237
- self.attn2 = CrossAttention(
238
- query_dim=dim,
239
- cross_attention_dim=cross_attention_dim,
240
- heads=num_attention_heads,
241
- dim_head=attention_head_dim,
242
- dropout=dropout,
243
- bias=attention_bias,
244
- upcast_attention=upcast_attention,
245
- ) # is self-attn if encoder_hidden_states is none
246
- else:
247
- self.attn2 = None
248
-
249
- if self.use_ada_layer_norm:
250
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
251
- elif self.use_ada_layer_norm_zero:
252
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
253
- else:
254
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
255
-
256
- if cross_attention_dim is not None:
257
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
258
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
259
- # the second cross attention block.
260
- self.norm2 = (
261
- AdaLayerNorm(dim, num_embeds_ada_norm)
262
- if self.use_ada_layer_norm
263
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
264
- )
265
- else:
266
- self.norm2 = None
267
-
268
- # 3. Feed-forward
269
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
270
-
271
- def forward(
272
- self,
273
- hidden_states,
274
- encoder_hidden_states=None,
275
- timestep=None,
276
- attention_mask=None,
277
- cross_attention_kwargs=None,
278
- class_labels=None,
279
- ):
280
- if self.use_ada_layer_norm:
281
- norm_hidden_states = self.norm1(hidden_states, timestep)
282
- elif self.use_ada_layer_norm_zero:
283
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
284
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
285
- )
286
- else:
287
- norm_hidden_states = self.norm1(hidden_states)
288
-
289
- # 1. Self-Attention
290
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
291
- attn_output = self.attn1(
292
- norm_hidden_states,
293
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
294
- attention_mask=attention_mask,
295
- **cross_attention_kwargs,
296
- )
297
- if self.use_ada_layer_norm_zero:
298
- attn_output = gate_msa.unsqueeze(1) * attn_output
299
- hidden_states = attn_output + hidden_states
300
-
301
- if self.attn2 is not None:
302
- norm_hidden_states = (
303
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
304
- )
305
-
306
- # 2. Cross-Attention
307
- attn_output = self.attn2(
308
- norm_hidden_states,
309
- encoder_hidden_states=encoder_hidden_states,
310
- attention_mask=attention_mask,
311
- **cross_attention_kwargs,
312
- )
313
- hidden_states = attn_output + hidden_states
314
-
315
- # 3. Feed-forward
316
- norm_hidden_states = self.norm3(hidden_states)
317
-
318
- if self.use_ada_layer_norm_zero:
319
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
320
-
321
- ff_output = self.ff(norm_hidden_states)
322
-
323
- if self.use_ada_layer_norm_zero:
324
- ff_output = gate_mlp.unsqueeze(1) * ff_output
325
-
326
- hidden_states = ff_output + hidden_states
327
-
328
- return hidden_states
329
-
330
-
331
- class FeedForward(nn.Module):
332
- r"""
333
- A feed-forward layer.
334
-
335
- Parameters:
336
- dim (`int`): The number of channels in the input.
337
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
338
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
339
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
340
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
341
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
342
- """
343
-
344
- def __init__(
345
- self,
346
- dim: int,
347
- dim_out: Optional[int] = None,
348
- mult: int = 4,
349
- dropout: float = 0.0,
350
- activation_fn: str = "geglu",
351
- final_dropout: bool = False,
352
- ):
353
- super().__init__()
354
- inner_dim = int(dim * mult)
355
- dim_out = dim_out if dim_out is not None else dim
356
-
357
- if activation_fn == "gelu":
358
- act_fn = GELU(dim, inner_dim)
359
- if activation_fn == "gelu-approximate":
360
- act_fn = GELU(dim, inner_dim, approximate="tanh")
361
- elif activation_fn == "geglu":
362
- act_fn = GEGLU(dim, inner_dim)
363
- elif activation_fn == "geglu-approximate":
364
- act_fn = ApproximateGELU(dim, inner_dim)
365
-
366
- self.net = nn.ModuleList([])
367
- # project in
368
- self.net.append(act_fn)
369
- # project dropout
370
- self.net.append(nn.Dropout(dropout))
371
- # project out
372
- self.net.append(nn.Linear(inner_dim, dim_out))
373
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
374
- if final_dropout:
375
- self.net.append(nn.Dropout(dropout))
376
-
377
- def forward(self, hidden_states):
378
- for module in self.net:
379
- hidden_states = module(hidden_states)
380
- return hidden_states
381
-
382
-
383
- class GELU(nn.Module):
384
- r"""
385
- GELU activation function with tanh approximation support with `approximate="tanh"`.
386
- """
387
-
388
- def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
389
- super().__init__()
390
- self.proj = nn.Linear(dim_in, dim_out)
391
- self.approximate = approximate
392
-
393
- def gelu(self, gate):
394
- if gate.device.type != "mps":
395
- return F.gelu(gate, approximate=self.approximate)
396
- # mps: gelu is not implemented for float16
397
- return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
398
-
399
- def forward(self, hidden_states):
400
- hidden_states = self.proj(hidden_states)
401
- hidden_states = self.gelu(hidden_states)
402
- return hidden_states
403
-
404
-
405
- class GEGLU(nn.Module):
406
- r"""
407
- A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
408
-
409
- Parameters:
410
- dim_in (`int`): The number of channels in the input.
411
- dim_out (`int`): The number of channels in the output.
412
- """
413
-
414
- def __init__(self, dim_in: int, dim_out: int):
415
- super().__init__()
416
- self.proj = nn.Linear(dim_in, dim_out * 2)
417
-
418
- def gelu(self, gate):
419
- if gate.device.type != "mps":
420
- return F.gelu(gate)
421
- # mps: gelu is not implemented for float16
422
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
423
-
424
- def forward(self, hidden_states):
425
- hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
426
- return hidden_states * self.gelu(gate)
427
-
428
-
429
- class ApproximateGELU(nn.Module):
430
- """
431
- The approximate form of Gaussian Error Linear Unit (GELU)
432
-
433
- For more details, see section 2: https://arxiv.org/abs/1606.08415
434
- """
435
-
436
- def __init__(self, dim_in: int, dim_out: int):
437
- super().__init__()
438
- self.proj = nn.Linear(dim_in, dim_out)
439
-
440
- def forward(self, x):
441
- x = self.proj(x)
442
- return x * torch.sigmoid(1.702 * x)
443
-
444
-
445
- class AdaLayerNorm(nn.Module):
446
- """
447
- Norm layer modified to incorporate timestep embeddings.
448
- """
449
-
450
- def __init__(self, embedding_dim, num_embeddings):
451
- super().__init__()
452
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
453
- self.silu = nn.SiLU()
454
- self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
455
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
456
-
457
- def forward(self, x, timestep):
458
- emb = self.linear(self.silu(self.emb(timestep)))
459
- scale, shift = torch.chunk(emb, 2)
460
- x = self.norm(x) * (1 + scale) + shift
461
- return x
462
-
463
-
464
- class AdaLayerNormZero(nn.Module):
465
- """
466
- Norm layer adaptive layer norm zero (adaLN-Zero).
467
- """
468
-
469
- def __init__(self, embedding_dim, num_embeddings):
470
- super().__init__()
471
-
472
- self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
473
-
474
- self.silu = nn.SiLU()
475
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
476
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
477
-
478
- def forward(self, x, timestep, class_labels, hidden_dtype=None):
479
- emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
480
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
481
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
482
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
483
-
484
-
485
- class AdaGroupNorm(nn.Module):
486
- """
487
- GroupNorm layer modified to incorporate timestep embeddings.
488
- """
489
-
490
- def __init__(
491
- self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
492
- ):
493
- super().__init__()
494
- self.num_groups = num_groups
495
- self.eps = eps
496
- self.act = None
497
- if act_fn == "swish":
498
- self.act = lambda x: F.silu(x)
499
- elif act_fn == "mish":
500
- self.act = nn.Mish()
501
- elif act_fn == "silu":
502
- self.act = nn.SiLU()
503
- elif act_fn == "gelu":
504
- self.act = nn.GELU()
505
-
506
- self.linear = nn.Linear(embedding_dim, out_dim * 2)
507
-
508
- def forward(self, x, emb):
509
- if self.act:
510
- emb = self.act(emb)
511
- emb = self.linear(emb)
512
- emb = emb[:, :, None, None]
513
- scale, shift = emb.chunk(2, dim=1)
514
-
515
- x = F.group_norm(x, self.num_groups, eps=self.eps)
516
- x = x * (1 + scale) + shift
517
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/models/attention_flax.py DELETED
@@ -1,302 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import flax.linen as nn
16
- import jax.numpy as jnp
17
-
18
-
19
- class FlaxCrossAttention(nn.Module):
20
- r"""
21
- A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
22
-
23
- Parameters:
24
- query_dim (:obj:`int`):
25
- Input hidden states dimension
26
- heads (:obj:`int`, *optional*, defaults to 8):
27
- Number of heads
28
- dim_head (:obj:`int`, *optional*, defaults to 64):
29
- Hidden states dimension inside each head
30
- dropout (:obj:`float`, *optional*, defaults to 0.0):
31
- Dropout rate
32
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
33
- Parameters `dtype`
34
-
35
- """
36
- query_dim: int
37
- heads: int = 8
38
- dim_head: int = 64
39
- dropout: float = 0.0
40
- dtype: jnp.dtype = jnp.float32
41
-
42
- def setup(self):
43
- inner_dim = self.dim_head * self.heads
44
- self.scale = self.dim_head**-0.5
45
-
46
- # Weights were exported with old names {to_q, to_k, to_v, to_out}
47
- self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
48
- self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
49
- self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
50
-
51
- self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
52
-
53
- def reshape_heads_to_batch_dim(self, tensor):
54
- batch_size, seq_len, dim = tensor.shape
55
- head_size = self.heads
56
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
57
- tensor = jnp.transpose(tensor, (0, 2, 1, 3))
58
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
59
- return tensor
60
-
61
- def reshape_batch_dim_to_heads(self, tensor):
62
- batch_size, seq_len, dim = tensor.shape
63
- head_size = self.heads
64
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
65
- tensor = jnp.transpose(tensor, (0, 2, 1, 3))
66
- tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
67
- return tensor
68
-
69
- def __call__(self, hidden_states, context=None, deterministic=True):
70
- context = hidden_states if context is None else context
71
-
72
- query_proj = self.query(hidden_states)
73
- key_proj = self.key(context)
74
- value_proj = self.value(context)
75
-
76
- query_states = self.reshape_heads_to_batch_dim(query_proj)
77
- key_states = self.reshape_heads_to_batch_dim(key_proj)
78
- value_states = self.reshape_heads_to_batch_dim(value_proj)
79
-
80
- # compute attentions
81
- attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
82
- attention_scores = attention_scores * self.scale
83
- attention_probs = nn.softmax(attention_scores, axis=2)
84
-
85
- # attend to values
86
- hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
87
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
88
- hidden_states = self.proj_attn(hidden_states)
89
- return hidden_states
90
-
91
-
92
- class FlaxBasicTransformerBlock(nn.Module):
93
- r"""
94
- A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
95
- https://arxiv.org/abs/1706.03762
96
-
97
-
98
- Parameters:
99
- dim (:obj:`int`):
100
- Inner hidden states dimension
101
- n_heads (:obj:`int`):
102
- Number of heads
103
- d_head (:obj:`int`):
104
- Hidden states dimension inside each head
105
- dropout (:obj:`float`, *optional*, defaults to 0.0):
106
- Dropout rate
107
- only_cross_attention (`bool`, defaults to `False`):
108
- Whether to only apply cross attention.
109
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
110
- Parameters `dtype`
111
- """
112
- dim: int
113
- n_heads: int
114
- d_head: int
115
- dropout: float = 0.0
116
- only_cross_attention: bool = False
117
- dtype: jnp.dtype = jnp.float32
118
-
119
- def setup(self):
120
- # self attention (or cross_attention if only_cross_attention is True)
121
- self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
122
- # cross attention
123
- self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
124
- self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
125
- self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
126
- self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
127
- self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
128
-
129
- def __call__(self, hidden_states, context, deterministic=True):
130
- # self attention
131
- residual = hidden_states
132
- if self.only_cross_attention:
133
- hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
134
- else:
135
- hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
136
- hidden_states = hidden_states + residual
137
-
138
- # cross attention
139
- residual = hidden_states
140
- hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
141
- hidden_states = hidden_states + residual
142
-
143
- # feed forward
144
- residual = hidden_states
145
- hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
146
- hidden_states = hidden_states + residual
147
-
148
- return hidden_states
149
-
150
-
151
- class FlaxTransformer2DModel(nn.Module):
152
- r"""
153
- A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
154
- https://arxiv.org/pdf/1506.02025.pdf
155
-
156
-
157
- Parameters:
158
- in_channels (:obj:`int`):
159
- Input number of channels
160
- n_heads (:obj:`int`):
161
- Number of heads
162
- d_head (:obj:`int`):
163
- Hidden states dimension inside each head
164
- depth (:obj:`int`, *optional*, defaults to 1):
165
- Number of transformers block
166
- dropout (:obj:`float`, *optional*, defaults to 0.0):
167
- Dropout rate
168
- use_linear_projection (`bool`, defaults to `False`): tbd
169
- only_cross_attention (`bool`, defaults to `False`): tbd
170
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
171
- Parameters `dtype`
172
- """
173
- in_channels: int
174
- n_heads: int
175
- d_head: int
176
- depth: int = 1
177
- dropout: float = 0.0
178
- use_linear_projection: bool = False
179
- only_cross_attention: bool = False
180
- dtype: jnp.dtype = jnp.float32
181
-
182
- def setup(self):
183
- self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
184
-
185
- inner_dim = self.n_heads * self.d_head
186
- if self.use_linear_projection:
187
- self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
188
- else:
189
- self.proj_in = nn.Conv(
190
- inner_dim,
191
- kernel_size=(1, 1),
192
- strides=(1, 1),
193
- padding="VALID",
194
- dtype=self.dtype,
195
- )
196
-
197
- self.transformer_blocks = [
198
- FlaxBasicTransformerBlock(
199
- inner_dim,
200
- self.n_heads,
201
- self.d_head,
202
- dropout=self.dropout,
203
- only_cross_attention=self.only_cross_attention,
204
- dtype=self.dtype,
205
- )
206
- for _ in range(self.depth)
207
- ]
208
-
209
- if self.use_linear_projection:
210
- self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
211
- else:
212
- self.proj_out = nn.Conv(
213
- inner_dim,
214
- kernel_size=(1, 1),
215
- strides=(1, 1),
216
- padding="VALID",
217
- dtype=self.dtype,
218
- )
219
-
220
- def __call__(self, hidden_states, context, deterministic=True):
221
- batch, height, width, channels = hidden_states.shape
222
- residual = hidden_states
223
- hidden_states = self.norm(hidden_states)
224
- if self.use_linear_projection:
225
- hidden_states = hidden_states.reshape(batch, height * width, channels)
226
- hidden_states = self.proj_in(hidden_states)
227
- else:
228
- hidden_states = self.proj_in(hidden_states)
229
- hidden_states = hidden_states.reshape(batch, height * width, channels)
230
-
231
- for transformer_block in self.transformer_blocks:
232
- hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
233
-
234
- if self.use_linear_projection:
235
- hidden_states = self.proj_out(hidden_states)
236
- hidden_states = hidden_states.reshape(batch, height, width, channels)
237
- else:
238
- hidden_states = hidden_states.reshape(batch, height, width, channels)
239
- hidden_states = self.proj_out(hidden_states)
240
-
241
- hidden_states = hidden_states + residual
242
- return hidden_states
243
-
244
-
245
- class FlaxFeedForward(nn.Module):
246
- r"""
247
- Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
248
- [`FeedForward`] class, with the following simplifications:
249
- - The activation function is currently hardcoded to a gated linear unit from:
250
- https://arxiv.org/abs/2002.05202
251
- - `dim_out` is equal to `dim`.
252
- - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
253
-
254
- Parameters:
255
- dim (:obj:`int`):
256
- Inner hidden states dimension
257
- dropout (:obj:`float`, *optional*, defaults to 0.0):
258
- Dropout rate
259
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
260
- Parameters `dtype`
261
- """
262
- dim: int
263
- dropout: float = 0.0
264
- dtype: jnp.dtype = jnp.float32
265
-
266
- def setup(self):
267
- # The second linear layer needs to be called
268
- # net_2 for now to match the index of the Sequential layer
269
- self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
270
- self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
271
-
272
- def __call__(self, hidden_states, deterministic=True):
273
- hidden_states = self.net_0(hidden_states)
274
- hidden_states = self.net_2(hidden_states)
275
- return hidden_states
276
-
277
-
278
- class FlaxGEGLU(nn.Module):
279
- r"""
280
- Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
281
- https://arxiv.org/abs/2002.05202.
282
-
283
- Parameters:
284
- dim (:obj:`int`):
285
- Input hidden states dimension
286
- dropout (:obj:`float`, *optional*, defaults to 0.0):
287
- Dropout rate
288
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
289
- Parameters `dtype`
290
- """
291
- dim: int
292
- dropout: float = 0.0
293
- dtype: jnp.dtype = jnp.float32
294
-
295
- def setup(self):
296
- inner_dim = self.dim * 4
297
- self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
298
-
299
- def __call__(self, hidden_states, deterministic=True):
300
- hidden_states = self.proj(hidden_states)
301
- hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
302
- return hidden_linear * nn.gelu(hidden_gelu)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
my_diffusers/models/autoencoder_kl.py DELETED
@@ -1,320 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from ..configuration_utils import ConfigMixin, register_to_config
21
- from ..utils import BaseOutput, apply_forward_hook
22
- from .modeling_utils import ModelMixin
23
- from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
24
-
25
-
26
- @dataclass
27
- class AutoencoderKLOutput(BaseOutput):
28
- """
29
- Output of AutoencoderKL encoding method.
30
-
31
- Args:
32
- latent_dist (`DiagonalGaussianDistribution`):
33
- Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
34
- `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
35
- """
36
-
37
- latent_dist: "DiagonalGaussianDistribution"
38
-
39
-
40
- class AutoencoderKL(ModelMixin, ConfigMixin):
41
- r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
42
- and Max Welling.
43
-
44
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
45
- implements for all the model (such as downloading or saving, etc.)
46
-
47
- Parameters:
48
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
49
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
50
- down_block_types (`Tuple[str]`, *optional*, defaults to :
51
- obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
52
- up_block_types (`Tuple[str]`, *optional*, defaults to :
53
- obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
54
- block_out_channels (`Tuple[int]`, *optional*, defaults to :
55
- obj:`(64,)`): Tuple of block output channels.
56
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
57
- latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
58
- sample_size (`int`, *optional*, defaults to `32`): TODO
59
- scaling_factor (`float`, *optional*, defaults to 0.18215):
60
- The component-wise standard deviation of the trained latent space computed using the first batch of the
61
- training set. This is used to scale the latent space to have unit variance when training the diffusion
62
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
63
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
64
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
65
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
66
- """
67
-
68
- @register_to_config
69
- def __init__(
70
- self,
71
- in_channels: int = 3,
72
- out_channels: int = 3,
73
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
74
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
75
- block_out_channels: Tuple[int] = (64,),
76
- layers_per_block: int = 1,
77
- act_fn: str = "silu",
78
- latent_channels: int = 4,
79
- norm_num_groups: int = 32,
80
- sample_size: int = 32,
81
- scaling_factor: float = 0.18215,
82
- ):
83
- super().__init__()
84
-
85
- # pass init params to Encoder
86
- self.encoder = Encoder(
87
- in_channels=in_channels,
88
- out_channels=latent_channels,
89
- down_block_types=down_block_types,
90
- block_out_channels=block_out_channels,
91
- layers_per_block=layers_per_block,
92
- act_fn=act_fn,
93
- norm_num_groups=norm_num_groups,
94
- double_z=True,
95
- )
96
-
97
- # pass init params to Decoder
98
- self.decoder = Decoder(
99
- in_channels=latent_channels,
100
- out_channels=out_channels,
101
- up_block_types=up_block_types,
102
- block_out_channels=block_out_channels,
103
- layers_per_block=layers_per_block,
104
- norm_num_groups=norm_num_groups,
105
- act_fn=act_fn,
106
- )
107
-
108
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
109
- self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
110
-
111
- self.use_slicing = False
112
- self.use_tiling = False
113
-
114
- # only relevant if vae tiling is enabled
115
- self.tile_sample_min_size = self.config.sample_size
116
- sample_size = (
117
- self.config.sample_size[0]
118
- if isinstance(self.config.sample_size, (list, tuple))
119
- else self.config.sample_size
120
- )
121
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
122
- self.tile_overlap_factor = 0.25
123
-
124
- def enable_tiling(self, use_tiling: bool = True):
125
- r"""
126
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
127
- compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
128
- the processing of larger images.
129
- """
130
- self.use_tiling = use_tiling
131
-
132
- def disable_tiling(self):
133
- r"""
134
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
135
- computing decoding in one step.
136
- """
137
- self.enable_tiling(False)
138
-
139
- def enable_slicing(self):
140
- r"""
141
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
142
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
143
- """
144
- self.use_slicing = True
145
-
146
- def disable_slicing(self):
147
- r"""
148
- Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
149
- decoding in one step.
150
- """
151
- self.use_slicing = False
152
-
153
- @apply_forward_hook
154
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
155
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
156
- return self.tiled_encode(x, return_dict=return_dict)
157
-
158
- h = self.encoder(x)
159
- moments = self.quant_conv(h)
160
- posterior = DiagonalGaussianDistribution(moments)
161
-
162
- if not return_dict:
163
- return (posterior,)
164
-
165
- return AutoencoderKLOutput(latent_dist=posterior)
166
-
167
- def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
168
- if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
169
- return self.tiled_decode(z, return_dict=return_dict)
170
-
171
- z = self.post_quant_conv(z)
172
- dec = self.decoder(z)
173
-
174
- if not return_dict:
175
- return (dec,)
176
-
177
- return DecoderOutput(sample=dec)
178
-
179
- @apply_forward_hook
180
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
181
- if self.use_slicing and z.shape[0] > 1:
182
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
183
- decoded = torch.cat(decoded_slices)
184
- else:
185
- decoded = self._decode(z).sample
186
-
187
- if not return_dict:
188
- return (decoded,)
189
-
190
- return DecoderOutput(sample=decoded)
191
-
192
- def blend_v(self, a, b, blend_extent):
193
- for y in range(blend_extent):
194
- b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
195
- return b
196
-
197
- def blend_h(self, a, b, blend_extent):
198
- for x in range(blend_extent):
199
- b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
200
- return b
201
-
202
- def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
203
- r"""Encode a batch of images using a tiled encoder.
204
- Args:
205
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
206
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
207
- different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
208
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
209
- look of the output, but they should be much less noticeable.
210
- x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
211
- Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
212
- """
213
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
214
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
215
- row_limit = self.tile_latent_min_size - blend_extent
216
-
217
- # Split the image into 512x512 tiles and encode them separately.
218
- rows = []
219
- for i in range(0, x.shape[2], overlap_size):
220
- row = []
221
- for j in range(0, x.shape[3], overlap_size):
222
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
223
- tile = self.encoder(tile)
224
- tile = self.quant_conv(tile)
225
- row.append(tile)
226
- rows.append(row)
227
- result_rows = []
228
- for i, row in enumerate(rows):
229
- result_row = []
230
- for j, tile in enumerate(row):
231
- # blend the above tile and the left tile
232
- # to the current tile and add the current tile to the result row
233
- if i > 0:
234
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
235
- if j > 0:
236
- tile = self.blend_h(row[j - 1], tile, blend_extent)
237
- result_row.append(tile[:, :, :row_limit, :row_limit])
238
- result_rows.append(torch.cat(result_row, dim=3))
239
-
240
- moments = torch.cat(result_rows, dim=2)
241
- posterior = DiagonalGaussianDistribution(moments)
242
-
243
- if not return_dict:
244
- return (posterior,)
245
-
246
- return AutoencoderKLOutput(latent_dist=posterior)
247
-
248
- def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
249
- r"""Decode a batch of images using a tiled decoder.
250
- Args:
251
- When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
252
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
253
- different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
254
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
255
- look of the output, but they should be much less noticeable.
256
- z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
257
- `True`):
258
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
259
- """
260
- overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
261
- blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
262
- row_limit = self.tile_sample_min_size - blend_extent
263
-
264
- # Split z into overlapping 64x64 tiles and decode them separately.
265
- # The tiles have an overlap to avoid seams between tiles.
266
- rows = []
267
- for i in range(0, z.shape[2], overlap_size):
268
- row = []
269
- for j in range(0, z.shape[3], overlap_size):
270
- tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
271
- tile = self.post_quant_conv(tile)
272
- decoded = self.decoder(tile)
273
- row.append(decoded)
274
- rows.append(row)
275
- result_rows = []
276
- for i, row in enumerate(rows):
277
- result_row = []
278
- for j, tile in enumerate(row):
279
- # blend the above tile and the left tile
280
- # to the current tile and add the current tile to the result row
281
- if i > 0:
282
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
283
- if j > 0:
284
- tile = self.blend_h(row[j - 1], tile, blend_extent)
285
- result_row.append(tile[:, :, :row_limit, :row_limit])
286
- result_rows.append(torch.cat(result_row, dim=3))
287
-
288
- dec = torch.cat(result_rows, dim=2)
289
- if not return_dict:
290
- return (dec,)
291
-
292
- return DecoderOutput(sample=dec)
293
-
294
- def forward(
295
- self,
296
- sample: torch.FloatTensor,
297
- sample_posterior: bool = False,
298
- return_dict: bool = True,
299
- generator: Optional[torch.Generator] = None,
300
- ) -> Union[DecoderOutput, torch.FloatTensor]:
301
- r"""
302
- Args:
303
- sample (`torch.FloatTensor`): Input sample.
304
- sample_posterior (`bool`, *optional*, defaults to `False`):
305
- Whether to sample from the posterior.
306
- return_dict (`bool`, *optional*, defaults to `True`):
307
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
308
- """
309
- x = sample
310
- posterior = self.encode(x).latent_dist
311
- if sample_posterior:
312
- z = posterior.sample(generator=generator)
313
- else:
314
- z = posterior.mode()
315
- dec = self.decode(z).sample
316
-
317
- if not return_dict:
318
- return (dec,)
319
-
320
- return DecoderOutput(sample=dec)