Spaces:
Sleeping
Sleeping
Upload 273 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- diffusers/__init__.py +163 -0
- diffusers/__pycache__/__init__.cpython-310.pyc +0 -0
- diffusers/__pycache__/configuration_utils.cpython-310.pyc +0 -0
- diffusers/__pycache__/dependency_versions_check.cpython-310.pyc +0 -0
- diffusers/__pycache__/dependency_versions_table.cpython-310.pyc +0 -0
- diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc +0 -0
- diffusers/__pycache__/hub_utils.cpython-310.pyc +0 -0
- diffusers/__pycache__/modeling_flax_pytorch_utils.cpython-310.pyc +0 -0
- diffusers/__pycache__/modeling_flax_utils.cpython-310.pyc +0 -0
- diffusers/__pycache__/modeling_utils.cpython-310.pyc +0 -0
- diffusers/__pycache__/onnx_utils.cpython-310.pyc +0 -0
- diffusers/__pycache__/optimization.cpython-310.pyc +0 -0
- diffusers/__pycache__/pipeline_flax_utils.cpython-310.pyc +0 -0
- diffusers/__pycache__/pipeline_utils.cpython-310.pyc +0 -0
- diffusers/__pycache__/training_utils.cpython-310.pyc +0 -0
- diffusers/commands/__init__.py +27 -0
- diffusers/commands/__pycache__/__init__.cpython-310.pyc +0 -0
- diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc +0 -0
- diffusers/commands/__pycache__/env.cpython-310.pyc +0 -0
- diffusers/commands/diffusers_cli.py +41 -0
- diffusers/commands/env.py +70 -0
- diffusers/configuration_utils.py +613 -0
- diffusers/dependency_versions_check.py +47 -0
- diffusers/dependency_versions_table.py +35 -0
- diffusers/dynamic_modules_utils.py +428 -0
- diffusers/experimental/__init__.py +1 -0
- diffusers/experimental/__pycache__/__init__.cpython-310.pyc +0 -0
- diffusers/experimental/rl/__init__.py +1 -0
- diffusers/experimental/rl/__pycache__/__init__.cpython-310.pyc +0 -0
- diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-310.pyc +0 -0
- diffusers/experimental/rl/value_guided_sampling.py +152 -0
- diffusers/hub_utils.py +154 -0
- diffusers/modeling_flax_pytorch_utils.py +117 -0
- diffusers/modeling_flax_utils.py +535 -0
- diffusers/modeling_utils.py +892 -0
- diffusers/models/__init__.py +27 -0
- diffusers/models/__pycache__/__init__.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/attention.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/attention_flax.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/embeddings.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/embeddings_flax.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/resnet.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/resnet_flax.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/unet_1d.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/unet_2d.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc +0 -0
- diffusers/models/__pycache__/unet_2d_condition_flax.cpython-310.pyc +0 -0
diffusers/__init__.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.10.2"
|
2 |
+
|
3 |
+
from .configuration_utils import ConfigMixin
|
4 |
+
from .onnx_utils import OnnxRuntimeModel
|
5 |
+
from .utils import (
|
6 |
+
OptionalDependencyNotAvailable,
|
7 |
+
is_flax_available,
|
8 |
+
is_inflect_available,
|
9 |
+
is_k_diffusion_available,
|
10 |
+
is_librosa_available,
|
11 |
+
is_onnx_available,
|
12 |
+
is_scipy_available,
|
13 |
+
is_torch_available,
|
14 |
+
is_transformers_available,
|
15 |
+
is_transformers_version,
|
16 |
+
is_unidecode_available,
|
17 |
+
logging,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
try:
|
22 |
+
if not is_torch_available():
|
23 |
+
raise OptionalDependencyNotAvailable()
|
24 |
+
except OptionalDependencyNotAvailable:
|
25 |
+
from .utils.dummy_pt_objects import * # noqa F403
|
26 |
+
else:
|
27 |
+
from .modeling_utils import ModelMixin
|
28 |
+
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
29 |
+
from .optimization import (
|
30 |
+
get_constant_schedule,
|
31 |
+
get_constant_schedule_with_warmup,
|
32 |
+
get_cosine_schedule_with_warmup,
|
33 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
34 |
+
get_linear_schedule_with_warmup,
|
35 |
+
get_polynomial_decay_schedule_with_warmup,
|
36 |
+
get_scheduler,
|
37 |
+
)
|
38 |
+
from .pipeline_utils import DiffusionPipeline
|
39 |
+
from .pipelines import (
|
40 |
+
DanceDiffusionPipeline,
|
41 |
+
DDIMPipeline,
|
42 |
+
DDPMPipeline,
|
43 |
+
KarrasVePipeline,
|
44 |
+
LDMPipeline,
|
45 |
+
LDMSuperResolutionPipeline,
|
46 |
+
PNDMPipeline,
|
47 |
+
RePaintPipeline,
|
48 |
+
ScoreSdeVePipeline,
|
49 |
+
)
|
50 |
+
from .schedulers import (
|
51 |
+
DDIMScheduler,
|
52 |
+
DDPMScheduler,
|
53 |
+
DPMSolverMultistepScheduler,
|
54 |
+
DPMSolverSinglestepScheduler,
|
55 |
+
EulerAncestralDiscreteScheduler,
|
56 |
+
EulerDiscreteScheduler,
|
57 |
+
HeunDiscreteScheduler,
|
58 |
+
IPNDMScheduler,
|
59 |
+
KarrasVeScheduler,
|
60 |
+
KDPM2AncestralDiscreteScheduler,
|
61 |
+
KDPM2DiscreteScheduler,
|
62 |
+
PNDMScheduler,
|
63 |
+
RePaintScheduler,
|
64 |
+
SchedulerMixin,
|
65 |
+
ScoreSdeVeScheduler,
|
66 |
+
VQDiffusionScheduler,
|
67 |
+
)
|
68 |
+
from .training_utils import EMAModel
|
69 |
+
|
70 |
+
try:
|
71 |
+
if not (is_torch_available() and is_scipy_available()):
|
72 |
+
raise OptionalDependencyNotAvailable()
|
73 |
+
except OptionalDependencyNotAvailable:
|
74 |
+
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
|
75 |
+
else:
|
76 |
+
from .schedulers import LMSDiscreteScheduler
|
77 |
+
|
78 |
+
|
79 |
+
try:
|
80 |
+
if not (is_torch_available() and is_transformers_available()):
|
81 |
+
raise OptionalDependencyNotAvailable()
|
82 |
+
except OptionalDependencyNotAvailable:
|
83 |
+
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
84 |
+
else:
|
85 |
+
from .pipelines import (
|
86 |
+
AltDiffusionImg2ImgPipeline,
|
87 |
+
AltDiffusionPipeline,
|
88 |
+
CycleDiffusionPipeline,
|
89 |
+
LDMTextToImagePipeline,
|
90 |
+
PaintByExamplePipeline,
|
91 |
+
StableDiffusionDepth2ImgPipeline,
|
92 |
+
StableDiffusionImageVariationPipeline,
|
93 |
+
StableDiffusionImg2ImgPipeline,
|
94 |
+
StableDiffusionInpaintPipeline,
|
95 |
+
StableDiffusionInpaintPipelineLegacy,
|
96 |
+
StableDiffusionPipeline,
|
97 |
+
StableDiffusionPipelineSafe,
|
98 |
+
StableDiffusionUpscalePipeline,
|
99 |
+
VersatileDiffusionDualGuidedPipeline,
|
100 |
+
VersatileDiffusionImageVariationPipeline,
|
101 |
+
VersatileDiffusionPipeline,
|
102 |
+
VersatileDiffusionTextToImagePipeline,
|
103 |
+
VQDiffusionPipeline,
|
104 |
+
)
|
105 |
+
|
106 |
+
try:
|
107 |
+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
108 |
+
raise OptionalDependencyNotAvailable()
|
109 |
+
except OptionalDependencyNotAvailable:
|
110 |
+
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
111 |
+
else:
|
112 |
+
from .pipelines import StableDiffusionKDiffusionPipeline
|
113 |
+
|
114 |
+
try:
|
115 |
+
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
116 |
+
raise OptionalDependencyNotAvailable()
|
117 |
+
except OptionalDependencyNotAvailable:
|
118 |
+
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
119 |
+
else:
|
120 |
+
from .pipelines import (
|
121 |
+
OnnxStableDiffusionImg2ImgPipeline,
|
122 |
+
OnnxStableDiffusionInpaintPipeline,
|
123 |
+
OnnxStableDiffusionInpaintPipelineLegacy,
|
124 |
+
OnnxStableDiffusionPipeline,
|
125 |
+
StableDiffusionOnnxPipeline,
|
126 |
+
)
|
127 |
+
|
128 |
+
try:
|
129 |
+
if not (is_torch_available() and is_librosa_available()):
|
130 |
+
raise OptionalDependencyNotAvailable()
|
131 |
+
except OptionalDependencyNotAvailable:
|
132 |
+
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
|
133 |
+
else:
|
134 |
+
from .pipelines import AudioDiffusionPipeline, Mel
|
135 |
+
|
136 |
+
try:
|
137 |
+
if not is_flax_available():
|
138 |
+
raise OptionalDependencyNotAvailable()
|
139 |
+
except OptionalDependencyNotAvailable:
|
140 |
+
from .utils.dummy_flax_objects import * # noqa F403
|
141 |
+
else:
|
142 |
+
from .modeling_flax_utils import FlaxModelMixin
|
143 |
+
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
144 |
+
from .models.vae_flax import FlaxAutoencoderKL
|
145 |
+
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
146 |
+
from .schedulers import (
|
147 |
+
FlaxDDIMScheduler,
|
148 |
+
FlaxDDPMScheduler,
|
149 |
+
FlaxDPMSolverMultistepScheduler,
|
150 |
+
FlaxKarrasVeScheduler,
|
151 |
+
FlaxLMSDiscreteScheduler,
|
152 |
+
FlaxPNDMScheduler,
|
153 |
+
FlaxSchedulerMixin,
|
154 |
+
FlaxScoreSdeVeScheduler,
|
155 |
+
)
|
156 |
+
|
157 |
+
try:
|
158 |
+
if not (is_flax_available() and is_transformers_available()):
|
159 |
+
raise OptionalDependencyNotAvailable()
|
160 |
+
except OptionalDependencyNotAvailable:
|
161 |
+
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
162 |
+
else:
|
163 |
+
from .pipelines import FlaxStableDiffusionPipeline
|
diffusers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (4.67 kB). View file
|
|
diffusers/__pycache__/configuration_utils.cpython-310.pyc
ADDED
Binary file (21.7 kB). View file
|
|
diffusers/__pycache__/dependency_versions_check.cpython-310.pyc
ADDED
Binary file (938 Bytes). View file
|
|
diffusers/__pycache__/dependency_versions_table.cpython-310.pyc
ADDED
Binary file (1.04 kB). View file
|
|
diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc
ADDED
Binary file (13.5 kB). View file
|
|
diffusers/__pycache__/hub_utils.cpython-310.pyc
ADDED
Binary file (4.53 kB). View file
|
|
diffusers/__pycache__/modeling_flax_pytorch_utils.cpython-310.pyc
ADDED
Binary file (2.61 kB). View file
|
|
diffusers/__pycache__/modeling_flax_utils.cpython-310.pyc
ADDED
Binary file (20.9 kB). View file
|
|
diffusers/__pycache__/modeling_utils.cpython-310.pyc
ADDED
Binary file (27.5 kB). View file
|
|
diffusers/__pycache__/onnx_utils.cpython-310.pyc
ADDED
Binary file (6.89 kB). View file
|
|
diffusers/__pycache__/optimization.cpython-310.pyc
ADDED
Binary file (10.1 kB). View file
|
|
diffusers/__pycache__/pipeline_flax_utils.cpython-310.pyc
ADDED
Binary file (15.8 kB). View file
|
|
diffusers/__pycache__/pipeline_utils.cpython-310.pyc
ADDED
Binary file (31.4 kB). View file
|
|
diffusers/__pycache__/training_utils.cpython-310.pyc
ADDED
Binary file (3.64 kB). View file
|
|
diffusers/commands/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
|
19 |
+
class BaseDiffusersCLICommand(ABC):
|
20 |
+
@staticmethod
|
21 |
+
@abstractmethod
|
22 |
+
def register_subcommand(parser: ArgumentParser):
|
23 |
+
raise NotImplementedError()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def run(self):
|
27 |
+
raise NotImplementedError()
|
diffusers/commands/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (788 Bytes). View file
|
|
diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc
ADDED
Binary file (749 Bytes). View file
|
|
diffusers/commands/__pycache__/env.cpython-310.pyc
ADDED
Binary file (2.14 kB). View file
|
|
diffusers/commands/diffusers_cli.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
from .env import EnvironmentCommand
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
|
23 |
+
commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
|
24 |
+
|
25 |
+
# Register commands
|
26 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
27 |
+
|
28 |
+
# Let's go
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
if not hasattr(args, "func"):
|
32 |
+
parser.print_help()
|
33 |
+
exit(1)
|
34 |
+
|
35 |
+
# Run
|
36 |
+
service = args.func(args)
|
37 |
+
service.run()
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
main()
|
diffusers/commands/env.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import platform
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
import huggingface_hub
|
19 |
+
|
20 |
+
from .. import __version__ as version
|
21 |
+
from ..utils import is_torch_available, is_transformers_available
|
22 |
+
from . import BaseDiffusersCLICommand
|
23 |
+
|
24 |
+
|
25 |
+
def info_command_factory(_):
|
26 |
+
return EnvironmentCommand()
|
27 |
+
|
28 |
+
|
29 |
+
class EnvironmentCommand(BaseDiffusersCLICommand):
|
30 |
+
@staticmethod
|
31 |
+
def register_subcommand(parser: ArgumentParser):
|
32 |
+
download_parser = parser.add_parser("env")
|
33 |
+
download_parser.set_defaults(func=info_command_factory)
|
34 |
+
|
35 |
+
def run(self):
|
36 |
+
hub_version = huggingface_hub.__version__
|
37 |
+
|
38 |
+
pt_version = "not installed"
|
39 |
+
pt_cuda_available = "NA"
|
40 |
+
if is_torch_available():
|
41 |
+
import torch
|
42 |
+
|
43 |
+
pt_version = torch.__version__
|
44 |
+
pt_cuda_available = torch.cuda.is_available()
|
45 |
+
|
46 |
+
transformers_version = "not installed"
|
47 |
+
if is_transformers_available:
|
48 |
+
import transformers
|
49 |
+
|
50 |
+
transformers_version = transformers.__version__
|
51 |
+
|
52 |
+
info = {
|
53 |
+
"`diffusers` version": version,
|
54 |
+
"Platform": platform.platform(),
|
55 |
+
"Python version": platform.python_version(),
|
56 |
+
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
57 |
+
"Huggingface_hub version": hub_version,
|
58 |
+
"Transformers version": transformers_version,
|
59 |
+
"Using GPU in script?": "<fill in>",
|
60 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
61 |
+
}
|
62 |
+
|
63 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
64 |
+
print(self.format_dict(info))
|
65 |
+
|
66 |
+
return info
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def format_dict(d):
|
70 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
diffusers/configuration_utils.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" ConfigMixin base class and utilities."""
|
17 |
+
import dataclasses
|
18 |
+
import functools
|
19 |
+
import importlib
|
20 |
+
import inspect
|
21 |
+
import json
|
22 |
+
import os
|
23 |
+
import re
|
24 |
+
from collections import OrderedDict
|
25 |
+
from typing import Any, Dict, Tuple, Union
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
|
29 |
+
from huggingface_hub import hf_hub_download
|
30 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
31 |
+
from requests import HTTPError
|
32 |
+
|
33 |
+
from . import __version__
|
34 |
+
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
|
35 |
+
|
36 |
+
|
37 |
+
logger = logging.get_logger(__name__)
|
38 |
+
|
39 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
40 |
+
|
41 |
+
|
42 |
+
class FrozenDict(OrderedDict):
|
43 |
+
def __init__(self, *args, **kwargs):
|
44 |
+
super().__init__(*args, **kwargs)
|
45 |
+
|
46 |
+
for key, value in self.items():
|
47 |
+
setattr(self, key, value)
|
48 |
+
|
49 |
+
self.__frozen = True
|
50 |
+
|
51 |
+
def __delitem__(self, *args, **kwargs):
|
52 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
53 |
+
|
54 |
+
def setdefault(self, *args, **kwargs):
|
55 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
56 |
+
|
57 |
+
def pop(self, *args, **kwargs):
|
58 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
59 |
+
|
60 |
+
def update(self, *args, **kwargs):
|
61 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
62 |
+
|
63 |
+
def __setattr__(self, name, value):
|
64 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
65 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
66 |
+
super().__setattr__(name, value)
|
67 |
+
|
68 |
+
def __setitem__(self, name, value):
|
69 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
70 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
71 |
+
super().__setitem__(name, value)
|
72 |
+
|
73 |
+
|
74 |
+
class ConfigMixin:
|
75 |
+
r"""
|
76 |
+
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
|
77 |
+
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
|
78 |
+
- [`~ConfigMixin.from_config`]
|
79 |
+
- [`~ConfigMixin.save_config`]
|
80 |
+
|
81 |
+
Class attributes:
|
82 |
+
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
83 |
+
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
84 |
+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
85 |
+
overridden by subclass).
|
86 |
+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
87 |
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
|
88 |
+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
89 |
+
subclass).
|
90 |
+
"""
|
91 |
+
config_name = None
|
92 |
+
ignore_for_config = []
|
93 |
+
has_compatibles = False
|
94 |
+
|
95 |
+
_deprecated_kwargs = []
|
96 |
+
|
97 |
+
def register_to_config(self, **kwargs):
|
98 |
+
if self.config_name is None:
|
99 |
+
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
100 |
+
# Special case for `kwargs` used in deprecation warning added to schedulers
|
101 |
+
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
102 |
+
# or solve in a more general way.
|
103 |
+
kwargs.pop("kwargs", None)
|
104 |
+
for key, value in kwargs.items():
|
105 |
+
try:
|
106 |
+
setattr(self, key, value)
|
107 |
+
except AttributeError as err:
|
108 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
109 |
+
raise err
|
110 |
+
|
111 |
+
if not hasattr(self, "_internal_dict"):
|
112 |
+
internal_dict = kwargs
|
113 |
+
else:
|
114 |
+
previous_dict = dict(self._internal_dict)
|
115 |
+
internal_dict = {**self._internal_dict, **kwargs}
|
116 |
+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
117 |
+
|
118 |
+
self._internal_dict = FrozenDict(internal_dict)
|
119 |
+
|
120 |
+
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
121 |
+
"""
|
122 |
+
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
123 |
+
[`~ConfigMixin.from_config`] class method.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
save_directory (`str` or `os.PathLike`):
|
127 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
128 |
+
"""
|
129 |
+
if os.path.isfile(save_directory):
|
130 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
131 |
+
|
132 |
+
os.makedirs(save_directory, exist_ok=True)
|
133 |
+
|
134 |
+
# If we save using the predefined names, we can load using `from_config`
|
135 |
+
output_config_file = os.path.join(save_directory, self.config_name)
|
136 |
+
|
137 |
+
self.to_json_file(output_config_file)
|
138 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
139 |
+
|
140 |
+
@classmethod
|
141 |
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
142 |
+
r"""
|
143 |
+
Instantiate a Python class from a config dictionary
|
144 |
+
|
145 |
+
Parameters:
|
146 |
+
config (`Dict[str, Any]`):
|
147 |
+
A config dictionary from which the Python class will be instantiated. Make sure to only load
|
148 |
+
configuration files of compatible classes.
|
149 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
150 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
151 |
+
|
152 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
153 |
+
Can be used to update the configuration object (after it being loaded) and initiate the Python class.
|
154 |
+
`**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
|
155 |
+
overwrite same named arguments of `config`.
|
156 |
+
|
157 |
+
Examples:
|
158 |
+
|
159 |
+
```python
|
160 |
+
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
161 |
+
|
162 |
+
>>> # Download scheduler from huggingface.co and cache.
|
163 |
+
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
164 |
+
|
165 |
+
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
166 |
+
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
167 |
+
|
168 |
+
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
169 |
+
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
170 |
+
```
|
171 |
+
"""
|
172 |
+
# <===== TO BE REMOVED WITH DEPRECATION
|
173 |
+
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
174 |
+
if "pretrained_model_name_or_path" in kwargs:
|
175 |
+
config = kwargs.pop("pretrained_model_name_or_path")
|
176 |
+
|
177 |
+
if config is None:
|
178 |
+
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
179 |
+
# ======>
|
180 |
+
|
181 |
+
if not isinstance(config, dict):
|
182 |
+
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
183 |
+
if "Scheduler" in cls.__name__:
|
184 |
+
deprecation_message += (
|
185 |
+
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
186 |
+
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
187 |
+
" be removed in v1.0.0."
|
188 |
+
)
|
189 |
+
elif "Model" in cls.__name__:
|
190 |
+
deprecation_message += (
|
191 |
+
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
192 |
+
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
193 |
+
" instead. This functionality will be removed in v1.0.0."
|
194 |
+
)
|
195 |
+
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
196 |
+
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
197 |
+
|
198 |
+
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
199 |
+
|
200 |
+
# Allow dtype to be specified on initialization
|
201 |
+
if "dtype" in unused_kwargs:
|
202 |
+
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
203 |
+
|
204 |
+
# add possible deprecated kwargs
|
205 |
+
for deprecated_kwarg in cls._deprecated_kwargs:
|
206 |
+
if deprecated_kwarg in unused_kwargs:
|
207 |
+
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
208 |
+
|
209 |
+
# Return model and optionally state and/or unused_kwargs
|
210 |
+
model = cls(**init_dict)
|
211 |
+
|
212 |
+
# make sure to also save config parameters that might be used for compatible classes
|
213 |
+
model.register_to_config(**hidden_dict)
|
214 |
+
|
215 |
+
# add hidden kwargs of compatible classes to unused_kwargs
|
216 |
+
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
217 |
+
|
218 |
+
if return_unused_kwargs:
|
219 |
+
return (model, unused_kwargs)
|
220 |
+
else:
|
221 |
+
return model
|
222 |
+
|
223 |
+
@classmethod
|
224 |
+
def get_config_dict(cls, *args, **kwargs):
|
225 |
+
deprecation_message = (
|
226 |
+
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
227 |
+
" removed in version v1.0.0"
|
228 |
+
)
|
229 |
+
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
230 |
+
return cls.load_config(*args, **kwargs)
|
231 |
+
|
232 |
+
@classmethod
|
233 |
+
def load_config(
|
234 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
235 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
236 |
+
r"""
|
237 |
+
Instantiate a Python class from a config dictionary
|
238 |
+
|
239 |
+
Parameters:
|
240 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
241 |
+
Can be either:
|
242 |
+
|
243 |
+
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
244 |
+
organization name, like `google/ddpm-celebahq-256`.
|
245 |
+
- A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
|
246 |
+
`./my_model_directory/`.
|
247 |
+
|
248 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
249 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
250 |
+
standard cache should not be used.
|
251 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
252 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
253 |
+
cached versions if they exist.
|
254 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
255 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
256 |
+
file exists.
|
257 |
+
proxies (`Dict[str, str]`, *optional*):
|
258 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
259 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
260 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
261 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
262 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
263 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
264 |
+
use_auth_token (`str` or *bool*, *optional*):
|
265 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
266 |
+
when running `transformers-cli login` (stored in `~/.huggingface`).
|
267 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
268 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
269 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
270 |
+
identifier allowed by git.
|
271 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
272 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
273 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
274 |
+
|
275 |
+
<Tip>
|
276 |
+
|
277 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
278 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
279 |
+
|
280 |
+
</Tip>
|
281 |
+
|
282 |
+
<Tip>
|
283 |
+
|
284 |
+
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
285 |
+
use this method in a firewalled environment.
|
286 |
+
|
287 |
+
</Tip>
|
288 |
+
"""
|
289 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
290 |
+
force_download = kwargs.pop("force_download", False)
|
291 |
+
resume_download = kwargs.pop("resume_download", False)
|
292 |
+
proxies = kwargs.pop("proxies", None)
|
293 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
294 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
295 |
+
revision = kwargs.pop("revision", None)
|
296 |
+
_ = kwargs.pop("mirror", None)
|
297 |
+
subfolder = kwargs.pop("subfolder", None)
|
298 |
+
|
299 |
+
user_agent = {"file_type": "config"}
|
300 |
+
|
301 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
302 |
+
|
303 |
+
if cls.config_name is None:
|
304 |
+
raise ValueError(
|
305 |
+
"`self.config_name` is not defined. Note that one should not load a config from "
|
306 |
+
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
307 |
+
)
|
308 |
+
|
309 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
310 |
+
config_file = pretrained_model_name_or_path
|
311 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
312 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
313 |
+
# Load from a PyTorch checkpoint
|
314 |
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
315 |
+
elif subfolder is not None and os.path.isfile(
|
316 |
+
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
317 |
+
):
|
318 |
+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
319 |
+
else:
|
320 |
+
raise EnvironmentError(
|
321 |
+
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
try:
|
325 |
+
# Load from URL or cache if already cached
|
326 |
+
config_file = hf_hub_download(
|
327 |
+
pretrained_model_name_or_path,
|
328 |
+
filename=cls.config_name,
|
329 |
+
cache_dir=cache_dir,
|
330 |
+
force_download=force_download,
|
331 |
+
proxies=proxies,
|
332 |
+
resume_download=resume_download,
|
333 |
+
local_files_only=local_files_only,
|
334 |
+
use_auth_token=use_auth_token,
|
335 |
+
user_agent=user_agent,
|
336 |
+
subfolder=subfolder,
|
337 |
+
revision=revision,
|
338 |
+
)
|
339 |
+
|
340 |
+
except RepositoryNotFoundError:
|
341 |
+
raise EnvironmentError(
|
342 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
343 |
+
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
344 |
+
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
|
345 |
+
" login`."
|
346 |
+
)
|
347 |
+
except RevisionNotFoundError:
|
348 |
+
raise EnvironmentError(
|
349 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
350 |
+
" this model name. Check the model page at"
|
351 |
+
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
352 |
+
)
|
353 |
+
except EntryNotFoundError:
|
354 |
+
raise EnvironmentError(
|
355 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
356 |
+
)
|
357 |
+
except HTTPError as err:
|
358 |
+
raise EnvironmentError(
|
359 |
+
"There was a specific connection error when trying to load"
|
360 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
361 |
+
)
|
362 |
+
except ValueError:
|
363 |
+
raise EnvironmentError(
|
364 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
365 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
366 |
+
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
367 |
+
" run the library in offline mode at"
|
368 |
+
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
369 |
+
)
|
370 |
+
except EnvironmentError:
|
371 |
+
raise EnvironmentError(
|
372 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
373 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
374 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
375 |
+
f"containing a {cls.config_name} file"
|
376 |
+
)
|
377 |
+
|
378 |
+
try:
|
379 |
+
# Load config dict
|
380 |
+
config_dict = cls._dict_from_json_file(config_file)
|
381 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
382 |
+
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
383 |
+
|
384 |
+
if return_unused_kwargs:
|
385 |
+
return config_dict, kwargs
|
386 |
+
|
387 |
+
return config_dict
|
388 |
+
|
389 |
+
@staticmethod
|
390 |
+
def _get_init_keys(cls):
|
391 |
+
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
392 |
+
|
393 |
+
@classmethod
|
394 |
+
def extract_init_dict(cls, config_dict, **kwargs):
|
395 |
+
# 0. Copy origin config dict
|
396 |
+
original_dict = {k: v for k, v in config_dict.items()}
|
397 |
+
|
398 |
+
# 1. Retrieve expected config attributes from __init__ signature
|
399 |
+
expected_keys = cls._get_init_keys(cls)
|
400 |
+
expected_keys.remove("self")
|
401 |
+
# remove general kwargs if present in dict
|
402 |
+
if "kwargs" in expected_keys:
|
403 |
+
expected_keys.remove("kwargs")
|
404 |
+
# remove flax internal keys
|
405 |
+
if hasattr(cls, "_flax_internal_args"):
|
406 |
+
for arg in cls._flax_internal_args:
|
407 |
+
expected_keys.remove(arg)
|
408 |
+
|
409 |
+
# 2. Remove attributes that cannot be expected from expected config attributes
|
410 |
+
# remove keys to be ignored
|
411 |
+
if len(cls.ignore_for_config) > 0:
|
412 |
+
expected_keys = expected_keys - set(cls.ignore_for_config)
|
413 |
+
|
414 |
+
# load diffusers library to import compatible and original scheduler
|
415 |
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
416 |
+
|
417 |
+
if cls.has_compatibles:
|
418 |
+
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
419 |
+
else:
|
420 |
+
compatible_classes = []
|
421 |
+
|
422 |
+
expected_keys_comp_cls = set()
|
423 |
+
for c in compatible_classes:
|
424 |
+
expected_keys_c = cls._get_init_keys(c)
|
425 |
+
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
426 |
+
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
427 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
428 |
+
|
429 |
+
# remove attributes from orig class that cannot be expected
|
430 |
+
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
431 |
+
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
432 |
+
orig_cls = getattr(diffusers_library, orig_cls_name)
|
433 |
+
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
434 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
435 |
+
|
436 |
+
# remove private attributes
|
437 |
+
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
438 |
+
|
439 |
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
440 |
+
init_dict = {}
|
441 |
+
for key in expected_keys:
|
442 |
+
# if config param is passed to kwarg and is present in config dict
|
443 |
+
# it should overwrite existing config dict key
|
444 |
+
if key in kwargs and key in config_dict:
|
445 |
+
config_dict[key] = kwargs.pop(key)
|
446 |
+
|
447 |
+
if key in kwargs:
|
448 |
+
# overwrite key
|
449 |
+
init_dict[key] = kwargs.pop(key)
|
450 |
+
elif key in config_dict:
|
451 |
+
# use value from config dict
|
452 |
+
init_dict[key] = config_dict.pop(key)
|
453 |
+
|
454 |
+
# 4. Give nice warning if unexpected values have been passed
|
455 |
+
if len(config_dict) > 0:
|
456 |
+
logger.warning(
|
457 |
+
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
458 |
+
"but are not expected and will be ignored. Please verify your "
|
459 |
+
f"{cls.config_name} configuration file."
|
460 |
+
)
|
461 |
+
|
462 |
+
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
463 |
+
passed_keys = set(init_dict.keys())
|
464 |
+
if len(expected_keys - passed_keys) > 0:
|
465 |
+
logger.info(
|
466 |
+
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
467 |
+
)
|
468 |
+
|
469 |
+
# 6. Define unused keyword arguments
|
470 |
+
unused_kwargs = {**config_dict, **kwargs}
|
471 |
+
|
472 |
+
# 7. Define "hidden" config parameters that were saved for compatible classes
|
473 |
+
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
474 |
+
|
475 |
+
return init_dict, unused_kwargs, hidden_config_dict
|
476 |
+
|
477 |
+
@classmethod
|
478 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
479 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
480 |
+
text = reader.read()
|
481 |
+
return json.loads(text)
|
482 |
+
|
483 |
+
def __repr__(self):
|
484 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
485 |
+
|
486 |
+
@property
|
487 |
+
def config(self) -> Dict[str, Any]:
|
488 |
+
"""
|
489 |
+
Returns the config of the class as a frozen dictionary
|
490 |
+
|
491 |
+
Returns:
|
492 |
+
`Dict[str, Any]`: Config of the class.
|
493 |
+
"""
|
494 |
+
return self._internal_dict
|
495 |
+
|
496 |
+
def to_json_string(self) -> str:
|
497 |
+
"""
|
498 |
+
Serializes this instance to a JSON string.
|
499 |
+
|
500 |
+
Returns:
|
501 |
+
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
502 |
+
"""
|
503 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
504 |
+
config_dict["_class_name"] = self.__class__.__name__
|
505 |
+
config_dict["_diffusers_version"] = __version__
|
506 |
+
|
507 |
+
def to_json_saveable(value):
|
508 |
+
if isinstance(value, np.ndarray):
|
509 |
+
value = value.tolist()
|
510 |
+
return value
|
511 |
+
|
512 |
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
513 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
514 |
+
|
515 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
516 |
+
"""
|
517 |
+
Save this instance to a JSON file.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
json_file_path (`str` or `os.PathLike`):
|
521 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
522 |
+
"""
|
523 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
524 |
+
writer.write(self.to_json_string())
|
525 |
+
|
526 |
+
|
527 |
+
def register_to_config(init):
|
528 |
+
r"""
|
529 |
+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
530 |
+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
531 |
+
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
532 |
+
|
533 |
+
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
534 |
+
"""
|
535 |
+
|
536 |
+
@functools.wraps(init)
|
537 |
+
def inner_init(self, *args, **kwargs):
|
538 |
+
# Ignore private kwargs in the init.
|
539 |
+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
540 |
+
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
541 |
+
if not isinstance(self, ConfigMixin):
|
542 |
+
raise RuntimeError(
|
543 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
544 |
+
"not inherit from `ConfigMixin`."
|
545 |
+
)
|
546 |
+
|
547 |
+
ignore = getattr(self, "ignore_for_config", [])
|
548 |
+
# Get positional arguments aligned with kwargs
|
549 |
+
new_kwargs = {}
|
550 |
+
signature = inspect.signature(init)
|
551 |
+
parameters = {
|
552 |
+
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
553 |
+
}
|
554 |
+
for arg, name in zip(args, parameters.keys()):
|
555 |
+
new_kwargs[name] = arg
|
556 |
+
|
557 |
+
# Then add all kwargs
|
558 |
+
new_kwargs.update(
|
559 |
+
{
|
560 |
+
k: init_kwargs.get(k, default)
|
561 |
+
for k, default in parameters.items()
|
562 |
+
if k not in ignore and k not in new_kwargs
|
563 |
+
}
|
564 |
+
)
|
565 |
+
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
566 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
567 |
+
init(self, *args, **init_kwargs)
|
568 |
+
|
569 |
+
return inner_init
|
570 |
+
|
571 |
+
|
572 |
+
def flax_register_to_config(cls):
|
573 |
+
original_init = cls.__init__
|
574 |
+
|
575 |
+
@functools.wraps(original_init)
|
576 |
+
def init(self, *args, **kwargs):
|
577 |
+
if not isinstance(self, ConfigMixin):
|
578 |
+
raise RuntimeError(
|
579 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
580 |
+
"not inherit from `ConfigMixin`."
|
581 |
+
)
|
582 |
+
|
583 |
+
# Ignore private kwargs in the init. Retrieve all passed attributes
|
584 |
+
init_kwargs = {k: v for k, v in kwargs.items()}
|
585 |
+
|
586 |
+
# Retrieve default values
|
587 |
+
fields = dataclasses.fields(self)
|
588 |
+
default_kwargs = {}
|
589 |
+
for field in fields:
|
590 |
+
# ignore flax specific attributes
|
591 |
+
if field.name in self._flax_internal_args:
|
592 |
+
continue
|
593 |
+
if type(field.default) == dataclasses._MISSING_TYPE:
|
594 |
+
default_kwargs[field.name] = None
|
595 |
+
else:
|
596 |
+
default_kwargs[field.name] = getattr(self, field.name)
|
597 |
+
|
598 |
+
# Make sure init_kwargs override default kwargs
|
599 |
+
new_kwargs = {**default_kwargs, **init_kwargs}
|
600 |
+
# dtype should be part of `init_kwargs`, but not `new_kwargs`
|
601 |
+
if "dtype" in new_kwargs:
|
602 |
+
new_kwargs.pop("dtype")
|
603 |
+
|
604 |
+
# Get positional arguments aligned with kwargs
|
605 |
+
for i, arg in enumerate(args):
|
606 |
+
name = fields[i].name
|
607 |
+
new_kwargs[name] = arg
|
608 |
+
|
609 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
610 |
+
original_init(self, *args, **kwargs)
|
611 |
+
|
612 |
+
cls.__init__ = init
|
613 |
+
return cls
|
diffusers/dependency_versions_check.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import sys
|
15 |
+
|
16 |
+
from .dependency_versions_table import deps
|
17 |
+
from .utils.versions import require_version, require_version_core
|
18 |
+
|
19 |
+
|
20 |
+
# define which module versions we always want to check at run time
|
21 |
+
# (usually the ones defined in `install_requires` in setup.py)
|
22 |
+
#
|
23 |
+
# order specific notes:
|
24 |
+
# - tqdm must be checked before tokenizers
|
25 |
+
|
26 |
+
pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
|
27 |
+
if sys.version_info < (3, 7):
|
28 |
+
pkgs_to_check_at_runtime.append("dataclasses")
|
29 |
+
if sys.version_info < (3, 8):
|
30 |
+
pkgs_to_check_at_runtime.append("importlib_metadata")
|
31 |
+
|
32 |
+
for pkg in pkgs_to_check_at_runtime:
|
33 |
+
if pkg in deps:
|
34 |
+
if pkg == "tokenizers":
|
35 |
+
# must be loaded here, or else tqdm check may fail
|
36 |
+
from .utils import is_tokenizers_available
|
37 |
+
|
38 |
+
if not is_tokenizers_available():
|
39 |
+
continue # not required, check version only if installed
|
40 |
+
|
41 |
+
require_version_core(deps[pkg])
|
42 |
+
else:
|
43 |
+
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
44 |
+
|
45 |
+
|
46 |
+
def dep_version_check(pkg, hint=None):
|
47 |
+
require_version(deps[pkg], hint)
|
diffusers/dependency_versions_table.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
2 |
+
# 1. modify the `_deps` dict in setup.py
|
3 |
+
# 2. run `make deps_table_update``
|
4 |
+
deps = {
|
5 |
+
"Pillow": "Pillow",
|
6 |
+
"accelerate": "accelerate>=0.11.0",
|
7 |
+
"black": "black==22.8",
|
8 |
+
"datasets": "datasets",
|
9 |
+
"filelock": "filelock",
|
10 |
+
"flake8": "flake8>=3.8.3",
|
11 |
+
"flax": "flax>=0.4.1",
|
12 |
+
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
13 |
+
"huggingface-hub": "huggingface-hub>=0.10.0",
|
14 |
+
"importlib_metadata": "importlib_metadata",
|
15 |
+
"isort": "isort>=5.5.4",
|
16 |
+
"jax": "jax>=0.2.8,!=0.3.2",
|
17 |
+
"jaxlib": "jaxlib>=0.1.65",
|
18 |
+
"k-diffusion": "k-diffusion",
|
19 |
+
"librosa": "librosa",
|
20 |
+
"modelcards": "modelcards>=0.1.4",
|
21 |
+
"numpy": "numpy",
|
22 |
+
"parameterized": "parameterized",
|
23 |
+
"pytest": "pytest",
|
24 |
+
"pytest-timeout": "pytest-timeout",
|
25 |
+
"pytest-xdist": "pytest-xdist",
|
26 |
+
"safetensors": "safetensors",
|
27 |
+
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
28 |
+
"scipy": "scipy",
|
29 |
+
"regex": "regex!=2019.12.17",
|
30 |
+
"requests": "requests",
|
31 |
+
"tensorboard": "tensorboard",
|
32 |
+
"torch": "torch>=1.4",
|
33 |
+
"torchvision": "torchvision",
|
34 |
+
"transformers": "transformers>=4.25.1",
|
35 |
+
}
|
diffusers/dynamic_modules_utils.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Utilities to dynamically load objects from the Hub."""
|
16 |
+
|
17 |
+
import importlib
|
18 |
+
import inspect
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
import shutil
|
22 |
+
import sys
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Dict, Optional, Union
|
25 |
+
|
26 |
+
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
|
27 |
+
|
28 |
+
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
29 |
+
|
30 |
+
|
31 |
+
COMMUNITY_PIPELINES_URL = (
|
32 |
+
"https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py"
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
|
39 |
+
def init_hf_modules():
|
40 |
+
"""
|
41 |
+
Creates the cache directory for modules with an init, and adds it to the Python path.
|
42 |
+
"""
|
43 |
+
# This function has already been executed if HF_MODULES_CACHE already is in the Python path.
|
44 |
+
if HF_MODULES_CACHE in sys.path:
|
45 |
+
return
|
46 |
+
|
47 |
+
sys.path.append(HF_MODULES_CACHE)
|
48 |
+
os.makedirs(HF_MODULES_CACHE, exist_ok=True)
|
49 |
+
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
|
50 |
+
if not init_path.exists():
|
51 |
+
init_path.touch()
|
52 |
+
|
53 |
+
|
54 |
+
def create_dynamic_module(name: Union[str, os.PathLike]):
|
55 |
+
"""
|
56 |
+
Creates a dynamic module in the cache directory for modules.
|
57 |
+
"""
|
58 |
+
init_hf_modules()
|
59 |
+
dynamic_module_path = Path(HF_MODULES_CACHE) / name
|
60 |
+
# If the parent module does not exist yet, recursively create it.
|
61 |
+
if not dynamic_module_path.parent.exists():
|
62 |
+
create_dynamic_module(dynamic_module_path.parent)
|
63 |
+
os.makedirs(dynamic_module_path, exist_ok=True)
|
64 |
+
init_path = dynamic_module_path / "__init__.py"
|
65 |
+
if not init_path.exists():
|
66 |
+
init_path.touch()
|
67 |
+
|
68 |
+
|
69 |
+
def get_relative_imports(module_file):
|
70 |
+
"""
|
71 |
+
Get the list of modules that are relatively imported in a module file.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
module_file (`str` or `os.PathLike`): The module file to inspect.
|
75 |
+
"""
|
76 |
+
with open(module_file, "r", encoding="utf-8") as f:
|
77 |
+
content = f.read()
|
78 |
+
|
79 |
+
# Imports of the form `import .xxx`
|
80 |
+
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
|
81 |
+
# Imports of the form `from .xxx import yyy`
|
82 |
+
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
|
83 |
+
# Unique-ify
|
84 |
+
return list(set(relative_imports))
|
85 |
+
|
86 |
+
|
87 |
+
def get_relative_import_files(module_file):
|
88 |
+
"""
|
89 |
+
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
|
90 |
+
imports (if a imports b and b imports c, it will return module files for b and c).
|
91 |
+
|
92 |
+
Args:
|
93 |
+
module_file (`str` or `os.PathLike`): The module file to inspect.
|
94 |
+
"""
|
95 |
+
no_change = False
|
96 |
+
files_to_check = [module_file]
|
97 |
+
all_relative_imports = []
|
98 |
+
|
99 |
+
# Let's recurse through all relative imports
|
100 |
+
while not no_change:
|
101 |
+
new_imports = []
|
102 |
+
for f in files_to_check:
|
103 |
+
new_imports.extend(get_relative_imports(f))
|
104 |
+
|
105 |
+
module_path = Path(module_file).parent
|
106 |
+
new_import_files = [str(module_path / m) for m in new_imports]
|
107 |
+
new_import_files = [f for f in new_import_files if f not in all_relative_imports]
|
108 |
+
files_to_check = [f"{f}.py" for f in new_import_files]
|
109 |
+
|
110 |
+
no_change = len(new_import_files) == 0
|
111 |
+
all_relative_imports.extend(files_to_check)
|
112 |
+
|
113 |
+
return all_relative_imports
|
114 |
+
|
115 |
+
|
116 |
+
def check_imports(filename):
|
117 |
+
"""
|
118 |
+
Check if the current Python environment contains all the libraries that are imported in a file.
|
119 |
+
"""
|
120 |
+
with open(filename, "r", encoding="utf-8") as f:
|
121 |
+
content = f.read()
|
122 |
+
|
123 |
+
# Imports of the form `import xxx`
|
124 |
+
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
125 |
+
# Imports of the form `from xxx import yyy`
|
126 |
+
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
127 |
+
# Only keep the top-level module
|
128 |
+
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
129 |
+
|
130 |
+
# Unique-ify and test we got them all
|
131 |
+
imports = list(set(imports))
|
132 |
+
missing_packages = []
|
133 |
+
for imp in imports:
|
134 |
+
try:
|
135 |
+
importlib.import_module(imp)
|
136 |
+
except ImportError:
|
137 |
+
missing_packages.append(imp)
|
138 |
+
|
139 |
+
if len(missing_packages) > 0:
|
140 |
+
raise ImportError(
|
141 |
+
"This modeling file requires the following packages that were not found in your environment: "
|
142 |
+
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
|
143 |
+
)
|
144 |
+
|
145 |
+
return get_relative_imports(filename)
|
146 |
+
|
147 |
+
|
148 |
+
def get_class_in_module(class_name, module_path):
|
149 |
+
"""
|
150 |
+
Import a module on the cache directory for modules and extract a class from it.
|
151 |
+
"""
|
152 |
+
module_path = module_path.replace(os.path.sep, ".")
|
153 |
+
module = importlib.import_module(module_path)
|
154 |
+
|
155 |
+
if class_name is None:
|
156 |
+
return find_pipeline_class(module)
|
157 |
+
return getattr(module, class_name)
|
158 |
+
|
159 |
+
|
160 |
+
def find_pipeline_class(loaded_module):
|
161 |
+
"""
|
162 |
+
Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
|
163 |
+
inheriting from `DiffusionPipeline`.
|
164 |
+
"""
|
165 |
+
from .pipeline_utils import DiffusionPipeline
|
166 |
+
|
167 |
+
cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
|
168 |
+
|
169 |
+
pipeline_class = None
|
170 |
+
for cls_name, cls in cls_members.items():
|
171 |
+
if (
|
172 |
+
cls_name != DiffusionPipeline.__name__
|
173 |
+
and issubclass(cls, DiffusionPipeline)
|
174 |
+
and cls.__module__.split(".")[0] != "diffusers"
|
175 |
+
):
|
176 |
+
if pipeline_class is not None:
|
177 |
+
raise ValueError(
|
178 |
+
f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
|
179 |
+
f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
|
180 |
+
f" {loaded_module}."
|
181 |
+
)
|
182 |
+
pipeline_class = cls
|
183 |
+
|
184 |
+
return pipeline_class
|
185 |
+
|
186 |
+
|
187 |
+
def get_cached_module_file(
|
188 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
189 |
+
module_file: str,
|
190 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
191 |
+
force_download: bool = False,
|
192 |
+
resume_download: bool = False,
|
193 |
+
proxies: Optional[Dict[str, str]] = None,
|
194 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
195 |
+
revision: Optional[str] = None,
|
196 |
+
local_files_only: bool = False,
|
197 |
+
):
|
198 |
+
"""
|
199 |
+
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
200 |
+
Transformers module.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
204 |
+
This can be either:
|
205 |
+
|
206 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
207 |
+
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
|
208 |
+
under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
209 |
+
- a path to a *directory* containing a configuration file saved using the
|
210 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
211 |
+
|
212 |
+
module_file (`str`):
|
213 |
+
The name of the module file containing the class to look for.
|
214 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
215 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
216 |
+
cache should not be used.
|
217 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
218 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
219 |
+
exist.
|
220 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
221 |
+
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
222 |
+
proxies (`Dict[str, str]`, *optional*):
|
223 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
224 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
225 |
+
use_auth_token (`str` or *bool*, *optional*):
|
226 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
227 |
+
when running `transformers-cli login` (stored in `~/.huggingface`).
|
228 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
229 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
230 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
231 |
+
identifier allowed by git.
|
232 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
233 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
234 |
+
|
235 |
+
<Tip>
|
236 |
+
|
237 |
+
You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
|
238 |
+
or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
239 |
+
|
240 |
+
</Tip>
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
`str`: The path to the module inside the cache.
|
244 |
+
"""
|
245 |
+
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
246 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
247 |
+
|
248 |
+
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
249 |
+
|
250 |
+
if os.path.isfile(module_file_or_url):
|
251 |
+
resolved_module_file = module_file_or_url
|
252 |
+
submodule = "local"
|
253 |
+
elif pretrained_model_name_or_path.count("/") == 0:
|
254 |
+
# community pipeline on GitHub
|
255 |
+
github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path)
|
256 |
+
try:
|
257 |
+
resolved_module_file = cached_download(
|
258 |
+
github_url,
|
259 |
+
cache_dir=cache_dir,
|
260 |
+
force_download=force_download,
|
261 |
+
proxies=proxies,
|
262 |
+
resume_download=resume_download,
|
263 |
+
local_files_only=local_files_only,
|
264 |
+
use_auth_token=False,
|
265 |
+
)
|
266 |
+
submodule = "git"
|
267 |
+
module_file = pretrained_model_name_or_path + ".py"
|
268 |
+
except EnvironmentError:
|
269 |
+
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
270 |
+
raise
|
271 |
+
else:
|
272 |
+
try:
|
273 |
+
# Load from URL or cache if already cached
|
274 |
+
resolved_module_file = hf_hub_download(
|
275 |
+
pretrained_model_name_or_path,
|
276 |
+
module_file,
|
277 |
+
cache_dir=cache_dir,
|
278 |
+
force_download=force_download,
|
279 |
+
proxies=proxies,
|
280 |
+
resume_download=resume_download,
|
281 |
+
local_files_only=local_files_only,
|
282 |
+
use_auth_token=use_auth_token,
|
283 |
+
)
|
284 |
+
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
|
285 |
+
except EnvironmentError:
|
286 |
+
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
287 |
+
raise
|
288 |
+
|
289 |
+
# Check we have all the requirements in our environment
|
290 |
+
modules_needed = check_imports(resolved_module_file)
|
291 |
+
|
292 |
+
# Now we move the module inside our cached dynamic modules.
|
293 |
+
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
294 |
+
create_dynamic_module(full_submodule)
|
295 |
+
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
296 |
+
if submodule == "local" or submodule == "git":
|
297 |
+
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
298 |
+
# that hash, to only copy when there is a modification but it seems overkill for now).
|
299 |
+
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
300 |
+
shutil.copy(resolved_module_file, submodule_path / module_file)
|
301 |
+
for module_needed in modules_needed:
|
302 |
+
module_needed = f"{module_needed}.py"
|
303 |
+
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
304 |
+
else:
|
305 |
+
# Get the commit hash
|
306 |
+
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
307 |
+
if isinstance(use_auth_token, str):
|
308 |
+
token = use_auth_token
|
309 |
+
elif use_auth_token is True:
|
310 |
+
token = HfFolder.get_token()
|
311 |
+
else:
|
312 |
+
token = None
|
313 |
+
|
314 |
+
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
|
315 |
+
|
316 |
+
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
|
317 |
+
# benefit of versioning.
|
318 |
+
submodule_path = submodule_path / commit_hash
|
319 |
+
full_submodule = full_submodule + os.path.sep + commit_hash
|
320 |
+
create_dynamic_module(full_submodule)
|
321 |
+
|
322 |
+
if not (submodule_path / module_file).exists():
|
323 |
+
shutil.copy(resolved_module_file, submodule_path / module_file)
|
324 |
+
# Make sure we also have every file with relative
|
325 |
+
for module_needed in modules_needed:
|
326 |
+
if not (submodule_path / module_needed).exists():
|
327 |
+
get_cached_module_file(
|
328 |
+
pretrained_model_name_or_path,
|
329 |
+
f"{module_needed}.py",
|
330 |
+
cache_dir=cache_dir,
|
331 |
+
force_download=force_download,
|
332 |
+
resume_download=resume_download,
|
333 |
+
proxies=proxies,
|
334 |
+
use_auth_token=use_auth_token,
|
335 |
+
revision=revision,
|
336 |
+
local_files_only=local_files_only,
|
337 |
+
)
|
338 |
+
return os.path.join(full_submodule, module_file)
|
339 |
+
|
340 |
+
|
341 |
+
def get_class_from_dynamic_module(
|
342 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
343 |
+
module_file: str,
|
344 |
+
class_name: Optional[str] = None,
|
345 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
346 |
+
force_download: bool = False,
|
347 |
+
resume_download: bool = False,
|
348 |
+
proxies: Optional[Dict[str, str]] = None,
|
349 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
350 |
+
revision: Optional[str] = None,
|
351 |
+
local_files_only: bool = False,
|
352 |
+
**kwargs,
|
353 |
+
):
|
354 |
+
"""
|
355 |
+
Extracts a class from a module file, present in the local folder or repository of a model.
|
356 |
+
|
357 |
+
<Tip warning={true}>
|
358 |
+
|
359 |
+
Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
|
360 |
+
therefore only be called on trusted repos.
|
361 |
+
|
362 |
+
</Tip>
|
363 |
+
|
364 |
+
Args:
|
365 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
366 |
+
This can be either:
|
367 |
+
|
368 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
369 |
+
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
|
370 |
+
under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
371 |
+
- a path to a *directory* containing a configuration file saved using the
|
372 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
373 |
+
|
374 |
+
module_file (`str`):
|
375 |
+
The name of the module file containing the class to look for.
|
376 |
+
class_name (`str`):
|
377 |
+
The name of the class to import in the module.
|
378 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
379 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
380 |
+
cache should not be used.
|
381 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
382 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
383 |
+
exist.
|
384 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
385 |
+
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
386 |
+
proxies (`Dict[str, str]`, *optional*):
|
387 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
388 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
389 |
+
use_auth_token (`str` or `bool`, *optional*):
|
390 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
391 |
+
when running `transformers-cli login` (stored in `~/.huggingface`).
|
392 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
393 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
394 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
395 |
+
identifier allowed by git.
|
396 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
397 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
398 |
+
|
399 |
+
<Tip>
|
400 |
+
|
401 |
+
You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
|
402 |
+
or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
403 |
+
|
404 |
+
</Tip>
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
`type`: The class, dynamically imported from the module.
|
408 |
+
|
409 |
+
Examples:
|
410 |
+
|
411 |
+
```python
|
412 |
+
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
|
413 |
+
# module.
|
414 |
+
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
|
415 |
+
```"""
|
416 |
+
# And lastly we get the class inside our newly created module
|
417 |
+
final_module = get_cached_module_file(
|
418 |
+
pretrained_model_name_or_path,
|
419 |
+
module_file,
|
420 |
+
cache_dir=cache_dir,
|
421 |
+
force_download=force_download,
|
422 |
+
resume_download=resume_download,
|
423 |
+
proxies=proxies,
|
424 |
+
use_auth_token=use_auth_token,
|
425 |
+
revision=revision,
|
426 |
+
local_files_only=local_files_only,
|
427 |
+
)
|
428 |
+
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
diffusers/experimental/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .rl import ValueGuidedRLPipeline
|
diffusers/experimental/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (210 Bytes). View file
|
|
diffusers/experimental/rl/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .value_guided_sampling import ValueGuidedRLPipeline
|
diffusers/experimental/rl/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (232 Bytes). View file
|
|
diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-310.pyc
ADDED
Binary file (4.84 kB). View file
|
|
diffusers/experimental/rl/value_guided_sampling.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
|
18 |
+
import tqdm
|
19 |
+
|
20 |
+
from ...models.unet_1d import UNet1DModel
|
21 |
+
from ...pipeline_utils import DiffusionPipeline
|
22 |
+
from ...utils.dummy_pt_objects import DDPMScheduler
|
23 |
+
|
24 |
+
|
25 |
+
class ValueGuidedRLPipeline(DiffusionPipeline):
|
26 |
+
r"""
|
27 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
28 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
29 |
+
Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
|
30 |
+
|
31 |
+
Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
|
35 |
+
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
|
36 |
+
scheduler ([`SchedulerMixin`]):
|
37 |
+
A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
|
38 |
+
application is [`DDPMScheduler`].
|
39 |
+
env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
value_function: UNet1DModel,
|
45 |
+
unet: UNet1DModel,
|
46 |
+
scheduler: DDPMScheduler,
|
47 |
+
env,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.value_function = value_function
|
51 |
+
self.unet = unet
|
52 |
+
self.scheduler = scheduler
|
53 |
+
self.env = env
|
54 |
+
self.data = env.get_dataset()
|
55 |
+
self.means = dict()
|
56 |
+
for key in self.data.keys():
|
57 |
+
try:
|
58 |
+
self.means[key] = self.data[key].mean()
|
59 |
+
except:
|
60 |
+
pass
|
61 |
+
self.stds = dict()
|
62 |
+
for key in self.data.keys():
|
63 |
+
try:
|
64 |
+
self.stds[key] = self.data[key].std()
|
65 |
+
except:
|
66 |
+
pass
|
67 |
+
self.state_dim = env.observation_space.shape[0]
|
68 |
+
self.action_dim = env.action_space.shape[0]
|
69 |
+
|
70 |
+
def normalize(self, x_in, key):
|
71 |
+
return (x_in - self.means[key]) / self.stds[key]
|
72 |
+
|
73 |
+
def de_normalize(self, x_in, key):
|
74 |
+
return x_in * self.stds[key] + self.means[key]
|
75 |
+
|
76 |
+
def to_torch(self, x_in):
|
77 |
+
if type(x_in) is dict:
|
78 |
+
return {k: self.to_torch(v) for k, v in x_in.items()}
|
79 |
+
elif torch.is_tensor(x_in):
|
80 |
+
return x_in.to(self.unet.device)
|
81 |
+
return torch.tensor(x_in, device=self.unet.device)
|
82 |
+
|
83 |
+
def reset_x0(self, x_in, cond, act_dim):
|
84 |
+
for key, val in cond.items():
|
85 |
+
x_in[:, key, act_dim:] = val.clone()
|
86 |
+
return x_in
|
87 |
+
|
88 |
+
def run_diffusion(self, x, conditions, n_guide_steps, scale):
|
89 |
+
batch_size = x.shape[0]
|
90 |
+
y = None
|
91 |
+
for i in tqdm.tqdm(self.scheduler.timesteps):
|
92 |
+
# create batch of timesteps to pass into model
|
93 |
+
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
|
94 |
+
for _ in range(n_guide_steps):
|
95 |
+
with torch.enable_grad():
|
96 |
+
x.requires_grad_()
|
97 |
+
|
98 |
+
# permute to match dimension for pre-trained models
|
99 |
+
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
|
100 |
+
grad = torch.autograd.grad([y.sum()], [x])[0]
|
101 |
+
|
102 |
+
posterior_variance = self.scheduler._get_variance(i)
|
103 |
+
model_std = torch.exp(0.5 * posterior_variance)
|
104 |
+
grad = model_std * grad
|
105 |
+
|
106 |
+
grad[timesteps < 2] = 0
|
107 |
+
x = x.detach()
|
108 |
+
x = x + scale * grad
|
109 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
110 |
+
|
111 |
+
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
112 |
+
|
113 |
+
# TODO: verify deprecation of this kwarg
|
114 |
+
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
115 |
+
|
116 |
+
# apply conditions to the trajectory (set the initial state)
|
117 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
118 |
+
x = self.to_torch(x)
|
119 |
+
return x, y
|
120 |
+
|
121 |
+
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
|
122 |
+
# normalize the observations and create batch dimension
|
123 |
+
obs = self.normalize(obs, "observations")
|
124 |
+
obs = obs[None].repeat(batch_size, axis=0)
|
125 |
+
|
126 |
+
conditions = {0: self.to_torch(obs)}
|
127 |
+
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
|
128 |
+
|
129 |
+
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
130 |
+
x1 = torch.randn(shape, device=self.unet.device)
|
131 |
+
x = self.reset_x0(x1, conditions, self.action_dim)
|
132 |
+
x = self.to_torch(x)
|
133 |
+
|
134 |
+
# run the diffusion process
|
135 |
+
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
|
136 |
+
|
137 |
+
# sort output trajectories by value
|
138 |
+
sorted_idx = y.argsort(0, descending=True).squeeze()
|
139 |
+
sorted_values = x[sorted_idx]
|
140 |
+
actions = sorted_values[:, :, : self.action_dim]
|
141 |
+
actions = actions.detach().cpu().numpy()
|
142 |
+
denorm_actions = self.de_normalize(actions, key="actions")
|
143 |
+
|
144 |
+
# select the action with the highest value
|
145 |
+
if y is not None:
|
146 |
+
selected_index = 0
|
147 |
+
else:
|
148 |
+
# if we didn't run value guiding, select a random action
|
149 |
+
selected_index = np.random.randint(0, batch_size)
|
150 |
+
|
151 |
+
denorm_actions = denorm_actions[selected_index, 0]
|
152 |
+
return denorm_actions
|
diffusers/hub_utils.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import os
|
18 |
+
import sys
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import Dict, Optional, Union
|
21 |
+
from uuid import uuid4
|
22 |
+
|
23 |
+
import requests
|
24 |
+
from huggingface_hub import HfFolder, whoami
|
25 |
+
|
26 |
+
from . import __version__
|
27 |
+
from .utils import ENV_VARS_TRUE_VALUES, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
28 |
+
from .utils.import_utils import (
|
29 |
+
_flax_version,
|
30 |
+
_jax_version,
|
31 |
+
_onnxruntime_version,
|
32 |
+
_torch_version,
|
33 |
+
is_flax_available,
|
34 |
+
is_modelcards_available,
|
35 |
+
is_onnx_available,
|
36 |
+
is_torch_available,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
if is_modelcards_available():
|
41 |
+
from modelcards import CardData, ModelCard
|
42 |
+
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__)
|
45 |
+
|
46 |
+
|
47 |
+
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
|
48 |
+
SESSION_ID = uuid4().hex
|
49 |
+
HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
|
50 |
+
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
|
51 |
+
HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
|
52 |
+
|
53 |
+
|
54 |
+
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
55 |
+
"""
|
56 |
+
Formats a user-agent string with basic info about a request.
|
57 |
+
"""
|
58 |
+
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
|
59 |
+
if DISABLE_TELEMETRY:
|
60 |
+
return ua + "; telemetry/off"
|
61 |
+
if is_torch_available():
|
62 |
+
ua += f"; torch/{_torch_version}"
|
63 |
+
if is_flax_available():
|
64 |
+
ua += f"; jax/{_jax_version}"
|
65 |
+
ua += f"; flax/{_flax_version}"
|
66 |
+
if is_onnx_available():
|
67 |
+
ua += f"; onnxruntime/{_onnxruntime_version}"
|
68 |
+
# CI will set this value to True
|
69 |
+
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
|
70 |
+
ua += "; is_ci/true"
|
71 |
+
if isinstance(user_agent, dict):
|
72 |
+
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
|
73 |
+
elif isinstance(user_agent, str):
|
74 |
+
ua += "; " + user_agent
|
75 |
+
return ua
|
76 |
+
|
77 |
+
|
78 |
+
def send_telemetry(data: Dict, name: str):
|
79 |
+
"""
|
80 |
+
Sends logs to the Hub telemetry endpoint.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
data: the fields to track, e.g. {"example_name": "dreambooth"}
|
84 |
+
name: a unique name to differentiate the telemetry logs, e.g. "diffusers_examples" or "diffusers_notebooks"
|
85 |
+
"""
|
86 |
+
if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
|
87 |
+
pass
|
88 |
+
|
89 |
+
headers = {"user-agent": http_user_agent(data)}
|
90 |
+
endpoint = HUGGINGFACE_CO_TELEMETRY + name
|
91 |
+
try:
|
92 |
+
r = requests.head(endpoint, headers=headers)
|
93 |
+
r.raise_for_status()
|
94 |
+
except Exception:
|
95 |
+
# We don't want to error in case of connection errors of any kind.
|
96 |
+
pass
|
97 |
+
|
98 |
+
|
99 |
+
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
100 |
+
if token is None:
|
101 |
+
token = HfFolder.get_token()
|
102 |
+
if organization is None:
|
103 |
+
username = whoami(token)["name"]
|
104 |
+
return f"{username}/{model_id}"
|
105 |
+
else:
|
106 |
+
return f"{organization}/{model_id}"
|
107 |
+
|
108 |
+
|
109 |
+
def create_model_card(args, model_name):
|
110 |
+
if not is_modelcards_available:
|
111 |
+
raise ValueError(
|
112 |
+
"Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
|
113 |
+
" install the package with `pip install modelcards`."
|
114 |
+
)
|
115 |
+
|
116 |
+
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
|
117 |
+
return
|
118 |
+
|
119 |
+
hub_token = args.hub_token if hasattr(args, "hub_token") else None
|
120 |
+
repo_name = get_full_repo_name(model_name, token=hub_token)
|
121 |
+
|
122 |
+
model_card = ModelCard.from_template(
|
123 |
+
card_data=CardData( # Card metadata object that will be converted to YAML block
|
124 |
+
language="en",
|
125 |
+
license="apache-2.0",
|
126 |
+
library_name="diffusers",
|
127 |
+
tags=[],
|
128 |
+
datasets=args.dataset_name,
|
129 |
+
metrics=[],
|
130 |
+
),
|
131 |
+
template_path=MODEL_CARD_TEMPLATE_PATH,
|
132 |
+
model_name=model_name,
|
133 |
+
repo_name=repo_name,
|
134 |
+
dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
|
135 |
+
learning_rate=args.learning_rate,
|
136 |
+
train_batch_size=args.train_batch_size,
|
137 |
+
eval_batch_size=args.eval_batch_size,
|
138 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps
|
139 |
+
if hasattr(args, "gradient_accumulation_steps")
|
140 |
+
else None,
|
141 |
+
adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
|
142 |
+
adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
|
143 |
+
adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
|
144 |
+
adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
|
145 |
+
lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
|
146 |
+
lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
|
147 |
+
ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
|
148 |
+
ema_power=args.ema_power if hasattr(args, "ema_power") else None,
|
149 |
+
ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
|
150 |
+
mixed_precision=args.mixed_precision,
|
151 |
+
)
|
152 |
+
|
153 |
+
card_path = os.path.join(args.output_dir, "README.md")
|
154 |
+
model_card.save(card_path)
|
diffusers/modeling_flax_pytorch_utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch - Flax general utilities."""
|
16 |
+
import re
|
17 |
+
|
18 |
+
import jax.numpy as jnp
|
19 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
20 |
+
from jax.random import PRNGKey
|
21 |
+
|
22 |
+
from .utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
def rename_key(key):
|
29 |
+
regex = r"\w+[.]\d+"
|
30 |
+
pats = re.findall(regex, key)
|
31 |
+
for pat in pats:
|
32 |
+
key = key.replace(pat, "_".join(pat.split(".")))
|
33 |
+
return key
|
34 |
+
|
35 |
+
|
36 |
+
#####################
|
37 |
+
# PyTorch => Flax #
|
38 |
+
#####################
|
39 |
+
|
40 |
+
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
|
41 |
+
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
|
42 |
+
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
|
43 |
+
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
|
44 |
+
|
45 |
+
# conv norm or layer norm
|
46 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
47 |
+
if (
|
48 |
+
any("norm" in str_ for str_ in pt_tuple_key)
|
49 |
+
and (pt_tuple_key[-1] == "bias")
|
50 |
+
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
|
51 |
+
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
|
52 |
+
):
|
53 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
54 |
+
return renamed_pt_tuple_key, pt_tensor
|
55 |
+
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
56 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
57 |
+
return renamed_pt_tuple_key, pt_tensor
|
58 |
+
|
59 |
+
# embedding
|
60 |
+
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
61 |
+
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
62 |
+
return renamed_pt_tuple_key, pt_tensor
|
63 |
+
|
64 |
+
# conv layer
|
65 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
66 |
+
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
|
67 |
+
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
68 |
+
return renamed_pt_tuple_key, pt_tensor
|
69 |
+
|
70 |
+
# linear layer
|
71 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
72 |
+
if pt_tuple_key[-1] == "weight":
|
73 |
+
pt_tensor = pt_tensor.T
|
74 |
+
return renamed_pt_tuple_key, pt_tensor
|
75 |
+
|
76 |
+
# old PyTorch layer norm weight
|
77 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
78 |
+
if pt_tuple_key[-1] == "gamma":
|
79 |
+
return renamed_pt_tuple_key, pt_tensor
|
80 |
+
|
81 |
+
# old PyTorch layer norm bias
|
82 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
83 |
+
if pt_tuple_key[-1] == "beta":
|
84 |
+
return renamed_pt_tuple_key, pt_tensor
|
85 |
+
|
86 |
+
return pt_tuple_key, pt_tensor
|
87 |
+
|
88 |
+
|
89 |
+
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
|
90 |
+
# Step 1: Convert pytorch tensor to numpy
|
91 |
+
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
92 |
+
|
93 |
+
# Step 2: Since the model is stateless, get random Flax params
|
94 |
+
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
|
95 |
+
|
96 |
+
random_flax_state_dict = flatten_dict(random_flax_params)
|
97 |
+
flax_state_dict = {}
|
98 |
+
|
99 |
+
# Need to change some parameters name to match Flax names
|
100 |
+
for pt_key, pt_tensor in pt_state_dict.items():
|
101 |
+
renamed_pt_key = rename_key(pt_key)
|
102 |
+
pt_tuple_key = tuple(renamed_pt_key.split("."))
|
103 |
+
|
104 |
+
# Correctly rename weight parameters
|
105 |
+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
|
106 |
+
|
107 |
+
if flax_key in random_flax_state_dict:
|
108 |
+
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
109 |
+
raise ValueError(
|
110 |
+
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
111 |
+
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
112 |
+
)
|
113 |
+
|
114 |
+
# also add unexpected weight so that warning is thrown
|
115 |
+
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
116 |
+
|
117 |
+
return unflatten_dict(flax_state_dict)
|
diffusers/modeling_flax_utils.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
from pickle import UnpicklingError
|
18 |
+
from typing import Any, Dict, Union
|
19 |
+
|
20 |
+
import jax
|
21 |
+
import jax.numpy as jnp
|
22 |
+
import msgpack.exceptions
|
23 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
24 |
+
from flax.serialization import from_bytes, to_bytes
|
25 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
26 |
+
from huggingface_hub import hf_hub_download
|
27 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
28 |
+
from requests import HTTPError
|
29 |
+
|
30 |
+
from . import __version__, is_torch_available
|
31 |
+
from .hub_utils import send_telemetry
|
32 |
+
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
33 |
+
from .utils import (
|
34 |
+
CONFIG_NAME,
|
35 |
+
DIFFUSERS_CACHE,
|
36 |
+
FLAX_WEIGHTS_NAME,
|
37 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
38 |
+
WEIGHTS_NAME,
|
39 |
+
logging,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
|
46 |
+
class FlaxModelMixin:
|
47 |
+
r"""
|
48 |
+
Base class for all flax models.
|
49 |
+
|
50 |
+
[`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
|
51 |
+
downloading and saving models.
|
52 |
+
"""
|
53 |
+
config_name = CONFIG_NAME
|
54 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
55 |
+
_flax_internal_args = ["name", "parent", "dtype"]
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def _from_config(cls, config, **kwargs):
|
59 |
+
"""
|
60 |
+
All context managers that the model should be initialized under go here.
|
61 |
+
"""
|
62 |
+
return cls(config, **kwargs)
|
63 |
+
|
64 |
+
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
65 |
+
"""
|
66 |
+
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
67 |
+
"""
|
68 |
+
|
69 |
+
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
70 |
+
def conditional_cast(param):
|
71 |
+
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
72 |
+
param = param.astype(dtype)
|
73 |
+
return param
|
74 |
+
|
75 |
+
if mask is None:
|
76 |
+
return jax.tree_map(conditional_cast, params)
|
77 |
+
|
78 |
+
flat_params = flatten_dict(params)
|
79 |
+
flat_mask, _ = jax.tree_flatten(mask)
|
80 |
+
|
81 |
+
for masked, key in zip(flat_mask, flat_params.keys()):
|
82 |
+
if masked:
|
83 |
+
param = flat_params[key]
|
84 |
+
flat_params[key] = conditional_cast(param)
|
85 |
+
|
86 |
+
return unflatten_dict(flat_params)
|
87 |
+
|
88 |
+
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
89 |
+
r"""
|
90 |
+
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
|
91 |
+
the `params` in place.
|
92 |
+
|
93 |
+
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
|
94 |
+
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
|
95 |
+
|
96 |
+
Arguments:
|
97 |
+
params (`Union[Dict, FrozenDict]`):
|
98 |
+
A `PyTree` of model parameters.
|
99 |
+
mask (`Union[Dict, FrozenDict]`):
|
100 |
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
|
101 |
+
you want to cast, and should be `False` for those you want to skip.
|
102 |
+
|
103 |
+
Examples:
|
104 |
+
|
105 |
+
```python
|
106 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
107 |
+
|
108 |
+
>>> # load model
|
109 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
110 |
+
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
|
111 |
+
>>> params = model.to_bf16(params)
|
112 |
+
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
|
113 |
+
>>> # then pass the mask as follows
|
114 |
+
>>> from flax import traverse_util
|
115 |
+
|
116 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
117 |
+
>>> flat_params = traverse_util.flatten_dict(params)
|
118 |
+
>>> mask = {
|
119 |
+
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
120 |
+
... for path in flat_params
|
121 |
+
... }
|
122 |
+
>>> mask = traverse_util.unflatten_dict(mask)
|
123 |
+
>>> params = model.to_bf16(params, mask)
|
124 |
+
```"""
|
125 |
+
return self._cast_floating_to(params, jnp.bfloat16, mask)
|
126 |
+
|
127 |
+
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
128 |
+
r"""
|
129 |
+
Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
|
130 |
+
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
|
131 |
+
|
132 |
+
Arguments:
|
133 |
+
params (`Union[Dict, FrozenDict]`):
|
134 |
+
A `PyTree` of model parameters.
|
135 |
+
mask (`Union[Dict, FrozenDict]`):
|
136 |
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
|
137 |
+
you want to cast, and should be `False` for those you want to skip
|
138 |
+
|
139 |
+
Examples:
|
140 |
+
|
141 |
+
```python
|
142 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
143 |
+
|
144 |
+
>>> # Download model and configuration from huggingface.co
|
145 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
146 |
+
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
|
147 |
+
>>> # we'll first cast to fp16 and back to fp32
|
148 |
+
>>> params = model.to_f16(params)
|
149 |
+
>>> # now cast back to fp32
|
150 |
+
>>> params = model.to_fp32(params)
|
151 |
+
```"""
|
152 |
+
return self._cast_floating_to(params, jnp.float32, mask)
|
153 |
+
|
154 |
+
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
155 |
+
r"""
|
156 |
+
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
|
157 |
+
`params` in place.
|
158 |
+
|
159 |
+
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
|
160 |
+
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
|
161 |
+
|
162 |
+
Arguments:
|
163 |
+
params (`Union[Dict, FrozenDict]`):
|
164 |
+
A `PyTree` of model parameters.
|
165 |
+
mask (`Union[Dict, FrozenDict]`):
|
166 |
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
|
167 |
+
you want to cast, and should be `False` for those you want to skip
|
168 |
+
|
169 |
+
Examples:
|
170 |
+
|
171 |
+
```python
|
172 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
173 |
+
|
174 |
+
>>> # load model
|
175 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
176 |
+
>>> # By default, the model params will be in fp32, to cast these to float16
|
177 |
+
>>> params = model.to_fp16(params)
|
178 |
+
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
179 |
+
>>> # then pass the mask as follows
|
180 |
+
>>> from flax import traverse_util
|
181 |
+
|
182 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
183 |
+
>>> flat_params = traverse_util.flatten_dict(params)
|
184 |
+
>>> mask = {
|
185 |
+
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
186 |
+
... for path in flat_params
|
187 |
+
... }
|
188 |
+
>>> mask = traverse_util.unflatten_dict(mask)
|
189 |
+
>>> params = model.to_fp16(params, mask)
|
190 |
+
```"""
|
191 |
+
return self._cast_floating_to(params, jnp.float16, mask)
|
192 |
+
|
193 |
+
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
|
194 |
+
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
195 |
+
|
196 |
+
@classmethod
|
197 |
+
def from_pretrained(
|
198 |
+
cls,
|
199 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
200 |
+
dtype: jnp.dtype = jnp.float32,
|
201 |
+
*model_args,
|
202 |
+
**kwargs,
|
203 |
+
):
|
204 |
+
r"""
|
205 |
+
Instantiate a pretrained flax model from a pre-trained model configuration.
|
206 |
+
|
207 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
208 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
209 |
+
task.
|
210 |
+
|
211 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
212 |
+
weights are discarded.
|
213 |
+
|
214 |
+
Parameters:
|
215 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
216 |
+
Can be either:
|
217 |
+
|
218 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
219 |
+
Valid model ids are namespaced under a user or organization name, like
|
220 |
+
`runwayml/stable-diffusion-v1-5`.
|
221 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
|
222 |
+
e.g., `./my_model_directory/`.
|
223 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
224 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
225 |
+
`jax.numpy.bfloat16` (on TPUs).
|
226 |
+
|
227 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
228 |
+
specified all the computation will be performed with the given `dtype`.
|
229 |
+
|
230 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
231 |
+
parameters.**
|
232 |
+
|
233 |
+
If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and
|
234 |
+
[`~ModelMixin.to_bf16`].
|
235 |
+
model_args (sequence of positional arguments, *optional*):
|
236 |
+
All remaining positional arguments will be passed to the underlying model's `__init__` method.
|
237 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
238 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
239 |
+
standard cache should not be used.
|
240 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
241 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
242 |
+
cached versions if they exist.
|
243 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
244 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
245 |
+
file exists.
|
246 |
+
proxies (`Dict[str, str]`, *optional*):
|
247 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
248 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
249 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
250 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
251 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
252 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
253 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
254 |
+
identifier allowed by git.
|
255 |
+
from_pt (`bool`, *optional*, defaults to `False`):
|
256 |
+
Load the model weights from a PyTorch checkpoint save file.
|
257 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
258 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
259 |
+
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
260 |
+
automatically loaded:
|
261 |
+
|
262 |
+
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
|
263 |
+
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
264 |
+
already been done)
|
265 |
+
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
266 |
+
initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to
|
267 |
+
a configuration attribute will be used to override said attribute with the supplied `kwargs`
|
268 |
+
value. Remaining keys that do not correspond to any configuration attribute will be passed to the
|
269 |
+
underlying model's `__init__` function.
|
270 |
+
|
271 |
+
Examples:
|
272 |
+
|
273 |
+
```python
|
274 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
275 |
+
|
276 |
+
>>> # Download model and configuration from huggingface.co and cache.
|
277 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
278 |
+
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
|
279 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
|
280 |
+
```"""
|
281 |
+
config = kwargs.pop("config", None)
|
282 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
283 |
+
force_download = kwargs.pop("force_download", False)
|
284 |
+
from_pt = kwargs.pop("from_pt", False)
|
285 |
+
resume_download = kwargs.pop("resume_download", False)
|
286 |
+
proxies = kwargs.pop("proxies", None)
|
287 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
288 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
289 |
+
revision = kwargs.pop("revision", None)
|
290 |
+
subfolder = kwargs.pop("subfolder", None)
|
291 |
+
|
292 |
+
user_agent = {
|
293 |
+
"diffusers": __version__,
|
294 |
+
"file_type": "model",
|
295 |
+
"framework": "flax",
|
296 |
+
}
|
297 |
+
|
298 |
+
# Load config if we don't provide a configuration
|
299 |
+
config_path = config if config is not None else pretrained_model_name_or_path
|
300 |
+
model, model_kwargs = cls.from_config(
|
301 |
+
config_path,
|
302 |
+
cache_dir=cache_dir,
|
303 |
+
return_unused_kwargs=True,
|
304 |
+
force_download=force_download,
|
305 |
+
resume_download=resume_download,
|
306 |
+
proxies=proxies,
|
307 |
+
local_files_only=local_files_only,
|
308 |
+
use_auth_token=use_auth_token,
|
309 |
+
revision=revision,
|
310 |
+
subfolder=subfolder,
|
311 |
+
# model args
|
312 |
+
dtype=dtype,
|
313 |
+
**kwargs,
|
314 |
+
)
|
315 |
+
|
316 |
+
# Load model
|
317 |
+
pretrained_path_with_subfolder = (
|
318 |
+
pretrained_model_name_or_path
|
319 |
+
if subfolder is None
|
320 |
+
else os.path.join(pretrained_model_name_or_path, subfolder)
|
321 |
+
)
|
322 |
+
if os.path.isdir(pretrained_path_with_subfolder):
|
323 |
+
if from_pt:
|
324 |
+
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
325 |
+
raise EnvironmentError(
|
326 |
+
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
|
327 |
+
)
|
328 |
+
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
|
329 |
+
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
|
330 |
+
# Load from a Flax checkpoint
|
331 |
+
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
|
332 |
+
# Check if pytorch weights exist instead
|
333 |
+
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
334 |
+
raise EnvironmentError(
|
335 |
+
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
|
336 |
+
" using `from_pt=True`."
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
raise EnvironmentError(
|
340 |
+
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
341 |
+
f"{pretrained_path_with_subfolder}."
|
342 |
+
)
|
343 |
+
send_telemetry(
|
344 |
+
{"model_class": cls.__name__, "model_path": "local", "framework": "flax"},
|
345 |
+
name="diffusers_from_pretrained",
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
try:
|
349 |
+
model_file = hf_hub_download(
|
350 |
+
pretrained_model_name_or_path,
|
351 |
+
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
|
352 |
+
cache_dir=cache_dir,
|
353 |
+
force_download=force_download,
|
354 |
+
proxies=proxies,
|
355 |
+
resume_download=resume_download,
|
356 |
+
local_files_only=local_files_only,
|
357 |
+
use_auth_token=use_auth_token,
|
358 |
+
user_agent=user_agent,
|
359 |
+
subfolder=subfolder,
|
360 |
+
revision=revision,
|
361 |
+
)
|
362 |
+
send_telemetry(
|
363 |
+
{"model_class": cls.__name__, "model_path": "hub", "framework": "flax"},
|
364 |
+
name="diffusers_from_pretrained",
|
365 |
+
)
|
366 |
+
|
367 |
+
except RepositoryNotFoundError:
|
368 |
+
raise EnvironmentError(
|
369 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
370 |
+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
371 |
+
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
372 |
+
"login`."
|
373 |
+
)
|
374 |
+
except RevisionNotFoundError:
|
375 |
+
raise EnvironmentError(
|
376 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
377 |
+
"this model name. Check the model page at "
|
378 |
+
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
379 |
+
)
|
380 |
+
except EntryNotFoundError:
|
381 |
+
raise EnvironmentError(
|
382 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
|
383 |
+
)
|
384 |
+
except HTTPError as err:
|
385 |
+
raise EnvironmentError(
|
386 |
+
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
387 |
+
f"{err}"
|
388 |
+
)
|
389 |
+
except ValueError:
|
390 |
+
raise EnvironmentError(
|
391 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
392 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
393 |
+
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
|
394 |
+
" internet connection or see how to run the library in offline mode at"
|
395 |
+
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
396 |
+
)
|
397 |
+
except EnvironmentError:
|
398 |
+
raise EnvironmentError(
|
399 |
+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
400 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
401 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
402 |
+
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
403 |
+
)
|
404 |
+
|
405 |
+
if from_pt:
|
406 |
+
if is_torch_available():
|
407 |
+
from .modeling_utils import load_state_dict
|
408 |
+
else:
|
409 |
+
raise EnvironmentError(
|
410 |
+
"Can't load the model in PyTorch format because PyTorch is not installed. "
|
411 |
+
"Please, install PyTorch or use native Flax weights."
|
412 |
+
)
|
413 |
+
|
414 |
+
# Step 1: Get the pytorch file
|
415 |
+
pytorch_model_file = load_state_dict(model_file)
|
416 |
+
|
417 |
+
# Step 2: Convert the weights
|
418 |
+
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
|
419 |
+
else:
|
420 |
+
try:
|
421 |
+
with open(model_file, "rb") as state_f:
|
422 |
+
state = from_bytes(cls, state_f.read())
|
423 |
+
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
424 |
+
try:
|
425 |
+
with open(model_file) as f:
|
426 |
+
if f.read().startswith("version"):
|
427 |
+
raise OSError(
|
428 |
+
"You seem to have cloned a repository without having git-lfs installed. Please"
|
429 |
+
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
430 |
+
" folder you cloned."
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
raise ValueError from e
|
434 |
+
except (UnicodeDecodeError, ValueError):
|
435 |
+
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
436 |
+
# make sure all arrays are stored as jnp.ndarray
|
437 |
+
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
438 |
+
# https://github.com/google/flax/issues/1261
|
439 |
+
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
440 |
+
|
441 |
+
# flatten dicts
|
442 |
+
state = flatten_dict(state)
|
443 |
+
|
444 |
+
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
|
445 |
+
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
446 |
+
|
447 |
+
shape_state = flatten_dict(unfreeze(params_shape_tree))
|
448 |
+
|
449 |
+
missing_keys = required_params - set(state.keys())
|
450 |
+
unexpected_keys = set(state.keys()) - required_params
|
451 |
+
|
452 |
+
if missing_keys:
|
453 |
+
logger.warning(
|
454 |
+
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
455 |
+
"Make sure to call model.init_weights to initialize the missing weights."
|
456 |
+
)
|
457 |
+
cls._missing_keys = missing_keys
|
458 |
+
|
459 |
+
for key in state.keys():
|
460 |
+
if key in shape_state and state[key].shape != shape_state[key].shape:
|
461 |
+
raise ValueError(
|
462 |
+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
463 |
+
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
|
464 |
+
)
|
465 |
+
|
466 |
+
# remove unexpected keys to not be saved again
|
467 |
+
for unexpected_key in unexpected_keys:
|
468 |
+
del state[unexpected_key]
|
469 |
+
|
470 |
+
if len(unexpected_keys) > 0:
|
471 |
+
logger.warning(
|
472 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
473 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
474 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
475 |
+
" with another architecture."
|
476 |
+
)
|
477 |
+
else:
|
478 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
479 |
+
|
480 |
+
if len(missing_keys) > 0:
|
481 |
+
logger.warning(
|
482 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
483 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
484 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
485 |
+
)
|
486 |
+
else:
|
487 |
+
logger.info(
|
488 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
489 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
490 |
+
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
491 |
+
" training."
|
492 |
+
)
|
493 |
+
|
494 |
+
return model, unflatten_dict(state)
|
495 |
+
|
496 |
+
def save_pretrained(
|
497 |
+
self,
|
498 |
+
save_directory: Union[str, os.PathLike],
|
499 |
+
params: Union[Dict, FrozenDict],
|
500 |
+
is_main_process: bool = True,
|
501 |
+
):
|
502 |
+
"""
|
503 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
504 |
+
`[`~FlaxModelMixin.from_pretrained`]` class method
|
505 |
+
|
506 |
+
Arguments:
|
507 |
+
save_directory (`str` or `os.PathLike`):
|
508 |
+
Directory to which to save. Will be created if it doesn't exist.
|
509 |
+
params (`Union[Dict, FrozenDict]`):
|
510 |
+
A `PyTree` of model parameters.
|
511 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
512 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
513 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
514 |
+
the main process to avoid race conditions.
|
515 |
+
"""
|
516 |
+
if os.path.isfile(save_directory):
|
517 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
518 |
+
return
|
519 |
+
|
520 |
+
os.makedirs(save_directory, exist_ok=True)
|
521 |
+
|
522 |
+
model_to_save = self
|
523 |
+
|
524 |
+
# Attach architecture to the config
|
525 |
+
# Save the config
|
526 |
+
if is_main_process:
|
527 |
+
model_to_save.save_config(save_directory)
|
528 |
+
|
529 |
+
# save model
|
530 |
+
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
531 |
+
with open(output_model_file, "wb") as f:
|
532 |
+
model_bytes = to_bytes(params)
|
533 |
+
f.write(model_bytes)
|
534 |
+
|
535 |
+
logger.info(f"Model weights saved in {output_model_file}")
|
diffusers/modeling_utils.py
ADDED
@@ -0,0 +1,892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import os
|
18 |
+
from functools import partial
|
19 |
+
from typing import Callable, List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from torch import Tensor, device
|
23 |
+
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
26 |
+
from requests import HTTPError
|
27 |
+
|
28 |
+
from . import __version__
|
29 |
+
from .hub_utils import send_telemetry
|
30 |
+
from .utils import (
|
31 |
+
CONFIG_NAME,
|
32 |
+
DIFFUSERS_CACHE,
|
33 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
34 |
+
SAFETENSORS_WEIGHTS_NAME,
|
35 |
+
WEIGHTS_NAME,
|
36 |
+
is_accelerate_available,
|
37 |
+
is_safetensors_available,
|
38 |
+
is_torch_version,
|
39 |
+
logging,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
|
46 |
+
if is_torch_version(">=", "1.9.0"):
|
47 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
48 |
+
else:
|
49 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
50 |
+
|
51 |
+
|
52 |
+
if is_accelerate_available():
|
53 |
+
import accelerate
|
54 |
+
from accelerate.utils import set_module_tensor_to_device
|
55 |
+
from accelerate.utils.versions import is_torch_version
|
56 |
+
|
57 |
+
if is_safetensors_available():
|
58 |
+
import safetensors
|
59 |
+
|
60 |
+
|
61 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
62 |
+
try:
|
63 |
+
return next(parameter.parameters()).device
|
64 |
+
except StopIteration:
|
65 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
66 |
+
|
67 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
68 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
69 |
+
return tuples
|
70 |
+
|
71 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
72 |
+
first_tuple = next(gen)
|
73 |
+
return first_tuple[1].device
|
74 |
+
|
75 |
+
|
76 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
77 |
+
try:
|
78 |
+
return next(parameter.parameters()).dtype
|
79 |
+
except StopIteration:
|
80 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
81 |
+
|
82 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
83 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
84 |
+
return tuples
|
85 |
+
|
86 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
87 |
+
first_tuple = next(gen)
|
88 |
+
return first_tuple[1].dtype
|
89 |
+
|
90 |
+
|
91 |
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
92 |
+
"""
|
93 |
+
Reads a checkpoint file, returning properly formatted errors if they arise.
|
94 |
+
"""
|
95 |
+
try:
|
96 |
+
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
|
97 |
+
return torch.load(checkpoint_file, map_location="cpu")
|
98 |
+
else:
|
99 |
+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
100 |
+
except Exception as e:
|
101 |
+
try:
|
102 |
+
with open(checkpoint_file) as f:
|
103 |
+
if f.read().startswith("version"):
|
104 |
+
raise OSError(
|
105 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
106 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
107 |
+
"you cloned."
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
raise ValueError(
|
111 |
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
112 |
+
"model. Make sure you have saved the model properly."
|
113 |
+
) from e
|
114 |
+
except (UnicodeDecodeError, ValueError):
|
115 |
+
raise OSError(
|
116 |
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
117 |
+
f"at '{checkpoint_file}'. "
|
118 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
def _load_state_dict_into_model(model_to_load, state_dict):
|
123 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
124 |
+
# copy state_dict so _load_from_state_dict can modify it
|
125 |
+
state_dict = state_dict.copy()
|
126 |
+
error_msgs = []
|
127 |
+
|
128 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
129 |
+
# so we need to apply the function recursively.
|
130 |
+
def load(module: torch.nn.Module, prefix=""):
|
131 |
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
132 |
+
module._load_from_state_dict(*args)
|
133 |
+
|
134 |
+
for name, child in module._modules.items():
|
135 |
+
if child is not None:
|
136 |
+
load(child, prefix + name + ".")
|
137 |
+
|
138 |
+
load(model_to_load)
|
139 |
+
|
140 |
+
return error_msgs
|
141 |
+
|
142 |
+
|
143 |
+
class ModelMixin(torch.nn.Module):
|
144 |
+
r"""
|
145 |
+
Base class for all models.
|
146 |
+
|
147 |
+
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
148 |
+
and saving models.
|
149 |
+
|
150 |
+
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
151 |
+
[`~modeling_utils.ModelMixin.save_pretrained`].
|
152 |
+
"""
|
153 |
+
config_name = CONFIG_NAME
|
154 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
155 |
+
_supports_gradient_checkpointing = False
|
156 |
+
|
157 |
+
def __init__(self):
|
158 |
+
super().__init__()
|
159 |
+
|
160 |
+
@property
|
161 |
+
def is_gradient_checkpointing(self) -> bool:
|
162 |
+
"""
|
163 |
+
Whether gradient checkpointing is activated for this model or not.
|
164 |
+
|
165 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
166 |
+
activations".
|
167 |
+
"""
|
168 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
169 |
+
|
170 |
+
def enable_gradient_checkpointing(self):
|
171 |
+
"""
|
172 |
+
Activates gradient checkpointing for the current model.
|
173 |
+
|
174 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
175 |
+
activations".
|
176 |
+
"""
|
177 |
+
if not self._supports_gradient_checkpointing:
|
178 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
179 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
180 |
+
|
181 |
+
def disable_gradient_checkpointing(self):
|
182 |
+
"""
|
183 |
+
Deactivates gradient checkpointing for the current model.
|
184 |
+
|
185 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
186 |
+
activations".
|
187 |
+
"""
|
188 |
+
if self._supports_gradient_checkpointing:
|
189 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
190 |
+
|
191 |
+
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
192 |
+
# Recursively walk through all the children.
|
193 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
194 |
+
# gets the message
|
195 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
196 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
197 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
198 |
+
|
199 |
+
for child in module.children():
|
200 |
+
fn_recursive_set_mem_eff(child)
|
201 |
+
|
202 |
+
for module in self.children():
|
203 |
+
if isinstance(module, torch.nn.Module):
|
204 |
+
fn_recursive_set_mem_eff(module)
|
205 |
+
|
206 |
+
def enable_xformers_memory_efficient_attention(self):
|
207 |
+
r"""
|
208 |
+
Enable memory efficient attention as implemented in xformers.
|
209 |
+
|
210 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
211 |
+
time. Speed up at training time is not guaranteed.
|
212 |
+
|
213 |
+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
214 |
+
is used.
|
215 |
+
"""
|
216 |
+
self.set_use_memory_efficient_attention_xformers(True)
|
217 |
+
|
218 |
+
def disable_xformers_memory_efficient_attention(self):
|
219 |
+
r"""
|
220 |
+
Disable memory efficient attention as implemented in xformers.
|
221 |
+
"""
|
222 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
223 |
+
|
224 |
+
def save_pretrained(
|
225 |
+
self,
|
226 |
+
save_directory: Union[str, os.PathLike],
|
227 |
+
is_main_process: bool = True,
|
228 |
+
save_function: Callable = None,
|
229 |
+
safe_serialization: bool = False,
|
230 |
+
):
|
231 |
+
"""
|
232 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
233 |
+
`[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
|
234 |
+
|
235 |
+
Arguments:
|
236 |
+
save_directory (`str` or `os.PathLike`):
|
237 |
+
Directory to which to save. Will be created if it doesn't exist.
|
238 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
239 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
240 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
241 |
+
the main process to avoid race conditions.
|
242 |
+
save_function (`Callable`):
|
243 |
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
244 |
+
need to replace `torch.save` by another method. Can be configured with the environment variable
|
245 |
+
`DIFFUSERS_SAVE_MODE`.
|
246 |
+
safe_serialization (`bool`, *optional*, defaults to `False`):
|
247 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
248 |
+
"""
|
249 |
+
if safe_serialization and not is_safetensors_available():
|
250 |
+
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
251 |
+
|
252 |
+
if os.path.isfile(save_directory):
|
253 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
254 |
+
return
|
255 |
+
|
256 |
+
if save_function is None:
|
257 |
+
save_function = safetensors.torch.save_file if safe_serialization else torch.save
|
258 |
+
|
259 |
+
os.makedirs(save_directory, exist_ok=True)
|
260 |
+
|
261 |
+
model_to_save = self
|
262 |
+
|
263 |
+
# Attach architecture to the config
|
264 |
+
# Save the config
|
265 |
+
if is_main_process:
|
266 |
+
model_to_save.save_config(save_directory)
|
267 |
+
|
268 |
+
# Save the model
|
269 |
+
state_dict = model_to_save.state_dict()
|
270 |
+
|
271 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
272 |
+
|
273 |
+
# Clean the folder from a previous save
|
274 |
+
for filename in os.listdir(save_directory):
|
275 |
+
full_filename = os.path.join(save_directory, filename)
|
276 |
+
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
277 |
+
# in distributed settings to avoid race conditions.
|
278 |
+
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
279 |
+
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
|
280 |
+
os.remove(full_filename)
|
281 |
+
|
282 |
+
# Save the model
|
283 |
+
save_function(state_dict, os.path.join(save_directory, weights_name))
|
284 |
+
|
285 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
286 |
+
|
287 |
+
@classmethod
|
288 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
289 |
+
r"""
|
290 |
+
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
291 |
+
|
292 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
293 |
+
the model, you should first set it back in training mode with `model.train()`.
|
294 |
+
|
295 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
296 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
297 |
+
task.
|
298 |
+
|
299 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
300 |
+
weights are discarded.
|
301 |
+
|
302 |
+
Parameters:
|
303 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
304 |
+
Can be either:
|
305 |
+
|
306 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
307 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
308 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
309 |
+
`./my_model_directory/`.
|
310 |
+
|
311 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
312 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
313 |
+
standard cache should not be used.
|
314 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
315 |
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
316 |
+
will be automatically derived from the model's weights.
|
317 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
318 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
319 |
+
cached versions if they exist.
|
320 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
321 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
322 |
+
file exists.
|
323 |
+
proxies (`Dict[str, str]`, *optional*):
|
324 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
325 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
326 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
327 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
328 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
329 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
330 |
+
use_auth_token (`str` or *bool*, *optional*):
|
331 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
332 |
+
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
333 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
334 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
335 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
336 |
+
identifier allowed by git.
|
337 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
338 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
339 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
340 |
+
|
341 |
+
mirror (`str`, *optional*):
|
342 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
343 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
344 |
+
Please refer to the mirror site for more information.
|
345 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
346 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
347 |
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
348 |
+
same device.
|
349 |
+
|
350 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
351 |
+
more information about each option see [designing a device
|
352 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
353 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
354 |
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
355 |
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
356 |
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
357 |
+
setting this argument to `True` will raise an error.
|
358 |
+
|
359 |
+
<Tip>
|
360 |
+
|
361 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
362 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
363 |
+
|
364 |
+
</Tip>
|
365 |
+
|
366 |
+
<Tip>
|
367 |
+
|
368 |
+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
369 |
+
this method in a firewalled environment.
|
370 |
+
|
371 |
+
</Tip>
|
372 |
+
|
373 |
+
"""
|
374 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
375 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
376 |
+
force_download = kwargs.pop("force_download", False)
|
377 |
+
resume_download = kwargs.pop("resume_download", False)
|
378 |
+
proxies = kwargs.pop("proxies", None)
|
379 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
380 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
381 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
382 |
+
revision = kwargs.pop("revision", None)
|
383 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
384 |
+
subfolder = kwargs.pop("subfolder", None)
|
385 |
+
device_map = kwargs.pop("device_map", None)
|
386 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
387 |
+
|
388 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
389 |
+
low_cpu_mem_usage = False
|
390 |
+
logger.warning(
|
391 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
392 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
393 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
394 |
+
" install accelerate\n```\n."
|
395 |
+
)
|
396 |
+
|
397 |
+
if device_map is not None and not is_accelerate_available():
|
398 |
+
raise NotImplementedError(
|
399 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
400 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
401 |
+
)
|
402 |
+
|
403 |
+
# Check if we can handle device_map and dispatching the weights
|
404 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
405 |
+
raise NotImplementedError(
|
406 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
407 |
+
" `device_map=None`."
|
408 |
+
)
|
409 |
+
|
410 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
411 |
+
raise NotImplementedError(
|
412 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
413 |
+
" `low_cpu_mem_usage=False`."
|
414 |
+
)
|
415 |
+
|
416 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
417 |
+
raise ValueError(
|
418 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
419 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
420 |
+
)
|
421 |
+
|
422 |
+
user_agent = {
|
423 |
+
"diffusers": __version__,
|
424 |
+
"file_type": "model",
|
425 |
+
"framework": "pytorch",
|
426 |
+
}
|
427 |
+
|
428 |
+
# Load config if we don't provide a configuration
|
429 |
+
config_path = pretrained_model_name_or_path
|
430 |
+
|
431 |
+
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
432 |
+
# Load model
|
433 |
+
|
434 |
+
model_file = None
|
435 |
+
if is_safetensors_available():
|
436 |
+
try:
|
437 |
+
model_file = cls._get_model_file(
|
438 |
+
pretrained_model_name_or_path,
|
439 |
+
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
440 |
+
cache_dir=cache_dir,
|
441 |
+
force_download=force_download,
|
442 |
+
resume_download=resume_download,
|
443 |
+
proxies=proxies,
|
444 |
+
local_files_only=local_files_only,
|
445 |
+
use_auth_token=use_auth_token,
|
446 |
+
revision=revision,
|
447 |
+
subfolder=subfolder,
|
448 |
+
user_agent=user_agent,
|
449 |
+
)
|
450 |
+
except:
|
451 |
+
pass
|
452 |
+
if model_file is None:
|
453 |
+
model_file = cls._get_model_file(
|
454 |
+
pretrained_model_name_or_path,
|
455 |
+
weights_name=WEIGHTS_NAME,
|
456 |
+
cache_dir=cache_dir,
|
457 |
+
force_download=force_download,
|
458 |
+
resume_download=resume_download,
|
459 |
+
proxies=proxies,
|
460 |
+
local_files_only=local_files_only,
|
461 |
+
use_auth_token=use_auth_token,
|
462 |
+
revision=revision,
|
463 |
+
subfolder=subfolder,
|
464 |
+
user_agent=user_agent,
|
465 |
+
)
|
466 |
+
|
467 |
+
if low_cpu_mem_usage:
|
468 |
+
# Instantiate model with empty weights
|
469 |
+
with accelerate.init_empty_weights():
|
470 |
+
config, unused_kwargs = cls.load_config(
|
471 |
+
config_path,
|
472 |
+
cache_dir=cache_dir,
|
473 |
+
return_unused_kwargs=True,
|
474 |
+
force_download=force_download,
|
475 |
+
resume_download=resume_download,
|
476 |
+
proxies=proxies,
|
477 |
+
local_files_only=local_files_only,
|
478 |
+
use_auth_token=use_auth_token,
|
479 |
+
revision=revision,
|
480 |
+
subfolder=subfolder,
|
481 |
+
device_map=device_map,
|
482 |
+
**kwargs,
|
483 |
+
)
|
484 |
+
model = cls.from_config(config, **unused_kwargs)
|
485 |
+
|
486 |
+
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
|
487 |
+
if device_map is None:
|
488 |
+
param_device = "cpu"
|
489 |
+
state_dict = load_state_dict(model_file)
|
490 |
+
# move the parms from meta device to cpu
|
491 |
+
for param_name, param in state_dict.items():
|
492 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
493 |
+
else: # else let accelerate handle loading and dispatching.
|
494 |
+
# Load weights and dispatch according to the device_map
|
495 |
+
# by deafult the device_map is None and the weights are loaded on the CPU
|
496 |
+
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
497 |
+
|
498 |
+
loading_info = {
|
499 |
+
"missing_keys": [],
|
500 |
+
"unexpected_keys": [],
|
501 |
+
"mismatched_keys": [],
|
502 |
+
"error_msgs": [],
|
503 |
+
}
|
504 |
+
else:
|
505 |
+
config, unused_kwargs = cls.load_config(
|
506 |
+
config_path,
|
507 |
+
cache_dir=cache_dir,
|
508 |
+
return_unused_kwargs=True,
|
509 |
+
force_download=force_download,
|
510 |
+
resume_download=resume_download,
|
511 |
+
proxies=proxies,
|
512 |
+
local_files_only=local_files_only,
|
513 |
+
use_auth_token=use_auth_token,
|
514 |
+
revision=revision,
|
515 |
+
subfolder=subfolder,
|
516 |
+
device_map=device_map,
|
517 |
+
**kwargs,
|
518 |
+
)
|
519 |
+
model = cls.from_config(config, **unused_kwargs)
|
520 |
+
|
521 |
+
state_dict = load_state_dict(model_file)
|
522 |
+
dtype = set(v.dtype for v in state_dict.values())
|
523 |
+
|
524 |
+
if len(dtype) > 1 and torch.float32 not in dtype:
|
525 |
+
raise ValueError(
|
526 |
+
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
|
527 |
+
f" make sure that {model_file} weights have only one dtype."
|
528 |
+
)
|
529 |
+
elif len(dtype) > 1 and torch.float32 in dtype:
|
530 |
+
dtype = torch.float32
|
531 |
+
else:
|
532 |
+
dtype = dtype.pop()
|
533 |
+
|
534 |
+
# move model to correct dtype
|
535 |
+
model = model.to(dtype)
|
536 |
+
|
537 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
538 |
+
model,
|
539 |
+
state_dict,
|
540 |
+
model_file,
|
541 |
+
pretrained_model_name_or_path,
|
542 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
543 |
+
)
|
544 |
+
|
545 |
+
loading_info = {
|
546 |
+
"missing_keys": missing_keys,
|
547 |
+
"unexpected_keys": unexpected_keys,
|
548 |
+
"mismatched_keys": mismatched_keys,
|
549 |
+
"error_msgs": error_msgs,
|
550 |
+
}
|
551 |
+
|
552 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
553 |
+
raise ValueError(
|
554 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
555 |
+
)
|
556 |
+
elif torch_dtype is not None:
|
557 |
+
model = model.to(torch_dtype)
|
558 |
+
|
559 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
560 |
+
|
561 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
562 |
+
model.eval()
|
563 |
+
if output_loading_info:
|
564 |
+
return model, loading_info
|
565 |
+
|
566 |
+
return model
|
567 |
+
|
568 |
+
@classmethod
|
569 |
+
def _get_model_file(
|
570 |
+
cls,
|
571 |
+
pretrained_model_name_or_path,
|
572 |
+
*,
|
573 |
+
weights_name,
|
574 |
+
subfolder,
|
575 |
+
cache_dir,
|
576 |
+
force_download,
|
577 |
+
proxies,
|
578 |
+
resume_download,
|
579 |
+
local_files_only,
|
580 |
+
use_auth_token,
|
581 |
+
user_agent,
|
582 |
+
revision,
|
583 |
+
):
|
584 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
585 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
586 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
587 |
+
# Load from a PyTorch checkpoint
|
588 |
+
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
589 |
+
elif subfolder is not None and os.path.isfile(
|
590 |
+
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
591 |
+
):
|
592 |
+
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
593 |
+
else:
|
594 |
+
raise EnvironmentError(
|
595 |
+
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
596 |
+
)
|
597 |
+
send_telemetry(
|
598 |
+
{"model_class": cls.__name__, "model_path": "local", "framework": "pytorch"},
|
599 |
+
name="diffusers_from_pretrained",
|
600 |
+
)
|
601 |
+
return model_file
|
602 |
+
else:
|
603 |
+
try:
|
604 |
+
# Load from URL or cache if already cached
|
605 |
+
model_file = hf_hub_download(
|
606 |
+
pretrained_model_name_or_path,
|
607 |
+
filename=weights_name,
|
608 |
+
cache_dir=cache_dir,
|
609 |
+
force_download=force_download,
|
610 |
+
proxies=proxies,
|
611 |
+
resume_download=resume_download,
|
612 |
+
local_files_only=local_files_only,
|
613 |
+
use_auth_token=use_auth_token,
|
614 |
+
user_agent=user_agent,
|
615 |
+
subfolder=subfolder,
|
616 |
+
revision=revision,
|
617 |
+
)
|
618 |
+
send_telemetry(
|
619 |
+
{"model_class": cls.__name__, "model_path": "hub", "framework": "pytorch"},
|
620 |
+
name="diffusers_from_pretrained",
|
621 |
+
)
|
622 |
+
return model_file
|
623 |
+
|
624 |
+
except RepositoryNotFoundError:
|
625 |
+
raise EnvironmentError(
|
626 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
627 |
+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
628 |
+
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
629 |
+
"login`."
|
630 |
+
)
|
631 |
+
except RevisionNotFoundError:
|
632 |
+
raise EnvironmentError(
|
633 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
634 |
+
"this model name. Check the model page at "
|
635 |
+
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
636 |
+
)
|
637 |
+
except EntryNotFoundError:
|
638 |
+
raise EnvironmentError(
|
639 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
640 |
+
)
|
641 |
+
except HTTPError as err:
|
642 |
+
raise EnvironmentError(
|
643 |
+
"There was a specific connection error when trying to load"
|
644 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
645 |
+
)
|
646 |
+
except ValueError:
|
647 |
+
raise EnvironmentError(
|
648 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
649 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
650 |
+
f" directory containing a file named {weights_name} or"
|
651 |
+
" \nCheckout your internet connection or see how to run the library in"
|
652 |
+
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
653 |
+
)
|
654 |
+
except EnvironmentError:
|
655 |
+
raise EnvironmentError(
|
656 |
+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
657 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
658 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
659 |
+
f"containing a file named {weights_name}"
|
660 |
+
)
|
661 |
+
|
662 |
+
@classmethod
|
663 |
+
def _load_pretrained_model(
|
664 |
+
cls,
|
665 |
+
model,
|
666 |
+
state_dict,
|
667 |
+
resolved_archive_file,
|
668 |
+
pretrained_model_name_or_path,
|
669 |
+
ignore_mismatched_sizes=False,
|
670 |
+
):
|
671 |
+
# Retrieve missing & unexpected_keys
|
672 |
+
model_state_dict = model.state_dict()
|
673 |
+
loaded_keys = [k for k in state_dict.keys()]
|
674 |
+
|
675 |
+
expected_keys = list(model_state_dict.keys())
|
676 |
+
|
677 |
+
original_loaded_keys = loaded_keys
|
678 |
+
|
679 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
680 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
681 |
+
|
682 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
683 |
+
model_to_load = model
|
684 |
+
|
685 |
+
def _find_mismatched_keys(
|
686 |
+
state_dict,
|
687 |
+
model_state_dict,
|
688 |
+
loaded_keys,
|
689 |
+
ignore_mismatched_sizes,
|
690 |
+
):
|
691 |
+
mismatched_keys = []
|
692 |
+
if ignore_mismatched_sizes:
|
693 |
+
for checkpoint_key in loaded_keys:
|
694 |
+
model_key = checkpoint_key
|
695 |
+
|
696 |
+
if (
|
697 |
+
model_key in model_state_dict
|
698 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
699 |
+
):
|
700 |
+
mismatched_keys.append(
|
701 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
702 |
+
)
|
703 |
+
del state_dict[checkpoint_key]
|
704 |
+
return mismatched_keys
|
705 |
+
|
706 |
+
if state_dict is not None:
|
707 |
+
# Whole checkpoint
|
708 |
+
mismatched_keys = _find_mismatched_keys(
|
709 |
+
state_dict,
|
710 |
+
model_state_dict,
|
711 |
+
original_loaded_keys,
|
712 |
+
ignore_mismatched_sizes,
|
713 |
+
)
|
714 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
715 |
+
|
716 |
+
if len(error_msgs) > 0:
|
717 |
+
error_msg = "\n\t".join(error_msgs)
|
718 |
+
if "size mismatch" in error_msg:
|
719 |
+
error_msg += (
|
720 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
721 |
+
)
|
722 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
723 |
+
|
724 |
+
if len(unexpected_keys) > 0:
|
725 |
+
logger.warning(
|
726 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
727 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
728 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
729 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
730 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
731 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
732 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
733 |
+
" BertForSequenceClassification model)."
|
734 |
+
)
|
735 |
+
else:
|
736 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
737 |
+
if len(missing_keys) > 0:
|
738 |
+
logger.warning(
|
739 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
740 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
741 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
742 |
+
)
|
743 |
+
elif len(mismatched_keys) == 0:
|
744 |
+
logger.info(
|
745 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
746 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
747 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
748 |
+
" without further training."
|
749 |
+
)
|
750 |
+
if len(mismatched_keys) > 0:
|
751 |
+
mismatched_warning = "\n".join(
|
752 |
+
[
|
753 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
754 |
+
for key, shape1, shape2 in mismatched_keys
|
755 |
+
]
|
756 |
+
)
|
757 |
+
logger.warning(
|
758 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
759 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
760 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
761 |
+
" able to use it for predictions and inference."
|
762 |
+
)
|
763 |
+
|
764 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
765 |
+
|
766 |
+
@property
|
767 |
+
def device(self) -> device:
|
768 |
+
"""
|
769 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
770 |
+
device).
|
771 |
+
"""
|
772 |
+
return get_parameter_device(self)
|
773 |
+
|
774 |
+
@property
|
775 |
+
def dtype(self) -> torch.dtype:
|
776 |
+
"""
|
777 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
778 |
+
"""
|
779 |
+
return get_parameter_dtype(self)
|
780 |
+
|
781 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
782 |
+
"""
|
783 |
+
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
784 |
+
|
785 |
+
Args:
|
786 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
787 |
+
Whether or not to return only the number of trainable parameters
|
788 |
+
|
789 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
790 |
+
Whether or not to return only the number of non-embeddings parameters
|
791 |
+
|
792 |
+
Returns:
|
793 |
+
`int`: The number of parameters.
|
794 |
+
"""
|
795 |
+
|
796 |
+
if exclude_embeddings:
|
797 |
+
embedding_param_names = [
|
798 |
+
f"{name}.weight"
|
799 |
+
for name, module_type in self.named_modules()
|
800 |
+
if isinstance(module_type, torch.nn.Embedding)
|
801 |
+
]
|
802 |
+
non_embedding_parameters = [
|
803 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
804 |
+
]
|
805 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
806 |
+
else:
|
807 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
808 |
+
|
809 |
+
|
810 |
+
def _get_model_file(
|
811 |
+
pretrained_model_name_or_path,
|
812 |
+
*,
|
813 |
+
weights_name,
|
814 |
+
subfolder,
|
815 |
+
cache_dir,
|
816 |
+
force_download,
|
817 |
+
proxies,
|
818 |
+
resume_download,
|
819 |
+
local_files_only,
|
820 |
+
use_auth_token,
|
821 |
+
user_agent,
|
822 |
+
revision,
|
823 |
+
):
|
824 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
825 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
826 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
827 |
+
# Load from a PyTorch checkpoint
|
828 |
+
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
829 |
+
return model_file
|
830 |
+
elif subfolder is not None and os.path.isfile(
|
831 |
+
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
832 |
+
):
|
833 |
+
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
834 |
+
return model_file
|
835 |
+
else:
|
836 |
+
raise EnvironmentError(
|
837 |
+
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
838 |
+
)
|
839 |
+
else:
|
840 |
+
try:
|
841 |
+
# Load from URL or cache if already cached
|
842 |
+
model_file = hf_hub_download(
|
843 |
+
pretrained_model_name_or_path,
|
844 |
+
filename=weights_name,
|
845 |
+
cache_dir=cache_dir,
|
846 |
+
force_download=force_download,
|
847 |
+
proxies=proxies,
|
848 |
+
resume_download=resume_download,
|
849 |
+
local_files_only=local_files_only,
|
850 |
+
use_auth_token=use_auth_token,
|
851 |
+
user_agent=user_agent,
|
852 |
+
subfolder=subfolder,
|
853 |
+
revision=revision,
|
854 |
+
)
|
855 |
+
return model_file
|
856 |
+
|
857 |
+
except RepositoryNotFoundError:
|
858 |
+
raise EnvironmentError(
|
859 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
860 |
+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
861 |
+
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
862 |
+
"login`."
|
863 |
+
)
|
864 |
+
except RevisionNotFoundError:
|
865 |
+
raise EnvironmentError(
|
866 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
867 |
+
"this model name. Check the model page at "
|
868 |
+
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
869 |
+
)
|
870 |
+
except EntryNotFoundError:
|
871 |
+
raise EnvironmentError(
|
872 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
873 |
+
)
|
874 |
+
except HTTPError as err:
|
875 |
+
raise EnvironmentError(
|
876 |
+
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
877 |
+
)
|
878 |
+
except ValueError:
|
879 |
+
raise EnvironmentError(
|
880 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
881 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
882 |
+
f" directory containing a file named {weights_name} or"
|
883 |
+
" \nCheckout your internet connection or see how to run the library in"
|
884 |
+
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
885 |
+
)
|
886 |
+
except EnvironmentError:
|
887 |
+
raise EnvironmentError(
|
888 |
+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
889 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
890 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
891 |
+
f"containing a file named {weights_name}"
|
892 |
+
)
|
diffusers/models/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from ..utils import is_flax_available, is_torch_available
|
16 |
+
|
17 |
+
|
18 |
+
if is_torch_available():
|
19 |
+
from .attention import Transformer2DModel
|
20 |
+
from .unet_1d import UNet1DModel
|
21 |
+
from .unet_2d import UNet2DModel
|
22 |
+
from .unet_2d_condition import UNet2DConditionModel
|
23 |
+
from .vae import AutoencoderKL, VQModel
|
24 |
+
|
25 |
+
if is_flax_available():
|
26 |
+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
27 |
+
from .vae_flax import FlaxAutoencoderKL
|
diffusers/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (639 Bytes). View file
|
|
diffusers/models/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (28.5 kB). View file
|
|
diffusers/models/__pycache__/attention_flax.cpython-310.pyc
ADDED
Binary file (8.99 kB). View file
|
|
diffusers/models/__pycache__/embeddings.cpython-310.pyc
ADDED
Binary file (5.87 kB). View file
|
|
diffusers/models/__pycache__/embeddings_flax.cpython-310.pyc
ADDED
Binary file (3.19 kB). View file
|
|
diffusers/models/__pycache__/resnet.cpython-310.pyc
ADDED
Binary file (18.4 kB). View file
|
|
diffusers/models/__pycache__/resnet_flax.cpython-310.pyc
ADDED
Binary file (2.77 kB). View file
|
|
diffusers/models/__pycache__/unet_1d.cpython-310.pyc
ADDED
Binary file (7.41 kB). View file
|
|
diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc
ADDED
Binary file (16.7 kB). View file
|
|
diffusers/models/__pycache__/unet_2d.cpython-310.pyc
ADDED
Binary file (7.97 kB). View file
|
|
diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc
ADDED
Binary file (26.1 kB). View file
|
|
diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-310.pyc
ADDED
Binary file (10 kB). View file
|
|
diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc
ADDED
Binary file (12.5 kB). View file
|
|
diffusers/models/__pycache__/unet_2d_condition_flax.cpython-310.pyc
ADDED
Binary file (9.63 kB). View file
|
|