Adi-69s commited on
Commit
0000c2e
·
verified ·
1 Parent(s): e8a594e

Upload 273 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffusers/__init__.py +163 -0
  2. diffusers/__pycache__/__init__.cpython-310.pyc +0 -0
  3. diffusers/__pycache__/configuration_utils.cpython-310.pyc +0 -0
  4. diffusers/__pycache__/dependency_versions_check.cpython-310.pyc +0 -0
  5. diffusers/__pycache__/dependency_versions_table.cpython-310.pyc +0 -0
  6. diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc +0 -0
  7. diffusers/__pycache__/hub_utils.cpython-310.pyc +0 -0
  8. diffusers/__pycache__/modeling_flax_pytorch_utils.cpython-310.pyc +0 -0
  9. diffusers/__pycache__/modeling_flax_utils.cpython-310.pyc +0 -0
  10. diffusers/__pycache__/modeling_utils.cpython-310.pyc +0 -0
  11. diffusers/__pycache__/onnx_utils.cpython-310.pyc +0 -0
  12. diffusers/__pycache__/optimization.cpython-310.pyc +0 -0
  13. diffusers/__pycache__/pipeline_flax_utils.cpython-310.pyc +0 -0
  14. diffusers/__pycache__/pipeline_utils.cpython-310.pyc +0 -0
  15. diffusers/__pycache__/training_utils.cpython-310.pyc +0 -0
  16. diffusers/commands/__init__.py +27 -0
  17. diffusers/commands/__pycache__/__init__.cpython-310.pyc +0 -0
  18. diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc +0 -0
  19. diffusers/commands/__pycache__/env.cpython-310.pyc +0 -0
  20. diffusers/commands/diffusers_cli.py +41 -0
  21. diffusers/commands/env.py +70 -0
  22. diffusers/configuration_utils.py +613 -0
  23. diffusers/dependency_versions_check.py +47 -0
  24. diffusers/dependency_versions_table.py +35 -0
  25. diffusers/dynamic_modules_utils.py +428 -0
  26. diffusers/experimental/__init__.py +1 -0
  27. diffusers/experimental/__pycache__/__init__.cpython-310.pyc +0 -0
  28. diffusers/experimental/rl/__init__.py +1 -0
  29. diffusers/experimental/rl/__pycache__/__init__.cpython-310.pyc +0 -0
  30. diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-310.pyc +0 -0
  31. diffusers/experimental/rl/value_guided_sampling.py +152 -0
  32. diffusers/hub_utils.py +154 -0
  33. diffusers/modeling_flax_pytorch_utils.py +117 -0
  34. diffusers/modeling_flax_utils.py +535 -0
  35. diffusers/modeling_utils.py +892 -0
  36. diffusers/models/__init__.py +27 -0
  37. diffusers/models/__pycache__/__init__.cpython-310.pyc +0 -0
  38. diffusers/models/__pycache__/attention.cpython-310.pyc +0 -0
  39. diffusers/models/__pycache__/attention_flax.cpython-310.pyc +0 -0
  40. diffusers/models/__pycache__/embeddings.cpython-310.pyc +0 -0
  41. diffusers/models/__pycache__/embeddings_flax.cpython-310.pyc +0 -0
  42. diffusers/models/__pycache__/resnet.cpython-310.pyc +0 -0
  43. diffusers/models/__pycache__/resnet_flax.cpython-310.pyc +0 -0
  44. diffusers/models/__pycache__/unet_1d.cpython-310.pyc +0 -0
  45. diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc +0 -0
  46. diffusers/models/__pycache__/unet_2d.cpython-310.pyc +0 -0
  47. diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc +0 -0
  48. diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-310.pyc +0 -0
  49. diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc +0 -0
  50. diffusers/models/__pycache__/unet_2d_condition_flax.cpython-310.pyc +0 -0
diffusers/__init__.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.10.2"
2
+
3
+ from .configuration_utils import ConfigMixin
4
+ from .onnx_utils import OnnxRuntimeModel
5
+ from .utils import (
6
+ OptionalDependencyNotAvailable,
7
+ is_flax_available,
8
+ is_inflect_available,
9
+ is_k_diffusion_available,
10
+ is_librosa_available,
11
+ is_onnx_available,
12
+ is_scipy_available,
13
+ is_torch_available,
14
+ is_transformers_available,
15
+ is_transformers_version,
16
+ is_unidecode_available,
17
+ logging,
18
+ )
19
+
20
+
21
+ try:
22
+ if not is_torch_available():
23
+ raise OptionalDependencyNotAvailable()
24
+ except OptionalDependencyNotAvailable:
25
+ from .utils.dummy_pt_objects import * # noqa F403
26
+ else:
27
+ from .modeling_utils import ModelMixin
28
+ from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
29
+ from .optimization import (
30
+ get_constant_schedule,
31
+ get_constant_schedule_with_warmup,
32
+ get_cosine_schedule_with_warmup,
33
+ get_cosine_with_hard_restarts_schedule_with_warmup,
34
+ get_linear_schedule_with_warmup,
35
+ get_polynomial_decay_schedule_with_warmup,
36
+ get_scheduler,
37
+ )
38
+ from .pipeline_utils import DiffusionPipeline
39
+ from .pipelines import (
40
+ DanceDiffusionPipeline,
41
+ DDIMPipeline,
42
+ DDPMPipeline,
43
+ KarrasVePipeline,
44
+ LDMPipeline,
45
+ LDMSuperResolutionPipeline,
46
+ PNDMPipeline,
47
+ RePaintPipeline,
48
+ ScoreSdeVePipeline,
49
+ )
50
+ from .schedulers import (
51
+ DDIMScheduler,
52
+ DDPMScheduler,
53
+ DPMSolverMultistepScheduler,
54
+ DPMSolverSinglestepScheduler,
55
+ EulerAncestralDiscreteScheduler,
56
+ EulerDiscreteScheduler,
57
+ HeunDiscreteScheduler,
58
+ IPNDMScheduler,
59
+ KarrasVeScheduler,
60
+ KDPM2AncestralDiscreteScheduler,
61
+ KDPM2DiscreteScheduler,
62
+ PNDMScheduler,
63
+ RePaintScheduler,
64
+ SchedulerMixin,
65
+ ScoreSdeVeScheduler,
66
+ VQDiffusionScheduler,
67
+ )
68
+ from .training_utils import EMAModel
69
+
70
+ try:
71
+ if not (is_torch_available() and is_scipy_available()):
72
+ raise OptionalDependencyNotAvailable()
73
+ except OptionalDependencyNotAvailable:
74
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
75
+ else:
76
+ from .schedulers import LMSDiscreteScheduler
77
+
78
+
79
+ try:
80
+ if not (is_torch_available() and is_transformers_available()):
81
+ raise OptionalDependencyNotAvailable()
82
+ except OptionalDependencyNotAvailable:
83
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
84
+ else:
85
+ from .pipelines import (
86
+ AltDiffusionImg2ImgPipeline,
87
+ AltDiffusionPipeline,
88
+ CycleDiffusionPipeline,
89
+ LDMTextToImagePipeline,
90
+ PaintByExamplePipeline,
91
+ StableDiffusionDepth2ImgPipeline,
92
+ StableDiffusionImageVariationPipeline,
93
+ StableDiffusionImg2ImgPipeline,
94
+ StableDiffusionInpaintPipeline,
95
+ StableDiffusionInpaintPipelineLegacy,
96
+ StableDiffusionPipeline,
97
+ StableDiffusionPipelineSafe,
98
+ StableDiffusionUpscalePipeline,
99
+ VersatileDiffusionDualGuidedPipeline,
100
+ VersatileDiffusionImageVariationPipeline,
101
+ VersatileDiffusionPipeline,
102
+ VersatileDiffusionTextToImagePipeline,
103
+ VQDiffusionPipeline,
104
+ )
105
+
106
+ try:
107
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
108
+ raise OptionalDependencyNotAvailable()
109
+ except OptionalDependencyNotAvailable:
110
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
111
+ else:
112
+ from .pipelines import StableDiffusionKDiffusionPipeline
113
+
114
+ try:
115
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
116
+ raise OptionalDependencyNotAvailable()
117
+ except OptionalDependencyNotAvailable:
118
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
119
+ else:
120
+ from .pipelines import (
121
+ OnnxStableDiffusionImg2ImgPipeline,
122
+ OnnxStableDiffusionInpaintPipeline,
123
+ OnnxStableDiffusionInpaintPipelineLegacy,
124
+ OnnxStableDiffusionPipeline,
125
+ StableDiffusionOnnxPipeline,
126
+ )
127
+
128
+ try:
129
+ if not (is_torch_available() and is_librosa_available()):
130
+ raise OptionalDependencyNotAvailable()
131
+ except OptionalDependencyNotAvailable:
132
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
133
+ else:
134
+ from .pipelines import AudioDiffusionPipeline, Mel
135
+
136
+ try:
137
+ if not is_flax_available():
138
+ raise OptionalDependencyNotAvailable()
139
+ except OptionalDependencyNotAvailable:
140
+ from .utils.dummy_flax_objects import * # noqa F403
141
+ else:
142
+ from .modeling_flax_utils import FlaxModelMixin
143
+ from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
144
+ from .models.vae_flax import FlaxAutoencoderKL
145
+ from .pipeline_flax_utils import FlaxDiffusionPipeline
146
+ from .schedulers import (
147
+ FlaxDDIMScheduler,
148
+ FlaxDDPMScheduler,
149
+ FlaxDPMSolverMultistepScheduler,
150
+ FlaxKarrasVeScheduler,
151
+ FlaxLMSDiscreteScheduler,
152
+ FlaxPNDMScheduler,
153
+ FlaxSchedulerMixin,
154
+ FlaxScoreSdeVeScheduler,
155
+ )
156
+
157
+ try:
158
+ if not (is_flax_available() and is_transformers_available()):
159
+ raise OptionalDependencyNotAvailable()
160
+ except OptionalDependencyNotAvailable:
161
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
162
+ else:
163
+ from .pipelines import FlaxStableDiffusionPipeline
diffusers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (4.67 kB). View file
 
diffusers/__pycache__/configuration_utils.cpython-310.pyc ADDED
Binary file (21.7 kB). View file
 
diffusers/__pycache__/dependency_versions_check.cpython-310.pyc ADDED
Binary file (938 Bytes). View file
 
diffusers/__pycache__/dependency_versions_table.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
diffusers/__pycache__/hub_utils.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
diffusers/__pycache__/modeling_flax_pytorch_utils.cpython-310.pyc ADDED
Binary file (2.61 kB). View file
 
diffusers/__pycache__/modeling_flax_utils.cpython-310.pyc ADDED
Binary file (20.9 kB). View file
 
diffusers/__pycache__/modeling_utils.cpython-310.pyc ADDED
Binary file (27.5 kB). View file
 
diffusers/__pycache__/onnx_utils.cpython-310.pyc ADDED
Binary file (6.89 kB). View file
 
diffusers/__pycache__/optimization.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
diffusers/__pycache__/pipeline_flax_utils.cpython-310.pyc ADDED
Binary file (15.8 kB). View file
 
diffusers/__pycache__/pipeline_utils.cpython-310.pyc ADDED
Binary file (31.4 kB). View file
 
diffusers/__pycache__/training_utils.cpython-310.pyc ADDED
Binary file (3.64 kB). View file
 
diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
diffusers/commands/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (788 Bytes). View file
 
diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc ADDED
Binary file (749 Bytes). View file
 
diffusers/commands/__pycache__/env.cpython-310.pyc ADDED
Binary file (2.14 kB). View file
 
diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+
20
+
21
+ def main():
22
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
+
25
+ # Register commands
26
+ EnvironmentCommand.register_subcommand(commands_parser)
27
+
28
+ # Let's go
29
+ args = parser.parse_args()
30
+
31
+ if not hasattr(args, "func"):
32
+ parser.print_help()
33
+ exit(1)
34
+
35
+ # Run
36
+ service = args.func(args)
37
+ service.run()
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
diffusers/commands/env.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_torch_available, is_transformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available:
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ info = {
53
+ "`diffusers` version": version,
54
+ "Platform": platform.platform(),
55
+ "Python version": platform.python_version(),
56
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
57
+ "Huggingface_hub version": hub_version,
58
+ "Transformers version": transformers_version,
59
+ "Using GPU in script?": "<fill in>",
60
+ "Using distributed or parallel set-up in script?": "<fill in>",
61
+ }
62
+
63
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
64
+ print(self.format_dict(info))
65
+
66
+ return info
67
+
68
+ @staticmethod
69
+ def format_dict(d):
70
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diffusers/configuration_utils.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import dataclasses
18
+ import functools
19
+ import importlib
20
+ import inspect
21
+ import json
22
+ import os
23
+ import re
24
+ from collections import OrderedDict
25
+ from typing import Any, Dict, Tuple, Union
26
+
27
+ import numpy as np
28
+
29
+ from huggingface_hub import hf_hub_download
30
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
+ from requests import HTTPError
32
+
33
+ from . import __version__
34
+ from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
40
+
41
+
42
+ class FrozenDict(OrderedDict):
43
+ def __init__(self, *args, **kwargs):
44
+ super().__init__(*args, **kwargs)
45
+
46
+ for key, value in self.items():
47
+ setattr(self, key, value)
48
+
49
+ self.__frozen = True
50
+
51
+ def __delitem__(self, *args, **kwargs):
52
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
53
+
54
+ def setdefault(self, *args, **kwargs):
55
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
56
+
57
+ def pop(self, *args, **kwargs):
58
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
59
+
60
+ def update(self, *args, **kwargs):
61
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
62
+
63
+ def __setattr__(self, name, value):
64
+ if hasattr(self, "__frozen") and self.__frozen:
65
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
66
+ super().__setattr__(name, value)
67
+
68
+ def __setitem__(self, name, value):
69
+ if hasattr(self, "__frozen") and self.__frozen:
70
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
71
+ super().__setitem__(name, value)
72
+
73
+
74
+ class ConfigMixin:
75
+ r"""
76
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
77
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
78
+ - [`~ConfigMixin.from_config`]
79
+ - [`~ConfigMixin.save_config`]
80
+
81
+ Class attributes:
82
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
83
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
84
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
85
+ overridden by subclass).
86
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
87
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
88
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
89
+ subclass).
90
+ """
91
+ config_name = None
92
+ ignore_for_config = []
93
+ has_compatibles = False
94
+
95
+ _deprecated_kwargs = []
96
+
97
+ def register_to_config(self, **kwargs):
98
+ if self.config_name is None:
99
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
100
+ # Special case for `kwargs` used in deprecation warning added to schedulers
101
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
102
+ # or solve in a more general way.
103
+ kwargs.pop("kwargs", None)
104
+ for key, value in kwargs.items():
105
+ try:
106
+ setattr(self, key, value)
107
+ except AttributeError as err:
108
+ logger.error(f"Can't set {key} with value {value} for {self}")
109
+ raise err
110
+
111
+ if not hasattr(self, "_internal_dict"):
112
+ internal_dict = kwargs
113
+ else:
114
+ previous_dict = dict(self._internal_dict)
115
+ internal_dict = {**self._internal_dict, **kwargs}
116
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
117
+
118
+ self._internal_dict = FrozenDict(internal_dict)
119
+
120
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
121
+ """
122
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
123
+ [`~ConfigMixin.from_config`] class method.
124
+
125
+ Args:
126
+ save_directory (`str` or `os.PathLike`):
127
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
128
+ """
129
+ if os.path.isfile(save_directory):
130
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
131
+
132
+ os.makedirs(save_directory, exist_ok=True)
133
+
134
+ # If we save using the predefined names, we can load using `from_config`
135
+ output_config_file = os.path.join(save_directory, self.config_name)
136
+
137
+ self.to_json_file(output_config_file)
138
+ logger.info(f"Configuration saved in {output_config_file}")
139
+
140
+ @classmethod
141
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
142
+ r"""
143
+ Instantiate a Python class from a config dictionary
144
+
145
+ Parameters:
146
+ config (`Dict[str, Any]`):
147
+ A config dictionary from which the Python class will be instantiated. Make sure to only load
148
+ configuration files of compatible classes.
149
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
150
+ Whether kwargs that are not consumed by the Python class should be returned or not.
151
+
152
+ kwargs (remaining dictionary of keyword arguments, *optional*):
153
+ Can be used to update the configuration object (after it being loaded) and initiate the Python class.
154
+ `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
155
+ overwrite same named arguments of `config`.
156
+
157
+ Examples:
158
+
159
+ ```python
160
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
161
+
162
+ >>> # Download scheduler from huggingface.co and cache.
163
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
164
+
165
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
166
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
167
+
168
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
169
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
170
+ ```
171
+ """
172
+ # <===== TO BE REMOVED WITH DEPRECATION
173
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
174
+ if "pretrained_model_name_or_path" in kwargs:
175
+ config = kwargs.pop("pretrained_model_name_or_path")
176
+
177
+ if config is None:
178
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
179
+ # ======>
180
+
181
+ if not isinstance(config, dict):
182
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
183
+ if "Scheduler" in cls.__name__:
184
+ deprecation_message += (
185
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
186
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
187
+ " be removed in v1.0.0."
188
+ )
189
+ elif "Model" in cls.__name__:
190
+ deprecation_message += (
191
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
192
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
193
+ " instead. This functionality will be removed in v1.0.0."
194
+ )
195
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
196
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
197
+
198
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
199
+
200
+ # Allow dtype to be specified on initialization
201
+ if "dtype" in unused_kwargs:
202
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
203
+
204
+ # add possible deprecated kwargs
205
+ for deprecated_kwarg in cls._deprecated_kwargs:
206
+ if deprecated_kwarg in unused_kwargs:
207
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
208
+
209
+ # Return model and optionally state and/or unused_kwargs
210
+ model = cls(**init_dict)
211
+
212
+ # make sure to also save config parameters that might be used for compatible classes
213
+ model.register_to_config(**hidden_dict)
214
+
215
+ # add hidden kwargs of compatible classes to unused_kwargs
216
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
217
+
218
+ if return_unused_kwargs:
219
+ return (model, unused_kwargs)
220
+ else:
221
+ return model
222
+
223
+ @classmethod
224
+ def get_config_dict(cls, *args, **kwargs):
225
+ deprecation_message = (
226
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
227
+ " removed in version v1.0.0"
228
+ )
229
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
230
+ return cls.load_config(*args, **kwargs)
231
+
232
+ @classmethod
233
+ def load_config(
234
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
235
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
236
+ r"""
237
+ Instantiate a Python class from a config dictionary
238
+
239
+ Parameters:
240
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
241
+ Can be either:
242
+
243
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
244
+ organization name, like `google/ddpm-celebahq-256`.
245
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
246
+ `./my_model_directory/`.
247
+
248
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
249
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
250
+ standard cache should not be used.
251
+ force_download (`bool`, *optional*, defaults to `False`):
252
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
253
+ cached versions if they exist.
254
+ resume_download (`bool`, *optional*, defaults to `False`):
255
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
256
+ file exists.
257
+ proxies (`Dict[str, str]`, *optional*):
258
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
259
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
260
+ output_loading_info(`bool`, *optional*, defaults to `False`):
261
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
262
+ local_files_only(`bool`, *optional*, defaults to `False`):
263
+ Whether or not to only look at local files (i.e., do not try to download the model).
264
+ use_auth_token (`str` or *bool*, *optional*):
265
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
266
+ when running `transformers-cli login` (stored in `~/.huggingface`).
267
+ revision (`str`, *optional*, defaults to `"main"`):
268
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
269
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
270
+ identifier allowed by git.
271
+ subfolder (`str`, *optional*, defaults to `""`):
272
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
273
+ huggingface.co or downloaded locally), you can specify the folder name here.
274
+
275
+ <Tip>
276
+
277
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
278
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
279
+
280
+ </Tip>
281
+
282
+ <Tip>
283
+
284
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
285
+ use this method in a firewalled environment.
286
+
287
+ </Tip>
288
+ """
289
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
290
+ force_download = kwargs.pop("force_download", False)
291
+ resume_download = kwargs.pop("resume_download", False)
292
+ proxies = kwargs.pop("proxies", None)
293
+ use_auth_token = kwargs.pop("use_auth_token", None)
294
+ local_files_only = kwargs.pop("local_files_only", False)
295
+ revision = kwargs.pop("revision", None)
296
+ _ = kwargs.pop("mirror", None)
297
+ subfolder = kwargs.pop("subfolder", None)
298
+
299
+ user_agent = {"file_type": "config"}
300
+
301
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
302
+
303
+ if cls.config_name is None:
304
+ raise ValueError(
305
+ "`self.config_name` is not defined. Note that one should not load a config from "
306
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
307
+ )
308
+
309
+ if os.path.isfile(pretrained_model_name_or_path):
310
+ config_file = pretrained_model_name_or_path
311
+ elif os.path.isdir(pretrained_model_name_or_path):
312
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
313
+ # Load from a PyTorch checkpoint
314
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
315
+ elif subfolder is not None and os.path.isfile(
316
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
317
+ ):
318
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
319
+ else:
320
+ raise EnvironmentError(
321
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
322
+ )
323
+ else:
324
+ try:
325
+ # Load from URL or cache if already cached
326
+ config_file = hf_hub_download(
327
+ pretrained_model_name_or_path,
328
+ filename=cls.config_name,
329
+ cache_dir=cache_dir,
330
+ force_download=force_download,
331
+ proxies=proxies,
332
+ resume_download=resume_download,
333
+ local_files_only=local_files_only,
334
+ use_auth_token=use_auth_token,
335
+ user_agent=user_agent,
336
+ subfolder=subfolder,
337
+ revision=revision,
338
+ )
339
+
340
+ except RepositoryNotFoundError:
341
+ raise EnvironmentError(
342
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
343
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
344
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
345
+ " login`."
346
+ )
347
+ except RevisionNotFoundError:
348
+ raise EnvironmentError(
349
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
350
+ " this model name. Check the model page at"
351
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
352
+ )
353
+ except EntryNotFoundError:
354
+ raise EnvironmentError(
355
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
356
+ )
357
+ except HTTPError as err:
358
+ raise EnvironmentError(
359
+ "There was a specific connection error when trying to load"
360
+ f" {pretrained_model_name_or_path}:\n{err}"
361
+ )
362
+ except ValueError:
363
+ raise EnvironmentError(
364
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
365
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
366
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
367
+ " run the library in offline mode at"
368
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
369
+ )
370
+ except EnvironmentError:
371
+ raise EnvironmentError(
372
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
373
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
374
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
375
+ f"containing a {cls.config_name} file"
376
+ )
377
+
378
+ try:
379
+ # Load config dict
380
+ config_dict = cls._dict_from_json_file(config_file)
381
+ except (json.JSONDecodeError, UnicodeDecodeError):
382
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
383
+
384
+ if return_unused_kwargs:
385
+ return config_dict, kwargs
386
+
387
+ return config_dict
388
+
389
+ @staticmethod
390
+ def _get_init_keys(cls):
391
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
392
+
393
+ @classmethod
394
+ def extract_init_dict(cls, config_dict, **kwargs):
395
+ # 0. Copy origin config dict
396
+ original_dict = {k: v for k, v in config_dict.items()}
397
+
398
+ # 1. Retrieve expected config attributes from __init__ signature
399
+ expected_keys = cls._get_init_keys(cls)
400
+ expected_keys.remove("self")
401
+ # remove general kwargs if present in dict
402
+ if "kwargs" in expected_keys:
403
+ expected_keys.remove("kwargs")
404
+ # remove flax internal keys
405
+ if hasattr(cls, "_flax_internal_args"):
406
+ for arg in cls._flax_internal_args:
407
+ expected_keys.remove(arg)
408
+
409
+ # 2. Remove attributes that cannot be expected from expected config attributes
410
+ # remove keys to be ignored
411
+ if len(cls.ignore_for_config) > 0:
412
+ expected_keys = expected_keys - set(cls.ignore_for_config)
413
+
414
+ # load diffusers library to import compatible and original scheduler
415
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
416
+
417
+ if cls.has_compatibles:
418
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
419
+ else:
420
+ compatible_classes = []
421
+
422
+ expected_keys_comp_cls = set()
423
+ for c in compatible_classes:
424
+ expected_keys_c = cls._get_init_keys(c)
425
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
426
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
427
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
428
+
429
+ # remove attributes from orig class that cannot be expected
430
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
431
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
432
+ orig_cls = getattr(diffusers_library, orig_cls_name)
433
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
434
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
435
+
436
+ # remove private attributes
437
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
438
+
439
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
440
+ init_dict = {}
441
+ for key in expected_keys:
442
+ # if config param is passed to kwarg and is present in config dict
443
+ # it should overwrite existing config dict key
444
+ if key in kwargs and key in config_dict:
445
+ config_dict[key] = kwargs.pop(key)
446
+
447
+ if key in kwargs:
448
+ # overwrite key
449
+ init_dict[key] = kwargs.pop(key)
450
+ elif key in config_dict:
451
+ # use value from config dict
452
+ init_dict[key] = config_dict.pop(key)
453
+
454
+ # 4. Give nice warning if unexpected values have been passed
455
+ if len(config_dict) > 0:
456
+ logger.warning(
457
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
458
+ "but are not expected and will be ignored. Please verify your "
459
+ f"{cls.config_name} configuration file."
460
+ )
461
+
462
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
463
+ passed_keys = set(init_dict.keys())
464
+ if len(expected_keys - passed_keys) > 0:
465
+ logger.info(
466
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
467
+ )
468
+
469
+ # 6. Define unused keyword arguments
470
+ unused_kwargs = {**config_dict, **kwargs}
471
+
472
+ # 7. Define "hidden" config parameters that were saved for compatible classes
473
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
474
+
475
+ return init_dict, unused_kwargs, hidden_config_dict
476
+
477
+ @classmethod
478
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
479
+ with open(json_file, "r", encoding="utf-8") as reader:
480
+ text = reader.read()
481
+ return json.loads(text)
482
+
483
+ def __repr__(self):
484
+ return f"{self.__class__.__name__} {self.to_json_string()}"
485
+
486
+ @property
487
+ def config(self) -> Dict[str, Any]:
488
+ """
489
+ Returns the config of the class as a frozen dictionary
490
+
491
+ Returns:
492
+ `Dict[str, Any]`: Config of the class.
493
+ """
494
+ return self._internal_dict
495
+
496
+ def to_json_string(self) -> str:
497
+ """
498
+ Serializes this instance to a JSON string.
499
+
500
+ Returns:
501
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
502
+ """
503
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
504
+ config_dict["_class_name"] = self.__class__.__name__
505
+ config_dict["_diffusers_version"] = __version__
506
+
507
+ def to_json_saveable(value):
508
+ if isinstance(value, np.ndarray):
509
+ value = value.tolist()
510
+ return value
511
+
512
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
513
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
514
+
515
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
516
+ """
517
+ Save this instance to a JSON file.
518
+
519
+ Args:
520
+ json_file_path (`str` or `os.PathLike`):
521
+ Path to the JSON file in which this configuration instance's parameters will be saved.
522
+ """
523
+ with open(json_file_path, "w", encoding="utf-8") as writer:
524
+ writer.write(self.to_json_string())
525
+
526
+
527
+ def register_to_config(init):
528
+ r"""
529
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
530
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
531
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
532
+
533
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
534
+ """
535
+
536
+ @functools.wraps(init)
537
+ def inner_init(self, *args, **kwargs):
538
+ # Ignore private kwargs in the init.
539
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
540
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
541
+ if not isinstance(self, ConfigMixin):
542
+ raise RuntimeError(
543
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
544
+ "not inherit from `ConfigMixin`."
545
+ )
546
+
547
+ ignore = getattr(self, "ignore_for_config", [])
548
+ # Get positional arguments aligned with kwargs
549
+ new_kwargs = {}
550
+ signature = inspect.signature(init)
551
+ parameters = {
552
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
553
+ }
554
+ for arg, name in zip(args, parameters.keys()):
555
+ new_kwargs[name] = arg
556
+
557
+ # Then add all kwargs
558
+ new_kwargs.update(
559
+ {
560
+ k: init_kwargs.get(k, default)
561
+ for k, default in parameters.items()
562
+ if k not in ignore and k not in new_kwargs
563
+ }
564
+ )
565
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
566
+ getattr(self, "register_to_config")(**new_kwargs)
567
+ init(self, *args, **init_kwargs)
568
+
569
+ return inner_init
570
+
571
+
572
+ def flax_register_to_config(cls):
573
+ original_init = cls.__init__
574
+
575
+ @functools.wraps(original_init)
576
+ def init(self, *args, **kwargs):
577
+ if not isinstance(self, ConfigMixin):
578
+ raise RuntimeError(
579
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
580
+ "not inherit from `ConfigMixin`."
581
+ )
582
+
583
+ # Ignore private kwargs in the init. Retrieve all passed attributes
584
+ init_kwargs = {k: v for k, v in kwargs.items()}
585
+
586
+ # Retrieve default values
587
+ fields = dataclasses.fields(self)
588
+ default_kwargs = {}
589
+ for field in fields:
590
+ # ignore flax specific attributes
591
+ if field.name in self._flax_internal_args:
592
+ continue
593
+ if type(field.default) == dataclasses._MISSING_TYPE:
594
+ default_kwargs[field.name] = None
595
+ else:
596
+ default_kwargs[field.name] = getattr(self, field.name)
597
+
598
+ # Make sure init_kwargs override default kwargs
599
+ new_kwargs = {**default_kwargs, **init_kwargs}
600
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
601
+ if "dtype" in new_kwargs:
602
+ new_kwargs.pop("dtype")
603
+
604
+ # Get positional arguments aligned with kwargs
605
+ for i, arg in enumerate(args):
606
+ name = fields[i].name
607
+ new_kwargs[name] = arg
608
+
609
+ getattr(self, "register_to_config")(**new_kwargs)
610
+ original_init(self, *args, **kwargs)
611
+
612
+ cls.__init__ = init
613
+ return cls
diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
+ if sys.version_info < (3, 7):
28
+ pkgs_to_check_at_runtime.append("dataclasses")
29
+ if sys.version_info < (3, 8):
30
+ pkgs_to_check_at_runtime.append("importlib_metadata")
31
+
32
+ for pkg in pkgs_to_check_at_runtime:
33
+ if pkg in deps:
34
+ if pkg == "tokenizers":
35
+ # must be loaded here, or else tqdm check may fail
36
+ from .utils import is_tokenizers_available
37
+
38
+ if not is_tokenizers_available():
39
+ continue # not required, check version only if installed
40
+
41
+ require_version_core(deps[pkg])
42
+ else:
43
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
+
45
+
46
+ def dep_version_check(pkg, hint=None):
47
+ require_version(deps[pkg], hint)
diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "black": "black==22.8",
8
+ "datasets": "datasets",
9
+ "filelock": "filelock",
10
+ "flake8": "flake8>=3.8.3",
11
+ "flax": "flax>=0.4.1",
12
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
13
+ "huggingface-hub": "huggingface-hub>=0.10.0",
14
+ "importlib_metadata": "importlib_metadata",
15
+ "isort": "isort>=5.5.4",
16
+ "jax": "jax>=0.2.8,!=0.3.2",
17
+ "jaxlib": "jaxlib>=0.1.65",
18
+ "k-diffusion": "k-diffusion",
19
+ "librosa": "librosa",
20
+ "modelcards": "modelcards>=0.1.4",
21
+ "numpy": "numpy",
22
+ "parameterized": "parameterized",
23
+ "pytest": "pytest",
24
+ "pytest-timeout": "pytest-timeout",
25
+ "pytest-xdist": "pytest-xdist",
26
+ "safetensors": "safetensors",
27
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
28
+ "scipy": "scipy",
29
+ "regex": "regex!=2019.12.17",
30
+ "requests": "requests",
31
+ "tensorboard": "tensorboard",
32
+ "torch": "torch>=1.4",
33
+ "torchvision": "torchvision",
34
+ "transformers": "transformers>=4.25.1",
35
+ }
diffusers/dynamic_modules_utils.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities to dynamically load objects from the Hub."""
16
+
17
+ import importlib
18
+ import inspect
19
+ import os
20
+ import re
21
+ import shutil
22
+ import sys
23
+ from pathlib import Path
24
+ from typing import Dict, Optional, Union
25
+
26
+ from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
27
+
28
+ from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
29
+
30
+
31
+ COMMUNITY_PIPELINES_URL = (
32
+ "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py"
33
+ )
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ def init_hf_modules():
40
+ """
41
+ Creates the cache directory for modules with an init, and adds it to the Python path.
42
+ """
43
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
44
+ if HF_MODULES_CACHE in sys.path:
45
+ return
46
+
47
+ sys.path.append(HF_MODULES_CACHE)
48
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
49
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
50
+ if not init_path.exists():
51
+ init_path.touch()
52
+
53
+
54
+ def create_dynamic_module(name: Union[str, os.PathLike]):
55
+ """
56
+ Creates a dynamic module in the cache directory for modules.
57
+ """
58
+ init_hf_modules()
59
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
60
+ # If the parent module does not exist yet, recursively create it.
61
+ if not dynamic_module_path.parent.exists():
62
+ create_dynamic_module(dynamic_module_path.parent)
63
+ os.makedirs(dynamic_module_path, exist_ok=True)
64
+ init_path = dynamic_module_path / "__init__.py"
65
+ if not init_path.exists():
66
+ init_path.touch()
67
+
68
+
69
+ def get_relative_imports(module_file):
70
+ """
71
+ Get the list of modules that are relatively imported in a module file.
72
+
73
+ Args:
74
+ module_file (`str` or `os.PathLike`): The module file to inspect.
75
+ """
76
+ with open(module_file, "r", encoding="utf-8") as f:
77
+ content = f.read()
78
+
79
+ # Imports of the form `import .xxx`
80
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
81
+ # Imports of the form `from .xxx import yyy`
82
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
83
+ # Unique-ify
84
+ return list(set(relative_imports))
85
+
86
+
87
+ def get_relative_import_files(module_file):
88
+ """
89
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
90
+ imports (if a imports b and b imports c, it will return module files for b and c).
91
+
92
+ Args:
93
+ module_file (`str` or `os.PathLike`): The module file to inspect.
94
+ """
95
+ no_change = False
96
+ files_to_check = [module_file]
97
+ all_relative_imports = []
98
+
99
+ # Let's recurse through all relative imports
100
+ while not no_change:
101
+ new_imports = []
102
+ for f in files_to_check:
103
+ new_imports.extend(get_relative_imports(f))
104
+
105
+ module_path = Path(module_file).parent
106
+ new_import_files = [str(module_path / m) for m in new_imports]
107
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
108
+ files_to_check = [f"{f}.py" for f in new_import_files]
109
+
110
+ no_change = len(new_import_files) == 0
111
+ all_relative_imports.extend(files_to_check)
112
+
113
+ return all_relative_imports
114
+
115
+
116
+ def check_imports(filename):
117
+ """
118
+ Check if the current Python environment contains all the libraries that are imported in a file.
119
+ """
120
+ with open(filename, "r", encoding="utf-8") as f:
121
+ content = f.read()
122
+
123
+ # Imports of the form `import xxx`
124
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
125
+ # Imports of the form `from xxx import yyy`
126
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
127
+ # Only keep the top-level module
128
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
129
+
130
+ # Unique-ify and test we got them all
131
+ imports = list(set(imports))
132
+ missing_packages = []
133
+ for imp in imports:
134
+ try:
135
+ importlib.import_module(imp)
136
+ except ImportError:
137
+ missing_packages.append(imp)
138
+
139
+ if len(missing_packages) > 0:
140
+ raise ImportError(
141
+ "This modeling file requires the following packages that were not found in your environment: "
142
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
143
+ )
144
+
145
+ return get_relative_imports(filename)
146
+
147
+
148
+ def get_class_in_module(class_name, module_path):
149
+ """
150
+ Import a module on the cache directory for modules and extract a class from it.
151
+ """
152
+ module_path = module_path.replace(os.path.sep, ".")
153
+ module = importlib.import_module(module_path)
154
+
155
+ if class_name is None:
156
+ return find_pipeline_class(module)
157
+ return getattr(module, class_name)
158
+
159
+
160
+ def find_pipeline_class(loaded_module):
161
+ """
162
+ Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
163
+ inheriting from `DiffusionPipeline`.
164
+ """
165
+ from .pipeline_utils import DiffusionPipeline
166
+
167
+ cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
168
+
169
+ pipeline_class = None
170
+ for cls_name, cls in cls_members.items():
171
+ if (
172
+ cls_name != DiffusionPipeline.__name__
173
+ and issubclass(cls, DiffusionPipeline)
174
+ and cls.__module__.split(".")[0] != "diffusers"
175
+ ):
176
+ if pipeline_class is not None:
177
+ raise ValueError(
178
+ f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
179
+ f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
180
+ f" {loaded_module}."
181
+ )
182
+ pipeline_class = cls
183
+
184
+ return pipeline_class
185
+
186
+
187
+ def get_cached_module_file(
188
+ pretrained_model_name_or_path: Union[str, os.PathLike],
189
+ module_file: str,
190
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
191
+ force_download: bool = False,
192
+ resume_download: bool = False,
193
+ proxies: Optional[Dict[str, str]] = None,
194
+ use_auth_token: Optional[Union[bool, str]] = None,
195
+ revision: Optional[str] = None,
196
+ local_files_only: bool = False,
197
+ ):
198
+ """
199
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
200
+ Transformers module.
201
+
202
+ Args:
203
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
204
+ This can be either:
205
+
206
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
207
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
208
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
209
+ - a path to a *directory* containing a configuration file saved using the
210
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
211
+
212
+ module_file (`str`):
213
+ The name of the module file containing the class to look for.
214
+ cache_dir (`str` or `os.PathLike`, *optional*):
215
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
216
+ cache should not be used.
217
+ force_download (`bool`, *optional*, defaults to `False`):
218
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
219
+ exist.
220
+ resume_download (`bool`, *optional*, defaults to `False`):
221
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
222
+ proxies (`Dict[str, str]`, *optional*):
223
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
224
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
225
+ use_auth_token (`str` or *bool*, *optional*):
226
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
227
+ when running `transformers-cli login` (stored in `~/.huggingface`).
228
+ revision (`str`, *optional*, defaults to `"main"`):
229
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
230
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
231
+ identifier allowed by git.
232
+ local_files_only (`bool`, *optional*, defaults to `False`):
233
+ If `True`, will only try to load the tokenizer configuration from local files.
234
+
235
+ <Tip>
236
+
237
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
238
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
239
+
240
+ </Tip>
241
+
242
+ Returns:
243
+ `str`: The path to the module inside the cache.
244
+ """
245
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
246
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
247
+
248
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
249
+
250
+ if os.path.isfile(module_file_or_url):
251
+ resolved_module_file = module_file_or_url
252
+ submodule = "local"
253
+ elif pretrained_model_name_or_path.count("/") == 0:
254
+ # community pipeline on GitHub
255
+ github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path)
256
+ try:
257
+ resolved_module_file = cached_download(
258
+ github_url,
259
+ cache_dir=cache_dir,
260
+ force_download=force_download,
261
+ proxies=proxies,
262
+ resume_download=resume_download,
263
+ local_files_only=local_files_only,
264
+ use_auth_token=False,
265
+ )
266
+ submodule = "git"
267
+ module_file = pretrained_model_name_or_path + ".py"
268
+ except EnvironmentError:
269
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
270
+ raise
271
+ else:
272
+ try:
273
+ # Load from URL or cache if already cached
274
+ resolved_module_file = hf_hub_download(
275
+ pretrained_model_name_or_path,
276
+ module_file,
277
+ cache_dir=cache_dir,
278
+ force_download=force_download,
279
+ proxies=proxies,
280
+ resume_download=resume_download,
281
+ local_files_only=local_files_only,
282
+ use_auth_token=use_auth_token,
283
+ )
284
+ submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
285
+ except EnvironmentError:
286
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
287
+ raise
288
+
289
+ # Check we have all the requirements in our environment
290
+ modules_needed = check_imports(resolved_module_file)
291
+
292
+ # Now we move the module inside our cached dynamic modules.
293
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
294
+ create_dynamic_module(full_submodule)
295
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
296
+ if submodule == "local" or submodule == "git":
297
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
298
+ # that hash, to only copy when there is a modification but it seems overkill for now).
299
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
300
+ shutil.copy(resolved_module_file, submodule_path / module_file)
301
+ for module_needed in modules_needed:
302
+ module_needed = f"{module_needed}.py"
303
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
304
+ else:
305
+ # Get the commit hash
306
+ # TODO: we will get this info in the etag soon, so retrieve it from there and not here.
307
+ if isinstance(use_auth_token, str):
308
+ token = use_auth_token
309
+ elif use_auth_token is True:
310
+ token = HfFolder.get_token()
311
+ else:
312
+ token = None
313
+
314
+ commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
315
+
316
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
317
+ # benefit of versioning.
318
+ submodule_path = submodule_path / commit_hash
319
+ full_submodule = full_submodule + os.path.sep + commit_hash
320
+ create_dynamic_module(full_submodule)
321
+
322
+ if not (submodule_path / module_file).exists():
323
+ shutil.copy(resolved_module_file, submodule_path / module_file)
324
+ # Make sure we also have every file with relative
325
+ for module_needed in modules_needed:
326
+ if not (submodule_path / module_needed).exists():
327
+ get_cached_module_file(
328
+ pretrained_model_name_or_path,
329
+ f"{module_needed}.py",
330
+ cache_dir=cache_dir,
331
+ force_download=force_download,
332
+ resume_download=resume_download,
333
+ proxies=proxies,
334
+ use_auth_token=use_auth_token,
335
+ revision=revision,
336
+ local_files_only=local_files_only,
337
+ )
338
+ return os.path.join(full_submodule, module_file)
339
+
340
+
341
+ def get_class_from_dynamic_module(
342
+ pretrained_model_name_or_path: Union[str, os.PathLike],
343
+ module_file: str,
344
+ class_name: Optional[str] = None,
345
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
346
+ force_download: bool = False,
347
+ resume_download: bool = False,
348
+ proxies: Optional[Dict[str, str]] = None,
349
+ use_auth_token: Optional[Union[bool, str]] = None,
350
+ revision: Optional[str] = None,
351
+ local_files_only: bool = False,
352
+ **kwargs,
353
+ ):
354
+ """
355
+ Extracts a class from a module file, present in the local folder or repository of a model.
356
+
357
+ <Tip warning={true}>
358
+
359
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
360
+ therefore only be called on trusted repos.
361
+
362
+ </Tip>
363
+
364
+ Args:
365
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
366
+ This can be either:
367
+
368
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
369
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
370
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
371
+ - a path to a *directory* containing a configuration file saved using the
372
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
373
+
374
+ module_file (`str`):
375
+ The name of the module file containing the class to look for.
376
+ class_name (`str`):
377
+ The name of the class to import in the module.
378
+ cache_dir (`str` or `os.PathLike`, *optional*):
379
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
380
+ cache should not be used.
381
+ force_download (`bool`, *optional*, defaults to `False`):
382
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
383
+ exist.
384
+ resume_download (`bool`, *optional*, defaults to `False`):
385
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
386
+ proxies (`Dict[str, str]`, *optional*):
387
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
388
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
389
+ use_auth_token (`str` or `bool`, *optional*):
390
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
391
+ when running `transformers-cli login` (stored in `~/.huggingface`).
392
+ revision (`str`, *optional*, defaults to `"main"`):
393
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
394
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
395
+ identifier allowed by git.
396
+ local_files_only (`bool`, *optional*, defaults to `False`):
397
+ If `True`, will only try to load the tokenizer configuration from local files.
398
+
399
+ <Tip>
400
+
401
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
402
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
403
+
404
+ </Tip>
405
+
406
+ Returns:
407
+ `type`: The class, dynamically imported from the module.
408
+
409
+ Examples:
410
+
411
+ ```python
412
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
413
+ # module.
414
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
415
+ ```"""
416
+ # And lastly we get the class inside our newly created module
417
+ final_module = get_cached_module_file(
418
+ pretrained_model_name_or_path,
419
+ module_file,
420
+ cache_dir=cache_dir,
421
+ force_download=force_download,
422
+ resume_download=resume_download,
423
+ proxies=proxies,
424
+ use_auth_token=use_auth_token,
425
+ revision=revision,
426
+ local_files_only=local_files_only,
427
+ )
428
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
diffusers/experimental/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (210 Bytes). View file
 
diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
diffusers/experimental/rl/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (232 Bytes). View file
 
diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-310.pyc ADDED
Binary file (4.84 kB). View file
 
diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ import tqdm
19
+
20
+ from ...models.unet_1d import UNet1DModel
21
+ from ...pipeline_utils import DiffusionPipeline
22
+ from ...utils.dummy_pt_objects import DDPMScheduler
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
+ Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
30
+
31
+ Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
32
+
33
+ Parameters:
34
+ value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
36
+ scheduler ([`SchedulerMixin`]):
37
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
38
+ application is [`DDPMScheduler`].
39
+ env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ value_function: UNet1DModel,
45
+ unet: UNet1DModel,
46
+ scheduler: DDPMScheduler,
47
+ env,
48
+ ):
49
+ super().__init__()
50
+ self.value_function = value_function
51
+ self.unet = unet
52
+ self.scheduler = scheduler
53
+ self.env = env
54
+ self.data = env.get_dataset()
55
+ self.means = dict()
56
+ for key in self.data.keys():
57
+ try:
58
+ self.means[key] = self.data[key].mean()
59
+ except:
60
+ pass
61
+ self.stds = dict()
62
+ for key in self.data.keys():
63
+ try:
64
+ self.stds[key] = self.data[key].std()
65
+ except:
66
+ pass
67
+ self.state_dim = env.observation_space.shape[0]
68
+ self.action_dim = env.action_space.shape[0]
69
+
70
+ def normalize(self, x_in, key):
71
+ return (x_in - self.means[key]) / self.stds[key]
72
+
73
+ def de_normalize(self, x_in, key):
74
+ return x_in * self.stds[key] + self.means[key]
75
+
76
+ def to_torch(self, x_in):
77
+ if type(x_in) is dict:
78
+ return {k: self.to_torch(v) for k, v in x_in.items()}
79
+ elif torch.is_tensor(x_in):
80
+ return x_in.to(self.unet.device)
81
+ return torch.tensor(x_in, device=self.unet.device)
82
+
83
+ def reset_x0(self, x_in, cond, act_dim):
84
+ for key, val in cond.items():
85
+ x_in[:, key, act_dim:] = val.clone()
86
+ return x_in
87
+
88
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
89
+ batch_size = x.shape[0]
90
+ y = None
91
+ for i in tqdm.tqdm(self.scheduler.timesteps):
92
+ # create batch of timesteps to pass into model
93
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
94
+ for _ in range(n_guide_steps):
95
+ with torch.enable_grad():
96
+ x.requires_grad_()
97
+
98
+ # permute to match dimension for pre-trained models
99
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
100
+ grad = torch.autograd.grad([y.sum()], [x])[0]
101
+
102
+ posterior_variance = self.scheduler._get_variance(i)
103
+ model_std = torch.exp(0.5 * posterior_variance)
104
+ grad = model_std * grad
105
+
106
+ grad[timesteps < 2] = 0
107
+ x = x.detach()
108
+ x = x + scale * grad
109
+ x = self.reset_x0(x, conditions, self.action_dim)
110
+
111
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
112
+
113
+ # TODO: verify deprecation of this kwarg
114
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
115
+
116
+ # apply conditions to the trajectory (set the initial state)
117
+ x = self.reset_x0(x, conditions, self.action_dim)
118
+ x = self.to_torch(x)
119
+ return x, y
120
+
121
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
122
+ # normalize the observations and create batch dimension
123
+ obs = self.normalize(obs, "observations")
124
+ obs = obs[None].repeat(batch_size, axis=0)
125
+
126
+ conditions = {0: self.to_torch(obs)}
127
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
128
+
129
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
130
+ x1 = torch.randn(shape, device=self.unet.device)
131
+ x = self.reset_x0(x1, conditions, self.action_dim)
132
+ x = self.to_torch(x)
133
+
134
+ # run the diffusion process
135
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
136
+
137
+ # sort output trajectories by value
138
+ sorted_idx = y.argsort(0, descending=True).squeeze()
139
+ sorted_values = x[sorted_idx]
140
+ actions = sorted_values[:, :, : self.action_dim]
141
+ actions = actions.detach().cpu().numpy()
142
+ denorm_actions = self.de_normalize(actions, key="actions")
143
+
144
+ # select the action with the highest value
145
+ if y is not None:
146
+ selected_index = 0
147
+ else:
148
+ # if we didn't run value guiding, select a random action
149
+ selected_index = np.random.randint(0, batch_size)
150
+
151
+ denorm_actions = denorm_actions[selected_index, 0]
152
+ return denorm_actions
diffusers/hub_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import sys
19
+ from pathlib import Path
20
+ from typing import Dict, Optional, Union
21
+ from uuid import uuid4
22
+
23
+ import requests
24
+ from huggingface_hub import HfFolder, whoami
25
+
26
+ from . import __version__
27
+ from .utils import ENV_VARS_TRUE_VALUES, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
28
+ from .utils.import_utils import (
29
+ _flax_version,
30
+ _jax_version,
31
+ _onnxruntime_version,
32
+ _torch_version,
33
+ is_flax_available,
34
+ is_modelcards_available,
35
+ is_onnx_available,
36
+ is_torch_available,
37
+ )
38
+
39
+
40
+ if is_modelcards_available():
41
+ from modelcards import CardData, ModelCard
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
48
+ SESSION_ID = uuid4().hex
49
+ HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
50
+ DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
51
+ HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
52
+
53
+
54
+ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
55
+ """
56
+ Formats a user-agent string with basic info about a request.
57
+ """
58
+ ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
59
+ if DISABLE_TELEMETRY:
60
+ return ua + "; telemetry/off"
61
+ if is_torch_available():
62
+ ua += f"; torch/{_torch_version}"
63
+ if is_flax_available():
64
+ ua += f"; jax/{_jax_version}"
65
+ ua += f"; flax/{_flax_version}"
66
+ if is_onnx_available():
67
+ ua += f"; onnxruntime/{_onnxruntime_version}"
68
+ # CI will set this value to True
69
+ if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
70
+ ua += "; is_ci/true"
71
+ if isinstance(user_agent, dict):
72
+ ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
73
+ elif isinstance(user_agent, str):
74
+ ua += "; " + user_agent
75
+ return ua
76
+
77
+
78
+ def send_telemetry(data: Dict, name: str):
79
+ """
80
+ Sends logs to the Hub telemetry endpoint.
81
+
82
+ Args:
83
+ data: the fields to track, e.g. {"example_name": "dreambooth"}
84
+ name: a unique name to differentiate the telemetry logs, e.g. "diffusers_examples" or "diffusers_notebooks"
85
+ """
86
+ if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
87
+ pass
88
+
89
+ headers = {"user-agent": http_user_agent(data)}
90
+ endpoint = HUGGINGFACE_CO_TELEMETRY + name
91
+ try:
92
+ r = requests.head(endpoint, headers=headers)
93
+ r.raise_for_status()
94
+ except Exception:
95
+ # We don't want to error in case of connection errors of any kind.
96
+ pass
97
+
98
+
99
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
100
+ if token is None:
101
+ token = HfFolder.get_token()
102
+ if organization is None:
103
+ username = whoami(token)["name"]
104
+ return f"{username}/{model_id}"
105
+ else:
106
+ return f"{organization}/{model_id}"
107
+
108
+
109
+ def create_model_card(args, model_name):
110
+ if not is_modelcards_available:
111
+ raise ValueError(
112
+ "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
113
+ " install the package with `pip install modelcards`."
114
+ )
115
+
116
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
117
+ return
118
+
119
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
120
+ repo_name = get_full_repo_name(model_name, token=hub_token)
121
+
122
+ model_card = ModelCard.from_template(
123
+ card_data=CardData( # Card metadata object that will be converted to YAML block
124
+ language="en",
125
+ license="apache-2.0",
126
+ library_name="diffusers",
127
+ tags=[],
128
+ datasets=args.dataset_name,
129
+ metrics=[],
130
+ ),
131
+ template_path=MODEL_CARD_TEMPLATE_PATH,
132
+ model_name=model_name,
133
+ repo_name=repo_name,
134
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
135
+ learning_rate=args.learning_rate,
136
+ train_batch_size=args.train_batch_size,
137
+ eval_batch_size=args.eval_batch_size,
138
+ gradient_accumulation_steps=args.gradient_accumulation_steps
139
+ if hasattr(args, "gradient_accumulation_steps")
140
+ else None,
141
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
142
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
143
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
144
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
145
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
146
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
147
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
148
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
149
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
150
+ mixed_precision=args.mixed_precision,
151
+ )
152
+
153
+ card_path = os.path.join(args.output_dir, "README.md")
154
+ model_card.save(card_path)
diffusers/modeling_flax_pytorch_utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch - Flax general utilities."""
16
+ import re
17
+
18
+ import jax.numpy as jnp
19
+ from flax.traverse_util import flatten_dict, unflatten_dict
20
+ from jax.random import PRNGKey
21
+
22
+ from .utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def rename_key(key):
29
+ regex = r"\w+[.]\d+"
30
+ pats = re.findall(regex, key)
31
+ for pat in pats:
32
+ key = key.replace(pat, "_".join(pat.split(".")))
33
+ return key
34
+
35
+
36
+ #####################
37
+ # PyTorch => Flax #
38
+ #####################
39
+
40
+ # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
41
+ # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
42
+ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
43
+ """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
44
+
45
+ # conv norm or layer norm
46
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
47
+ if (
48
+ any("norm" in str_ for str_ in pt_tuple_key)
49
+ and (pt_tuple_key[-1] == "bias")
50
+ and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
51
+ and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
52
+ ):
53
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
54
+ return renamed_pt_tuple_key, pt_tensor
55
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
56
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
57
+ return renamed_pt_tuple_key, pt_tensor
58
+
59
+ # embedding
60
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
61
+ pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
62
+ return renamed_pt_tuple_key, pt_tensor
63
+
64
+ # conv layer
65
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
66
+ if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
67
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
68
+ return renamed_pt_tuple_key, pt_tensor
69
+
70
+ # linear layer
71
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
72
+ if pt_tuple_key[-1] == "weight":
73
+ pt_tensor = pt_tensor.T
74
+ return renamed_pt_tuple_key, pt_tensor
75
+
76
+ # old PyTorch layer norm weight
77
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78
+ if pt_tuple_key[-1] == "gamma":
79
+ return renamed_pt_tuple_key, pt_tensor
80
+
81
+ # old PyTorch layer norm bias
82
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
83
+ if pt_tuple_key[-1] == "beta":
84
+ return renamed_pt_tuple_key, pt_tensor
85
+
86
+ return pt_tuple_key, pt_tensor
87
+
88
+
89
+ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
90
+ # Step 1: Convert pytorch tensor to numpy
91
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
92
+
93
+ # Step 2: Since the model is stateless, get random Flax params
94
+ random_flax_params = flax_model.init_weights(PRNGKey(init_key))
95
+
96
+ random_flax_state_dict = flatten_dict(random_flax_params)
97
+ flax_state_dict = {}
98
+
99
+ # Need to change some parameters name to match Flax names
100
+ for pt_key, pt_tensor in pt_state_dict.items():
101
+ renamed_pt_key = rename_key(pt_key)
102
+ pt_tuple_key = tuple(renamed_pt_key.split("."))
103
+
104
+ # Correctly rename weight parameters
105
+ flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
106
+
107
+ if flax_key in random_flax_state_dict:
108
+ if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
109
+ raise ValueError(
110
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
111
+ f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
112
+ )
113
+
114
+ # also add unexpected weight so that warning is thrown
115
+ flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
116
+
117
+ return unflatten_dict(flax_state_dict)
diffusers/modeling_flax_utils.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from pickle import UnpicklingError
18
+ from typing import Any, Dict, Union
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import msgpack.exceptions
23
+ from flax.core.frozen_dict import FrozenDict, unfreeze
24
+ from flax.serialization import from_bytes, to_bytes
25
+ from flax.traverse_util import flatten_dict, unflatten_dict
26
+ from huggingface_hub import hf_hub_download
27
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
28
+ from requests import HTTPError
29
+
30
+ from . import __version__, is_torch_available
31
+ from .hub_utils import send_telemetry
32
+ from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
33
+ from .utils import (
34
+ CONFIG_NAME,
35
+ DIFFUSERS_CACHE,
36
+ FLAX_WEIGHTS_NAME,
37
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
38
+ WEIGHTS_NAME,
39
+ logging,
40
+ )
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ class FlaxModelMixin:
47
+ r"""
48
+ Base class for all flax models.
49
+
50
+ [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
51
+ downloading and saving models.
52
+ """
53
+ config_name = CONFIG_NAME
54
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
55
+ _flax_internal_args = ["name", "parent", "dtype"]
56
+
57
+ @classmethod
58
+ def _from_config(cls, config, **kwargs):
59
+ """
60
+ All context managers that the model should be initialized under go here.
61
+ """
62
+ return cls(config, **kwargs)
63
+
64
+ def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
65
+ """
66
+ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
67
+ """
68
+
69
+ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
70
+ def conditional_cast(param):
71
+ if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
72
+ param = param.astype(dtype)
73
+ return param
74
+
75
+ if mask is None:
76
+ return jax.tree_map(conditional_cast, params)
77
+
78
+ flat_params = flatten_dict(params)
79
+ flat_mask, _ = jax.tree_flatten(mask)
80
+
81
+ for masked, key in zip(flat_mask, flat_params.keys()):
82
+ if masked:
83
+ param = flat_params[key]
84
+ flat_params[key] = conditional_cast(param)
85
+
86
+ return unflatten_dict(flat_params)
87
+
88
+ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
89
+ r"""
90
+ Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
91
+ the `params` in place.
92
+
93
+ This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
94
+ half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
95
+
96
+ Arguments:
97
+ params (`Union[Dict, FrozenDict]`):
98
+ A `PyTree` of model parameters.
99
+ mask (`Union[Dict, FrozenDict]`):
100
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
101
+ you want to cast, and should be `False` for those you want to skip.
102
+
103
+ Examples:
104
+
105
+ ```python
106
+ >>> from diffusers import FlaxUNet2DConditionModel
107
+
108
+ >>> # load model
109
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
110
+ >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
111
+ >>> params = model.to_bf16(params)
112
+ >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
113
+ >>> # then pass the mask as follows
114
+ >>> from flax import traverse_util
115
+
116
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
117
+ >>> flat_params = traverse_util.flatten_dict(params)
118
+ >>> mask = {
119
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
120
+ ... for path in flat_params
121
+ ... }
122
+ >>> mask = traverse_util.unflatten_dict(mask)
123
+ >>> params = model.to_bf16(params, mask)
124
+ ```"""
125
+ return self._cast_floating_to(params, jnp.bfloat16, mask)
126
+
127
+ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
128
+ r"""
129
+ Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
130
+ model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
131
+
132
+ Arguments:
133
+ params (`Union[Dict, FrozenDict]`):
134
+ A `PyTree` of model parameters.
135
+ mask (`Union[Dict, FrozenDict]`):
136
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
137
+ you want to cast, and should be `False` for those you want to skip
138
+
139
+ Examples:
140
+
141
+ ```python
142
+ >>> from diffusers import FlaxUNet2DConditionModel
143
+
144
+ >>> # Download model and configuration from huggingface.co
145
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
146
+ >>> # By default, the model params will be in fp32, to illustrate the use of this method,
147
+ >>> # we'll first cast to fp16 and back to fp32
148
+ >>> params = model.to_f16(params)
149
+ >>> # now cast back to fp32
150
+ >>> params = model.to_fp32(params)
151
+ ```"""
152
+ return self._cast_floating_to(params, jnp.float32, mask)
153
+
154
+ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
155
+ r"""
156
+ Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
157
+ `params` in place.
158
+
159
+ This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
160
+ half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
161
+
162
+ Arguments:
163
+ params (`Union[Dict, FrozenDict]`):
164
+ A `PyTree` of model parameters.
165
+ mask (`Union[Dict, FrozenDict]`):
166
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
167
+ you want to cast, and should be `False` for those you want to skip
168
+
169
+ Examples:
170
+
171
+ ```python
172
+ >>> from diffusers import FlaxUNet2DConditionModel
173
+
174
+ >>> # load model
175
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
176
+ >>> # By default, the model params will be in fp32, to cast these to float16
177
+ >>> params = model.to_fp16(params)
178
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
179
+ >>> # then pass the mask as follows
180
+ >>> from flax import traverse_util
181
+
182
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
183
+ >>> flat_params = traverse_util.flatten_dict(params)
184
+ >>> mask = {
185
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
186
+ ... for path in flat_params
187
+ ... }
188
+ >>> mask = traverse_util.unflatten_dict(mask)
189
+ >>> params = model.to_fp16(params, mask)
190
+ ```"""
191
+ return self._cast_floating_to(params, jnp.float16, mask)
192
+
193
+ def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
194
+ raise NotImplementedError(f"init_weights method has to be implemented for {self}")
195
+
196
+ @classmethod
197
+ def from_pretrained(
198
+ cls,
199
+ pretrained_model_name_or_path: Union[str, os.PathLike],
200
+ dtype: jnp.dtype = jnp.float32,
201
+ *model_args,
202
+ **kwargs,
203
+ ):
204
+ r"""
205
+ Instantiate a pretrained flax model from a pre-trained model configuration.
206
+
207
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
208
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
209
+ task.
210
+
211
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
212
+ weights are discarded.
213
+
214
+ Parameters:
215
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
216
+ Can be either:
217
+
218
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
219
+ Valid model ids are namespaced under a user or organization name, like
220
+ `runwayml/stable-diffusion-v1-5`.
221
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
222
+ e.g., `./my_model_directory/`.
223
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
224
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
225
+ `jax.numpy.bfloat16` (on TPUs).
226
+
227
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
228
+ specified all the computation will be performed with the given `dtype`.
229
+
230
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
231
+ parameters.**
232
+
233
+ If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and
234
+ [`~ModelMixin.to_bf16`].
235
+ model_args (sequence of positional arguments, *optional*):
236
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
237
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
238
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
239
+ standard cache should not be used.
240
+ force_download (`bool`, *optional*, defaults to `False`):
241
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
242
+ cached versions if they exist.
243
+ resume_download (`bool`, *optional*, defaults to `False`):
244
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
245
+ file exists.
246
+ proxies (`Dict[str, str]`, *optional*):
247
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
248
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
249
+ local_files_only(`bool`, *optional*, defaults to `False`):
250
+ Whether or not to only look at local files (i.e., do not try to download the model).
251
+ revision (`str`, *optional*, defaults to `"main"`):
252
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
253
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
254
+ identifier allowed by git.
255
+ from_pt (`bool`, *optional*, defaults to `False`):
256
+ Load the model weights from a PyTorch checkpoint save file.
257
+ kwargs (remaining dictionary of keyword arguments, *optional*):
258
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
259
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
260
+ automatically loaded:
261
+
262
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
263
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
264
+ already been done)
265
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
266
+ initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to
267
+ a configuration attribute will be used to override said attribute with the supplied `kwargs`
268
+ value. Remaining keys that do not correspond to any configuration attribute will be passed to the
269
+ underlying model's `__init__` function.
270
+
271
+ Examples:
272
+
273
+ ```python
274
+ >>> from diffusers import FlaxUNet2DConditionModel
275
+
276
+ >>> # Download model and configuration from huggingface.co and cache.
277
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
278
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
279
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
280
+ ```"""
281
+ config = kwargs.pop("config", None)
282
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
283
+ force_download = kwargs.pop("force_download", False)
284
+ from_pt = kwargs.pop("from_pt", False)
285
+ resume_download = kwargs.pop("resume_download", False)
286
+ proxies = kwargs.pop("proxies", None)
287
+ local_files_only = kwargs.pop("local_files_only", False)
288
+ use_auth_token = kwargs.pop("use_auth_token", None)
289
+ revision = kwargs.pop("revision", None)
290
+ subfolder = kwargs.pop("subfolder", None)
291
+
292
+ user_agent = {
293
+ "diffusers": __version__,
294
+ "file_type": "model",
295
+ "framework": "flax",
296
+ }
297
+
298
+ # Load config if we don't provide a configuration
299
+ config_path = config if config is not None else pretrained_model_name_or_path
300
+ model, model_kwargs = cls.from_config(
301
+ config_path,
302
+ cache_dir=cache_dir,
303
+ return_unused_kwargs=True,
304
+ force_download=force_download,
305
+ resume_download=resume_download,
306
+ proxies=proxies,
307
+ local_files_only=local_files_only,
308
+ use_auth_token=use_auth_token,
309
+ revision=revision,
310
+ subfolder=subfolder,
311
+ # model args
312
+ dtype=dtype,
313
+ **kwargs,
314
+ )
315
+
316
+ # Load model
317
+ pretrained_path_with_subfolder = (
318
+ pretrained_model_name_or_path
319
+ if subfolder is None
320
+ else os.path.join(pretrained_model_name_or_path, subfolder)
321
+ )
322
+ if os.path.isdir(pretrained_path_with_subfolder):
323
+ if from_pt:
324
+ if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
325
+ raise EnvironmentError(
326
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
327
+ )
328
+ model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
329
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
330
+ # Load from a Flax checkpoint
331
+ model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
332
+ # Check if pytorch weights exist instead
333
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
334
+ raise EnvironmentError(
335
+ f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
336
+ " using `from_pt=True`."
337
+ )
338
+ else:
339
+ raise EnvironmentError(
340
+ f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
341
+ f"{pretrained_path_with_subfolder}."
342
+ )
343
+ send_telemetry(
344
+ {"model_class": cls.__name__, "model_path": "local", "framework": "flax"},
345
+ name="diffusers_from_pretrained",
346
+ )
347
+ else:
348
+ try:
349
+ model_file = hf_hub_download(
350
+ pretrained_model_name_or_path,
351
+ filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
352
+ cache_dir=cache_dir,
353
+ force_download=force_download,
354
+ proxies=proxies,
355
+ resume_download=resume_download,
356
+ local_files_only=local_files_only,
357
+ use_auth_token=use_auth_token,
358
+ user_agent=user_agent,
359
+ subfolder=subfolder,
360
+ revision=revision,
361
+ )
362
+ send_telemetry(
363
+ {"model_class": cls.__name__, "model_path": "hub", "framework": "flax"},
364
+ name="diffusers_from_pretrained",
365
+ )
366
+
367
+ except RepositoryNotFoundError:
368
+ raise EnvironmentError(
369
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
370
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
371
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
372
+ "login`."
373
+ )
374
+ except RevisionNotFoundError:
375
+ raise EnvironmentError(
376
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
377
+ "this model name. Check the model page at "
378
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
379
+ )
380
+ except EntryNotFoundError:
381
+ raise EnvironmentError(
382
+ f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
383
+ )
384
+ except HTTPError as err:
385
+ raise EnvironmentError(
386
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
387
+ f"{err}"
388
+ )
389
+ except ValueError:
390
+ raise EnvironmentError(
391
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
392
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
393
+ f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
394
+ " internet connection or see how to run the library in offline mode at"
395
+ " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
396
+ )
397
+ except EnvironmentError:
398
+ raise EnvironmentError(
399
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
400
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
401
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
402
+ f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
403
+ )
404
+
405
+ if from_pt:
406
+ if is_torch_available():
407
+ from .modeling_utils import load_state_dict
408
+ else:
409
+ raise EnvironmentError(
410
+ "Can't load the model in PyTorch format because PyTorch is not installed. "
411
+ "Please, install PyTorch or use native Flax weights."
412
+ )
413
+
414
+ # Step 1: Get the pytorch file
415
+ pytorch_model_file = load_state_dict(model_file)
416
+
417
+ # Step 2: Convert the weights
418
+ state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
419
+ else:
420
+ try:
421
+ with open(model_file, "rb") as state_f:
422
+ state = from_bytes(cls, state_f.read())
423
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
424
+ try:
425
+ with open(model_file) as f:
426
+ if f.read().startswith("version"):
427
+ raise OSError(
428
+ "You seem to have cloned a repository without having git-lfs installed. Please"
429
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
430
+ " folder you cloned."
431
+ )
432
+ else:
433
+ raise ValueError from e
434
+ except (UnicodeDecodeError, ValueError):
435
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
436
+ # make sure all arrays are stored as jnp.ndarray
437
+ # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
438
+ # https://github.com/google/flax/issues/1261
439
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
440
+
441
+ # flatten dicts
442
+ state = flatten_dict(state)
443
+
444
+ params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
445
+ required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
446
+
447
+ shape_state = flatten_dict(unfreeze(params_shape_tree))
448
+
449
+ missing_keys = required_params - set(state.keys())
450
+ unexpected_keys = set(state.keys()) - required_params
451
+
452
+ if missing_keys:
453
+ logger.warning(
454
+ f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
455
+ "Make sure to call model.init_weights to initialize the missing weights."
456
+ )
457
+ cls._missing_keys = missing_keys
458
+
459
+ for key in state.keys():
460
+ if key in shape_state and state[key].shape != shape_state[key].shape:
461
+ raise ValueError(
462
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
463
+ f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
464
+ )
465
+
466
+ # remove unexpected keys to not be saved again
467
+ for unexpected_key in unexpected_keys:
468
+ del state[unexpected_key]
469
+
470
+ if len(unexpected_keys) > 0:
471
+ logger.warning(
472
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
473
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
474
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
475
+ " with another architecture."
476
+ )
477
+ else:
478
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
479
+
480
+ if len(missing_keys) > 0:
481
+ logger.warning(
482
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
483
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
484
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
485
+ )
486
+ else:
487
+ logger.info(
488
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
489
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
490
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
491
+ " training."
492
+ )
493
+
494
+ return model, unflatten_dict(state)
495
+
496
+ def save_pretrained(
497
+ self,
498
+ save_directory: Union[str, os.PathLike],
499
+ params: Union[Dict, FrozenDict],
500
+ is_main_process: bool = True,
501
+ ):
502
+ """
503
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
504
+ `[`~FlaxModelMixin.from_pretrained`]` class method
505
+
506
+ Arguments:
507
+ save_directory (`str` or `os.PathLike`):
508
+ Directory to which to save. Will be created if it doesn't exist.
509
+ params (`Union[Dict, FrozenDict]`):
510
+ A `PyTree` of model parameters.
511
+ is_main_process (`bool`, *optional*, defaults to `True`):
512
+ Whether the process calling this is the main process or not. Useful when in distributed training like
513
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
514
+ the main process to avoid race conditions.
515
+ """
516
+ if os.path.isfile(save_directory):
517
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
518
+ return
519
+
520
+ os.makedirs(save_directory, exist_ok=True)
521
+
522
+ model_to_save = self
523
+
524
+ # Attach architecture to the config
525
+ # Save the config
526
+ if is_main_process:
527
+ model_to_save.save_config(save_directory)
528
+
529
+ # save model
530
+ output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
531
+ with open(output_model_file, "wb") as f:
532
+ model_bytes = to_bytes(params)
533
+ f.write(model_bytes)
534
+
535
+ logger.info(f"Model weights saved in {output_model_file}")
diffusers/modeling_utils.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ from functools import partial
19
+ from typing import Callable, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ from torch import Tensor, device
23
+
24
+ from huggingface_hub import hf_hub_download
25
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
26
+ from requests import HTTPError
27
+
28
+ from . import __version__
29
+ from .hub_utils import send_telemetry
30
+ from .utils import (
31
+ CONFIG_NAME,
32
+ DIFFUSERS_CACHE,
33
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
34
+ SAFETENSORS_WEIGHTS_NAME,
35
+ WEIGHTS_NAME,
36
+ is_accelerate_available,
37
+ is_safetensors_available,
38
+ is_torch_version,
39
+ logging,
40
+ )
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ if is_torch_version(">=", "1.9.0"):
47
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
48
+ else:
49
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
50
+
51
+
52
+ if is_accelerate_available():
53
+ import accelerate
54
+ from accelerate.utils import set_module_tensor_to_device
55
+ from accelerate.utils.versions import is_torch_version
56
+
57
+ if is_safetensors_available():
58
+ import safetensors
59
+
60
+
61
+ def get_parameter_device(parameter: torch.nn.Module):
62
+ try:
63
+ return next(parameter.parameters()).device
64
+ except StopIteration:
65
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
66
+
67
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
68
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
69
+ return tuples
70
+
71
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
72
+ first_tuple = next(gen)
73
+ return first_tuple[1].device
74
+
75
+
76
+ def get_parameter_dtype(parameter: torch.nn.Module):
77
+ try:
78
+ return next(parameter.parameters()).dtype
79
+ except StopIteration:
80
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
81
+
82
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
83
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
84
+ return tuples
85
+
86
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
87
+ first_tuple = next(gen)
88
+ return first_tuple[1].dtype
89
+
90
+
91
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
92
+ """
93
+ Reads a checkpoint file, returning properly formatted errors if they arise.
94
+ """
95
+ try:
96
+ if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
97
+ return torch.load(checkpoint_file, map_location="cpu")
98
+ else:
99
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
100
+ except Exception as e:
101
+ try:
102
+ with open(checkpoint_file) as f:
103
+ if f.read().startswith("version"):
104
+ raise OSError(
105
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
106
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
107
+ "you cloned."
108
+ )
109
+ else:
110
+ raise ValueError(
111
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
112
+ "model. Make sure you have saved the model properly."
113
+ ) from e
114
+ except (UnicodeDecodeError, ValueError):
115
+ raise OSError(
116
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
117
+ f"at '{checkpoint_file}'. "
118
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
119
+ )
120
+
121
+
122
+ def _load_state_dict_into_model(model_to_load, state_dict):
123
+ # Convert old format to new format if needed from a PyTorch state_dict
124
+ # copy state_dict so _load_from_state_dict can modify it
125
+ state_dict = state_dict.copy()
126
+ error_msgs = []
127
+
128
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
129
+ # so we need to apply the function recursively.
130
+ def load(module: torch.nn.Module, prefix=""):
131
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
132
+ module._load_from_state_dict(*args)
133
+
134
+ for name, child in module._modules.items():
135
+ if child is not None:
136
+ load(child, prefix + name + ".")
137
+
138
+ load(model_to_load)
139
+
140
+ return error_msgs
141
+
142
+
143
+ class ModelMixin(torch.nn.Module):
144
+ r"""
145
+ Base class for all models.
146
+
147
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
148
+ and saving models.
149
+
150
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
151
+ [`~modeling_utils.ModelMixin.save_pretrained`].
152
+ """
153
+ config_name = CONFIG_NAME
154
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
155
+ _supports_gradient_checkpointing = False
156
+
157
+ def __init__(self):
158
+ super().__init__()
159
+
160
+ @property
161
+ def is_gradient_checkpointing(self) -> bool:
162
+ """
163
+ Whether gradient checkpointing is activated for this model or not.
164
+
165
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
166
+ activations".
167
+ """
168
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
169
+
170
+ def enable_gradient_checkpointing(self):
171
+ """
172
+ Activates gradient checkpointing for the current model.
173
+
174
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
175
+ activations".
176
+ """
177
+ if not self._supports_gradient_checkpointing:
178
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
179
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
180
+
181
+ def disable_gradient_checkpointing(self):
182
+ """
183
+ Deactivates gradient checkpointing for the current model.
184
+
185
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
186
+ activations".
187
+ """
188
+ if self._supports_gradient_checkpointing:
189
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
190
+
191
+ def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
192
+ # Recursively walk through all the children.
193
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
194
+ # gets the message
195
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
196
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
197
+ module.set_use_memory_efficient_attention_xformers(valid)
198
+
199
+ for child in module.children():
200
+ fn_recursive_set_mem_eff(child)
201
+
202
+ for module in self.children():
203
+ if isinstance(module, torch.nn.Module):
204
+ fn_recursive_set_mem_eff(module)
205
+
206
+ def enable_xformers_memory_efficient_attention(self):
207
+ r"""
208
+ Enable memory efficient attention as implemented in xformers.
209
+
210
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
211
+ time. Speed up at training time is not guaranteed.
212
+
213
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
214
+ is used.
215
+ """
216
+ self.set_use_memory_efficient_attention_xformers(True)
217
+
218
+ def disable_xformers_memory_efficient_attention(self):
219
+ r"""
220
+ Disable memory efficient attention as implemented in xformers.
221
+ """
222
+ self.set_use_memory_efficient_attention_xformers(False)
223
+
224
+ def save_pretrained(
225
+ self,
226
+ save_directory: Union[str, os.PathLike],
227
+ is_main_process: bool = True,
228
+ save_function: Callable = None,
229
+ safe_serialization: bool = False,
230
+ ):
231
+ """
232
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
233
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
234
+
235
+ Arguments:
236
+ save_directory (`str` or `os.PathLike`):
237
+ Directory to which to save. Will be created if it doesn't exist.
238
+ is_main_process (`bool`, *optional*, defaults to `True`):
239
+ Whether the process calling this is the main process or not. Useful when in distributed training like
240
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
241
+ the main process to avoid race conditions.
242
+ save_function (`Callable`):
243
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
244
+ need to replace `torch.save` by another method. Can be configured with the environment variable
245
+ `DIFFUSERS_SAVE_MODE`.
246
+ safe_serialization (`bool`, *optional*, defaults to `False`):
247
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
248
+ """
249
+ if safe_serialization and not is_safetensors_available():
250
+ raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
251
+
252
+ if os.path.isfile(save_directory):
253
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
254
+ return
255
+
256
+ if save_function is None:
257
+ save_function = safetensors.torch.save_file if safe_serialization else torch.save
258
+
259
+ os.makedirs(save_directory, exist_ok=True)
260
+
261
+ model_to_save = self
262
+
263
+ # Attach architecture to the config
264
+ # Save the config
265
+ if is_main_process:
266
+ model_to_save.save_config(save_directory)
267
+
268
+ # Save the model
269
+ state_dict = model_to_save.state_dict()
270
+
271
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
272
+
273
+ # Clean the folder from a previous save
274
+ for filename in os.listdir(save_directory):
275
+ full_filename = os.path.join(save_directory, filename)
276
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
277
+ # in distributed settings to avoid race conditions.
278
+ weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
279
+ if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
280
+ os.remove(full_filename)
281
+
282
+ # Save the model
283
+ save_function(state_dict, os.path.join(save_directory, weights_name))
284
+
285
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
286
+
287
+ @classmethod
288
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
289
+ r"""
290
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
291
+
292
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
293
+ the model, you should first set it back in training mode with `model.train()`.
294
+
295
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
296
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
297
+ task.
298
+
299
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
300
+ weights are discarded.
301
+
302
+ Parameters:
303
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
304
+ Can be either:
305
+
306
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
307
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
308
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
309
+ `./my_model_directory/`.
310
+
311
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
312
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
313
+ standard cache should not be used.
314
+ torch_dtype (`str` or `torch.dtype`, *optional*):
315
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
316
+ will be automatically derived from the model's weights.
317
+ force_download (`bool`, *optional*, defaults to `False`):
318
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
319
+ cached versions if they exist.
320
+ resume_download (`bool`, *optional*, defaults to `False`):
321
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
322
+ file exists.
323
+ proxies (`Dict[str, str]`, *optional*):
324
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
325
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
326
+ output_loading_info(`bool`, *optional*, defaults to `False`):
327
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
328
+ local_files_only(`bool`, *optional*, defaults to `False`):
329
+ Whether or not to only look at local files (i.e., do not try to download the model).
330
+ use_auth_token (`str` or *bool*, *optional*):
331
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
332
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
333
+ revision (`str`, *optional*, defaults to `"main"`):
334
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
335
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
336
+ identifier allowed by git.
337
+ subfolder (`str`, *optional*, defaults to `""`):
338
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
339
+ huggingface.co or downloaded locally), you can specify the folder name here.
340
+
341
+ mirror (`str`, *optional*):
342
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
343
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
344
+ Please refer to the mirror site for more information.
345
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
346
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
347
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
348
+ same device.
349
+
350
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
351
+ more information about each option see [designing a device
352
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
353
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
354
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
355
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
356
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
357
+ setting this argument to `True` will raise an error.
358
+
359
+ <Tip>
360
+
361
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
362
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
363
+
364
+ </Tip>
365
+
366
+ <Tip>
367
+
368
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
369
+ this method in a firewalled environment.
370
+
371
+ </Tip>
372
+
373
+ """
374
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
375
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
376
+ force_download = kwargs.pop("force_download", False)
377
+ resume_download = kwargs.pop("resume_download", False)
378
+ proxies = kwargs.pop("proxies", None)
379
+ output_loading_info = kwargs.pop("output_loading_info", False)
380
+ local_files_only = kwargs.pop("local_files_only", False)
381
+ use_auth_token = kwargs.pop("use_auth_token", None)
382
+ revision = kwargs.pop("revision", None)
383
+ torch_dtype = kwargs.pop("torch_dtype", None)
384
+ subfolder = kwargs.pop("subfolder", None)
385
+ device_map = kwargs.pop("device_map", None)
386
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
387
+
388
+ if low_cpu_mem_usage and not is_accelerate_available():
389
+ low_cpu_mem_usage = False
390
+ logger.warning(
391
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
392
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
393
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
394
+ " install accelerate\n```\n."
395
+ )
396
+
397
+ if device_map is not None and not is_accelerate_available():
398
+ raise NotImplementedError(
399
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
400
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
401
+ )
402
+
403
+ # Check if we can handle device_map and dispatching the weights
404
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
405
+ raise NotImplementedError(
406
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
407
+ " `device_map=None`."
408
+ )
409
+
410
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
411
+ raise NotImplementedError(
412
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
413
+ " `low_cpu_mem_usage=False`."
414
+ )
415
+
416
+ if low_cpu_mem_usage is False and device_map is not None:
417
+ raise ValueError(
418
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
419
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
420
+ )
421
+
422
+ user_agent = {
423
+ "diffusers": __version__,
424
+ "file_type": "model",
425
+ "framework": "pytorch",
426
+ }
427
+
428
+ # Load config if we don't provide a configuration
429
+ config_path = pretrained_model_name_or_path
430
+
431
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
432
+ # Load model
433
+
434
+ model_file = None
435
+ if is_safetensors_available():
436
+ try:
437
+ model_file = cls._get_model_file(
438
+ pretrained_model_name_or_path,
439
+ weights_name=SAFETENSORS_WEIGHTS_NAME,
440
+ cache_dir=cache_dir,
441
+ force_download=force_download,
442
+ resume_download=resume_download,
443
+ proxies=proxies,
444
+ local_files_only=local_files_only,
445
+ use_auth_token=use_auth_token,
446
+ revision=revision,
447
+ subfolder=subfolder,
448
+ user_agent=user_agent,
449
+ )
450
+ except:
451
+ pass
452
+ if model_file is None:
453
+ model_file = cls._get_model_file(
454
+ pretrained_model_name_or_path,
455
+ weights_name=WEIGHTS_NAME,
456
+ cache_dir=cache_dir,
457
+ force_download=force_download,
458
+ resume_download=resume_download,
459
+ proxies=proxies,
460
+ local_files_only=local_files_only,
461
+ use_auth_token=use_auth_token,
462
+ revision=revision,
463
+ subfolder=subfolder,
464
+ user_agent=user_agent,
465
+ )
466
+
467
+ if low_cpu_mem_usage:
468
+ # Instantiate model with empty weights
469
+ with accelerate.init_empty_weights():
470
+ config, unused_kwargs = cls.load_config(
471
+ config_path,
472
+ cache_dir=cache_dir,
473
+ return_unused_kwargs=True,
474
+ force_download=force_download,
475
+ resume_download=resume_download,
476
+ proxies=proxies,
477
+ local_files_only=local_files_only,
478
+ use_auth_token=use_auth_token,
479
+ revision=revision,
480
+ subfolder=subfolder,
481
+ device_map=device_map,
482
+ **kwargs,
483
+ )
484
+ model = cls.from_config(config, **unused_kwargs)
485
+
486
+ # if device_map is Non,e load the state dict on move the params from meta device to the cpu
487
+ if device_map is None:
488
+ param_device = "cpu"
489
+ state_dict = load_state_dict(model_file)
490
+ # move the parms from meta device to cpu
491
+ for param_name, param in state_dict.items():
492
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
493
+ else: # else let accelerate handle loading and dispatching.
494
+ # Load weights and dispatch according to the device_map
495
+ # by deafult the device_map is None and the weights are loaded on the CPU
496
+ accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
497
+
498
+ loading_info = {
499
+ "missing_keys": [],
500
+ "unexpected_keys": [],
501
+ "mismatched_keys": [],
502
+ "error_msgs": [],
503
+ }
504
+ else:
505
+ config, unused_kwargs = cls.load_config(
506
+ config_path,
507
+ cache_dir=cache_dir,
508
+ return_unused_kwargs=True,
509
+ force_download=force_download,
510
+ resume_download=resume_download,
511
+ proxies=proxies,
512
+ local_files_only=local_files_only,
513
+ use_auth_token=use_auth_token,
514
+ revision=revision,
515
+ subfolder=subfolder,
516
+ device_map=device_map,
517
+ **kwargs,
518
+ )
519
+ model = cls.from_config(config, **unused_kwargs)
520
+
521
+ state_dict = load_state_dict(model_file)
522
+ dtype = set(v.dtype for v in state_dict.values())
523
+
524
+ if len(dtype) > 1 and torch.float32 not in dtype:
525
+ raise ValueError(
526
+ f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
527
+ f" make sure that {model_file} weights have only one dtype."
528
+ )
529
+ elif len(dtype) > 1 and torch.float32 in dtype:
530
+ dtype = torch.float32
531
+ else:
532
+ dtype = dtype.pop()
533
+
534
+ # move model to correct dtype
535
+ model = model.to(dtype)
536
+
537
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
538
+ model,
539
+ state_dict,
540
+ model_file,
541
+ pretrained_model_name_or_path,
542
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
543
+ )
544
+
545
+ loading_info = {
546
+ "missing_keys": missing_keys,
547
+ "unexpected_keys": unexpected_keys,
548
+ "mismatched_keys": mismatched_keys,
549
+ "error_msgs": error_msgs,
550
+ }
551
+
552
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
553
+ raise ValueError(
554
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
555
+ )
556
+ elif torch_dtype is not None:
557
+ model = model.to(torch_dtype)
558
+
559
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
560
+
561
+ # Set model in evaluation mode to deactivate DropOut modules by default
562
+ model.eval()
563
+ if output_loading_info:
564
+ return model, loading_info
565
+
566
+ return model
567
+
568
+ @classmethod
569
+ def _get_model_file(
570
+ cls,
571
+ pretrained_model_name_or_path,
572
+ *,
573
+ weights_name,
574
+ subfolder,
575
+ cache_dir,
576
+ force_download,
577
+ proxies,
578
+ resume_download,
579
+ local_files_only,
580
+ use_auth_token,
581
+ user_agent,
582
+ revision,
583
+ ):
584
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
585
+ if os.path.isdir(pretrained_model_name_or_path):
586
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
587
+ # Load from a PyTorch checkpoint
588
+ model_file = os.path.join(pretrained_model_name_or_path, weights_name)
589
+ elif subfolder is not None and os.path.isfile(
590
+ os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
591
+ ):
592
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
593
+ else:
594
+ raise EnvironmentError(
595
+ f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
596
+ )
597
+ send_telemetry(
598
+ {"model_class": cls.__name__, "model_path": "local", "framework": "pytorch"},
599
+ name="diffusers_from_pretrained",
600
+ )
601
+ return model_file
602
+ else:
603
+ try:
604
+ # Load from URL or cache if already cached
605
+ model_file = hf_hub_download(
606
+ pretrained_model_name_or_path,
607
+ filename=weights_name,
608
+ cache_dir=cache_dir,
609
+ force_download=force_download,
610
+ proxies=proxies,
611
+ resume_download=resume_download,
612
+ local_files_only=local_files_only,
613
+ use_auth_token=use_auth_token,
614
+ user_agent=user_agent,
615
+ subfolder=subfolder,
616
+ revision=revision,
617
+ )
618
+ send_telemetry(
619
+ {"model_class": cls.__name__, "model_path": "hub", "framework": "pytorch"},
620
+ name="diffusers_from_pretrained",
621
+ )
622
+ return model_file
623
+
624
+ except RepositoryNotFoundError:
625
+ raise EnvironmentError(
626
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
627
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
628
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
629
+ "login`."
630
+ )
631
+ except RevisionNotFoundError:
632
+ raise EnvironmentError(
633
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
634
+ "this model name. Check the model page at "
635
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
636
+ )
637
+ except EntryNotFoundError:
638
+ raise EnvironmentError(
639
+ f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
640
+ )
641
+ except HTTPError as err:
642
+ raise EnvironmentError(
643
+ "There was a specific connection error when trying to load"
644
+ f" {pretrained_model_name_or_path}:\n{err}"
645
+ )
646
+ except ValueError:
647
+ raise EnvironmentError(
648
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
649
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
650
+ f" directory containing a file named {weights_name} or"
651
+ " \nCheckout your internet connection or see how to run the library in"
652
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
653
+ )
654
+ except EnvironmentError:
655
+ raise EnvironmentError(
656
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
657
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
658
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
659
+ f"containing a file named {weights_name}"
660
+ )
661
+
662
+ @classmethod
663
+ def _load_pretrained_model(
664
+ cls,
665
+ model,
666
+ state_dict,
667
+ resolved_archive_file,
668
+ pretrained_model_name_or_path,
669
+ ignore_mismatched_sizes=False,
670
+ ):
671
+ # Retrieve missing & unexpected_keys
672
+ model_state_dict = model.state_dict()
673
+ loaded_keys = [k for k in state_dict.keys()]
674
+
675
+ expected_keys = list(model_state_dict.keys())
676
+
677
+ original_loaded_keys = loaded_keys
678
+
679
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
680
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
681
+
682
+ # Make sure we are able to load base models as well as derived models (with heads)
683
+ model_to_load = model
684
+
685
+ def _find_mismatched_keys(
686
+ state_dict,
687
+ model_state_dict,
688
+ loaded_keys,
689
+ ignore_mismatched_sizes,
690
+ ):
691
+ mismatched_keys = []
692
+ if ignore_mismatched_sizes:
693
+ for checkpoint_key in loaded_keys:
694
+ model_key = checkpoint_key
695
+
696
+ if (
697
+ model_key in model_state_dict
698
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
699
+ ):
700
+ mismatched_keys.append(
701
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
702
+ )
703
+ del state_dict[checkpoint_key]
704
+ return mismatched_keys
705
+
706
+ if state_dict is not None:
707
+ # Whole checkpoint
708
+ mismatched_keys = _find_mismatched_keys(
709
+ state_dict,
710
+ model_state_dict,
711
+ original_loaded_keys,
712
+ ignore_mismatched_sizes,
713
+ )
714
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
715
+
716
+ if len(error_msgs) > 0:
717
+ error_msg = "\n\t".join(error_msgs)
718
+ if "size mismatch" in error_msg:
719
+ error_msg += (
720
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
721
+ )
722
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
723
+
724
+ if len(unexpected_keys) > 0:
725
+ logger.warning(
726
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
727
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
728
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
729
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
730
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
731
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
732
+ " identical (initializing a BertForSequenceClassification model from a"
733
+ " BertForSequenceClassification model)."
734
+ )
735
+ else:
736
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
737
+ if len(missing_keys) > 0:
738
+ logger.warning(
739
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
740
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
741
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
742
+ )
743
+ elif len(mismatched_keys) == 0:
744
+ logger.info(
745
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
746
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
747
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
748
+ " without further training."
749
+ )
750
+ if len(mismatched_keys) > 0:
751
+ mismatched_warning = "\n".join(
752
+ [
753
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
754
+ for key, shape1, shape2 in mismatched_keys
755
+ ]
756
+ )
757
+ logger.warning(
758
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
759
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
760
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
761
+ " able to use it for predictions and inference."
762
+ )
763
+
764
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
765
+
766
+ @property
767
+ def device(self) -> device:
768
+ """
769
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
770
+ device).
771
+ """
772
+ return get_parameter_device(self)
773
+
774
+ @property
775
+ def dtype(self) -> torch.dtype:
776
+ """
777
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
778
+ """
779
+ return get_parameter_dtype(self)
780
+
781
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
782
+ """
783
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
784
+
785
+ Args:
786
+ only_trainable (`bool`, *optional*, defaults to `False`):
787
+ Whether or not to return only the number of trainable parameters
788
+
789
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
790
+ Whether or not to return only the number of non-embeddings parameters
791
+
792
+ Returns:
793
+ `int`: The number of parameters.
794
+ """
795
+
796
+ if exclude_embeddings:
797
+ embedding_param_names = [
798
+ f"{name}.weight"
799
+ for name, module_type in self.named_modules()
800
+ if isinstance(module_type, torch.nn.Embedding)
801
+ ]
802
+ non_embedding_parameters = [
803
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
804
+ ]
805
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
806
+ else:
807
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
808
+
809
+
810
+ def _get_model_file(
811
+ pretrained_model_name_or_path,
812
+ *,
813
+ weights_name,
814
+ subfolder,
815
+ cache_dir,
816
+ force_download,
817
+ proxies,
818
+ resume_download,
819
+ local_files_only,
820
+ use_auth_token,
821
+ user_agent,
822
+ revision,
823
+ ):
824
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
825
+ if os.path.isdir(pretrained_model_name_or_path):
826
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
827
+ # Load from a PyTorch checkpoint
828
+ model_file = os.path.join(pretrained_model_name_or_path, weights_name)
829
+ return model_file
830
+ elif subfolder is not None and os.path.isfile(
831
+ os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
832
+ ):
833
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
834
+ return model_file
835
+ else:
836
+ raise EnvironmentError(
837
+ f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
838
+ )
839
+ else:
840
+ try:
841
+ # Load from URL or cache if already cached
842
+ model_file = hf_hub_download(
843
+ pretrained_model_name_or_path,
844
+ filename=weights_name,
845
+ cache_dir=cache_dir,
846
+ force_download=force_download,
847
+ proxies=proxies,
848
+ resume_download=resume_download,
849
+ local_files_only=local_files_only,
850
+ use_auth_token=use_auth_token,
851
+ user_agent=user_agent,
852
+ subfolder=subfolder,
853
+ revision=revision,
854
+ )
855
+ return model_file
856
+
857
+ except RepositoryNotFoundError:
858
+ raise EnvironmentError(
859
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
860
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
861
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
862
+ "login`."
863
+ )
864
+ except RevisionNotFoundError:
865
+ raise EnvironmentError(
866
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
867
+ "this model name. Check the model page at "
868
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
869
+ )
870
+ except EntryNotFoundError:
871
+ raise EnvironmentError(
872
+ f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
873
+ )
874
+ except HTTPError as err:
875
+ raise EnvironmentError(
876
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
877
+ )
878
+ except ValueError:
879
+ raise EnvironmentError(
880
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
881
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
882
+ f" directory containing a file named {weights_name} or"
883
+ " \nCheckout your internet connection or see how to run the library in"
884
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
885
+ )
886
+ except EnvironmentError:
887
+ raise EnvironmentError(
888
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
889
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
890
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
891
+ f"containing a file named {weights_name}"
892
+ )
diffusers/models/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..utils import is_flax_available, is_torch_available
16
+
17
+
18
+ if is_torch_available():
19
+ from .attention import Transformer2DModel
20
+ from .unet_1d import UNet1DModel
21
+ from .unet_2d import UNet2DModel
22
+ from .unet_2d_condition import UNet2DConditionModel
23
+ from .vae import AutoencoderKL, VQModel
24
+
25
+ if is_flax_available():
26
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
27
+ from .vae_flax import FlaxAutoencoderKL
diffusers/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (639 Bytes). View file
 
diffusers/models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (28.5 kB). View file
 
diffusers/models/__pycache__/attention_flax.cpython-310.pyc ADDED
Binary file (8.99 kB). View file
 
diffusers/models/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (5.87 kB). View file
 
diffusers/models/__pycache__/embeddings_flax.cpython-310.pyc ADDED
Binary file (3.19 kB). View file
 
diffusers/models/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (18.4 kB). View file
 
diffusers/models/__pycache__/resnet_flax.cpython-310.pyc ADDED
Binary file (2.77 kB). View file
 
diffusers/models/__pycache__/unet_1d.cpython-310.pyc ADDED
Binary file (7.41 kB). View file
 
diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc ADDED
Binary file (16.7 kB). View file
 
diffusers/models/__pycache__/unet_2d.cpython-310.pyc ADDED
Binary file (7.97 kB). View file
 
diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc ADDED
Binary file (26.1 kB). View file
 
diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
diffusers/models/__pycache__/unet_2d_condition_flax.cpython-310.pyc ADDED
Binary file (9.63 kB). View file