lakshyana commited on
Commit
7c8c2c8
1 Parent(s): 6f6356e

updated diffusers

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffusers +0 -0
  2. diffusers/__init__.py +60 -0
  3. diffusers/__pycache__/__init__.cpython-310.pyc +0 -0
  4. diffusers/__pycache__/__init__.cpython-37.pyc +0 -0
  5. diffusers/__pycache__/configuration_utils.cpython-310.pyc +0 -0
  6. diffusers/__pycache__/configuration_utils.cpython-37.pyc +0 -0
  7. diffusers/__pycache__/dependency_versions_check.cpython-310.pyc +0 -0
  8. diffusers/__pycache__/dependency_versions_table.cpython-310.pyc +0 -0
  9. diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc +0 -0
  10. diffusers/__pycache__/hub_utils.cpython-310.pyc +0 -0
  11. diffusers/__pycache__/modeling_utils.cpython-310.pyc +0 -0
  12. diffusers/__pycache__/modeling_utils.cpython-37.pyc +0 -0
  13. diffusers/__pycache__/onnx_utils.cpython-310.pyc +0 -0
  14. diffusers/__pycache__/onnx_utils.cpython-37.pyc +0 -0
  15. diffusers/__pycache__/optimization.cpython-310.pyc +0 -0
  16. diffusers/__pycache__/optimization.cpython-37.pyc +0 -0
  17. diffusers/__pycache__/pipeline_utils.cpython-310.pyc +0 -0
  18. diffusers/__pycache__/pipeline_utils.cpython-37.pyc +0 -0
  19. diffusers/__pycache__/testing_utils.cpython-310.pyc +0 -0
  20. diffusers/__pycache__/training_utils.cpython-310.pyc +0 -0
  21. diffusers/__pycache__/training_utils.cpython-37.pyc +0 -0
  22. diffusers/commands/__init__.py +27 -0
  23. diffusers/commands/__pycache__/__init__.cpython-310.pyc +0 -0
  24. diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc +0 -0
  25. diffusers/commands/__pycache__/env.cpython-310.pyc +0 -0
  26. diffusers/commands/diffusers_cli.py +41 -0
  27. diffusers/commands/env.py +70 -0
  28. diffusers/configuration_utils.py +403 -0
  29. diffusers/dependency_versions_check.py +47 -0
  30. diffusers/dependency_versions_table.py +26 -0
  31. diffusers/dynamic_modules_utils.py +335 -0
  32. diffusers/hub_utils.py +197 -0
  33. diffusers/modeling_utils.py +542 -0
  34. diffusers/models/__init__.py +17 -0
  35. diffusers/models/__pycache__/__init__.cpython-310.pyc +0 -0
  36. diffusers/models/__pycache__/__init__.cpython-37.pyc +0 -0
  37. diffusers/models/__pycache__/attention.cpython-310.pyc +0 -0
  38. diffusers/models/__pycache__/attention.cpython-37.pyc +0 -0
  39. diffusers/models/__pycache__/embeddings.cpython-310.pyc +0 -0
  40. diffusers/models/__pycache__/embeddings.cpython-37.pyc +0 -0
  41. diffusers/models/__pycache__/resnet.cpython-310.pyc +0 -0
  42. diffusers/models/__pycache__/resnet.cpython-37.pyc +0 -0
  43. diffusers/models/__pycache__/unet_2d.cpython-310.pyc +0 -0
  44. diffusers/models/__pycache__/unet_2d.cpython-37.pyc +0 -0
  45. diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc +0 -0
  46. diffusers/models/__pycache__/unet_2d_condition.cpython-37.pyc +0 -0
  47. diffusers/models/__pycache__/unet_blocks.cpython-310.pyc +0 -0
  48. diffusers/models/__pycache__/unet_blocks.cpython-37.pyc +0 -0
  49. diffusers/models/__pycache__/vae.cpython-310.pyc +0 -0
  50. diffusers/models/__pycache__/vae.cpython-37.pyc +0 -0
diffusers DELETED
File without changes
diffusers/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (
2
+ is_inflect_available,
3
+ is_onnx_available,
4
+ is_scipy_available,
5
+ is_transformers_available,
6
+ is_unidecode_available,
7
+ )
8
+
9
+
10
+ __version__ = "0.3.0"
11
+
12
+ from .configuration_utils import ConfigMixin
13
+ from .modeling_utils import ModelMixin
14
+ from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
15
+ from .onnx_utils import OnnxRuntimeModel
16
+ from .optimization import (
17
+ get_constant_schedule,
18
+ get_constant_schedule_with_warmup,
19
+ get_cosine_schedule_with_warmup,
20
+ get_cosine_with_hard_restarts_schedule_with_warmup,
21
+ get_linear_schedule_with_warmup,
22
+ get_polynomial_decay_schedule_with_warmup,
23
+ get_scheduler,
24
+ )
25
+ from .pipeline_utils import DiffusionPipeline
26
+ from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
27
+ from .schedulers import (
28
+ DDIMScheduler,
29
+ DDPMScheduler,
30
+ KarrasVeScheduler,
31
+ PNDMScheduler,
32
+ SchedulerMixin,
33
+ ScoreSdeVeScheduler,
34
+ )
35
+ from .utils import logging
36
+
37
+
38
+ if is_scipy_available():
39
+ from .schedulers import LMSDiscreteScheduler
40
+ else:
41
+ from .utils.dummy_scipy_objects import * # noqa F403
42
+
43
+ from .training_utils import EMAModel
44
+
45
+
46
+ if is_transformers_available():
47
+ from .pipelines import (
48
+ LDMTextToImagePipeline,
49
+ StableDiffusionImg2ImgPipeline,
50
+ StableDiffusionInpaintPipeline,
51
+ StableDiffusionPipeline,
52
+ )
53
+ else:
54
+ from .utils.dummy_transformers_objects import * # noqa F403
55
+
56
+
57
+ if is_transformers_available() and is_onnx_available():
58
+ from .pipelines import StableDiffusionOnnxPipeline
59
+ else:
60
+ from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
diffusers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
diffusers/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.91 kB). View file
 
diffusers/__pycache__/configuration_utils.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
diffusers/__pycache__/configuration_utils.cpython-37.pyc ADDED
Binary file (15.6 kB). View file
 
diffusers/__pycache__/dependency_versions_check.cpython-310.pyc ADDED
Binary file (967 Bytes). View file
 
diffusers/__pycache__/dependency_versions_table.cpython-310.pyc ADDED
Binary file (819 Bytes). View file
 
diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
diffusers/__pycache__/hub_utils.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
diffusers/__pycache__/modeling_utils.cpython-310.pyc ADDED
Binary file (18.7 kB). View file
 
diffusers/__pycache__/modeling_utils.cpython-37.pyc ADDED
Binary file (18.9 kB). View file
 
diffusers/__pycache__/onnx_utils.cpython-310.pyc ADDED
Binary file (6.3 kB). View file
 
diffusers/__pycache__/onnx_utils.cpython-37.pyc ADDED
Binary file (6.26 kB). View file
 
diffusers/__pycache__/optimization.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
diffusers/__pycache__/optimization.cpython-37.pyc ADDED
Binary file (10.3 kB). View file
 
diffusers/__pycache__/pipeline_utils.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
diffusers/__pycache__/pipeline_utils.cpython-37.pyc ADDED
Binary file (14 kB). View file
 
diffusers/__pycache__/testing_utils.cpython-310.pyc ADDED
Binary file (1.66 kB). View file
 
diffusers/__pycache__/training_utils.cpython-310.pyc ADDED
Binary file (3.64 kB). View file
 
diffusers/__pycache__/training_utils.cpython-37.pyc ADDED
Binary file (3.67 kB). View file
 
diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
diffusers/commands/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (817 Bytes). View file
 
diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc ADDED
Binary file (778 Bytes). View file
 
diffusers/commands/__pycache__/env.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+
20
+
21
+ def main():
22
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
+
25
+ # Register commands
26
+ EnvironmentCommand.register_subcommand(commands_parser)
27
+
28
+ # Let's go
29
+ args = parser.parse_args()
30
+
31
+ if not hasattr(args, "func"):
32
+ parser.print_help()
33
+ exit(1)
34
+
35
+ # Run
36
+ service = args.func(args)
37
+ service.run()
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
diffusers/commands/env.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_torch_available, is_transformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available:
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ info = {
53
+ "`diffusers` version": version,
54
+ "Platform": platform.platform(),
55
+ "Python version": platform.python_version(),
56
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
57
+ "Huggingface_hub version": hub_version,
58
+ "Transformers version": transformers_version,
59
+ "Using GPU in script?": "<fill in>",
60
+ "Using distributed or parallel set-up in script?": "<fill in>",
61
+ }
62
+
63
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
64
+ print(self.format_dict(info))
65
+
66
+ return info
67
+
68
+ @staticmethod
69
+ def format_dict(d):
70
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diffusers/configuration_utils.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixinuration base class and utilities."""
17
+ import functools
18
+ import inspect
19
+ import json
20
+ import os
21
+ import re
22
+ from collections import OrderedDict
23
+ from typing import Any, Dict, Tuple, Union
24
+
25
+ from huggingface_hub import hf_hub_download
26
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
27
+ from requests import HTTPError
28
+
29
+ from . import __version__
30
+ from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
36
+
37
+
38
+ class ConfigMixin:
39
+ r"""
40
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
41
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
42
+ - [`~ConfigMixin.from_config`]
43
+ - [`~ConfigMixin.save_config`]
44
+
45
+ Class attributes:
46
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
47
+ [`~ConfigMixin.save_config`] (should be overriden by parent class).
48
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
49
+ overriden by parent class).
50
+ """
51
+ config_name = None
52
+ ignore_for_config = []
53
+
54
+ def register_to_config(self, **kwargs):
55
+ if self.config_name is None:
56
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
57
+ kwargs["_class_name"] = self.__class__.__name__
58
+ kwargs["_diffusers_version"] = __version__
59
+
60
+ for key, value in kwargs.items():
61
+ try:
62
+ setattr(self, key, value)
63
+ except AttributeError as err:
64
+ logger.error(f"Can't set {key} with value {value} for {self}")
65
+ raise err
66
+
67
+ if not hasattr(self, "_internal_dict"):
68
+ internal_dict = kwargs
69
+ else:
70
+ previous_dict = dict(self._internal_dict)
71
+ internal_dict = {**self._internal_dict, **kwargs}
72
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
73
+
74
+ self._internal_dict = FrozenDict(internal_dict)
75
+
76
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
77
+ """
78
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
79
+ [`~ConfigMixin.from_config`] class method.
80
+
81
+ Args:
82
+ save_directory (`str` or `os.PathLike`):
83
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
84
+ """
85
+ if os.path.isfile(save_directory):
86
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
87
+
88
+ os.makedirs(save_directory, exist_ok=True)
89
+
90
+ # If we save using the predefined names, we can load using `from_config`
91
+ output_config_file = os.path.join(save_directory, self.config_name)
92
+
93
+ self.to_json_file(output_config_file)
94
+ logger.info(f"ConfigMixinuration saved in {output_config_file}")
95
+
96
+ @classmethod
97
+ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
98
+ r"""
99
+ Instantiate a Python class from a pre-defined JSON-file.
100
+
101
+ Parameters:
102
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
103
+ Can be either:
104
+
105
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
106
+ organization name, like `google/ddpm-celebahq-256`.
107
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
108
+ `./my_model_directory/`.
109
+
110
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
111
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
112
+ standard cache should not be used.
113
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
114
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
115
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
116
+ checkpoint with 3 labels).
117
+ force_download (`bool`, *optional*, defaults to `False`):
118
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
119
+ cached versions if they exist.
120
+ resume_download (`bool`, *optional*, defaults to `False`):
121
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
122
+ file exists.
123
+ proxies (`Dict[str, str]`, *optional*):
124
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
125
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
126
+ output_loading_info(`bool`, *optional*, defaults to `False`):
127
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
128
+ local_files_only(`bool`, *optional*, defaults to `False`):
129
+ Whether or not to only look at local files (i.e., do not try to download the model).
130
+ use_auth_token (`str` or *bool*, *optional*):
131
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
132
+ when running `transformers-cli login` (stored in `~/.huggingface`).
133
+ revision (`str`, *optional*, defaults to `"main"`):
134
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
135
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
136
+ identifier allowed by git.
137
+ mirror (`str`, *optional*):
138
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
139
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
140
+ Please refer to the mirror site for more information.
141
+
142
+ <Tip>
143
+
144
+ Passing `use_auth_token=True`` is required when you want to use a private model.
145
+
146
+ </Tip>
147
+
148
+ <Tip>
149
+
150
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
151
+ use this method in a firewalled environment.
152
+
153
+ </Tip>
154
+
155
+ """
156
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
157
+
158
+ init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
159
+
160
+ model = cls(**init_dict)
161
+
162
+ if return_unused_kwargs:
163
+ return model, unused_kwargs
164
+ else:
165
+ return model
166
+
167
+ @classmethod
168
+ def get_config_dict(
169
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
170
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
171
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
172
+ force_download = kwargs.pop("force_download", False)
173
+ resume_download = kwargs.pop("resume_download", False)
174
+ proxies = kwargs.pop("proxies", None)
175
+ use_auth_token = kwargs.pop("use_auth_token", None)
176
+ local_files_only = kwargs.pop("local_files_only", False)
177
+ revision = kwargs.pop("revision", None)
178
+ subfolder = kwargs.pop("subfolder", None)
179
+
180
+ user_agent = {"file_type": "config"}
181
+
182
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
183
+
184
+ if cls.config_name is None:
185
+ raise ValueError(
186
+ "`self.config_name` is not defined. Note that one should not load a config from "
187
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
188
+ )
189
+
190
+ if os.path.isfile(pretrained_model_name_or_path):
191
+ config_file = pretrained_model_name_or_path
192
+ elif os.path.isdir(pretrained_model_name_or_path):
193
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
194
+ # Load from a PyTorch checkpoint
195
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
196
+ elif subfolder is not None and os.path.isfile(
197
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
198
+ ):
199
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
200
+ else:
201
+ raise EnvironmentError(
202
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
203
+ )
204
+ else:
205
+ try:
206
+ # Load from URL or cache if already cached
207
+ config_file = hf_hub_download(
208
+ pretrained_model_name_or_path,
209
+ filename=cls.config_name,
210
+ cache_dir=cache_dir,
211
+ force_download=force_download,
212
+ proxies=proxies,
213
+ resume_download=resume_download,
214
+ local_files_only=local_files_only,
215
+ use_auth_token=use_auth_token,
216
+ user_agent=user_agent,
217
+ subfolder=subfolder,
218
+ revision=revision,
219
+ )
220
+
221
+ except RepositoryNotFoundError:
222
+ raise EnvironmentError(
223
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
224
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
225
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
226
+ " login` and pass `use_auth_token=True`."
227
+ )
228
+ except RevisionNotFoundError:
229
+ raise EnvironmentError(
230
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
231
+ " this model name. Check the model page at"
232
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
233
+ )
234
+ except EntryNotFoundError:
235
+ raise EnvironmentError(
236
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
237
+ )
238
+ except HTTPError as err:
239
+ raise EnvironmentError(
240
+ "There was a specific connection error when trying to load"
241
+ f" {pretrained_model_name_or_path}:\n{err}"
242
+ )
243
+ except ValueError:
244
+ raise EnvironmentError(
245
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
246
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
247
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
248
+ " run the library in offline mode at"
249
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
250
+ )
251
+ except EnvironmentError:
252
+ raise EnvironmentError(
253
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
254
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
255
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
256
+ f"containing a {cls.config_name} file"
257
+ )
258
+
259
+ try:
260
+ # Load config dict
261
+ config_dict = cls._dict_from_json_file(config_file)
262
+ except (json.JSONDecodeError, UnicodeDecodeError):
263
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
264
+
265
+ return config_dict
266
+
267
+ @classmethod
268
+ def extract_init_dict(cls, config_dict, **kwargs):
269
+ expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
270
+ expected_keys.remove("self")
271
+ # remove general kwargs if present in dict
272
+ if "kwargs" in expected_keys:
273
+ expected_keys.remove("kwargs")
274
+ # remove keys to be ignored
275
+ if len(cls.ignore_for_config) > 0:
276
+ expected_keys = expected_keys - set(cls.ignore_for_config)
277
+ init_dict = {}
278
+ for key in expected_keys:
279
+ if key in kwargs:
280
+ # overwrite key
281
+ init_dict[key] = kwargs.pop(key)
282
+ elif key in config_dict:
283
+ # use value from config dict
284
+ init_dict[key] = config_dict.pop(key)
285
+
286
+ unused_kwargs = config_dict.update(kwargs)
287
+
288
+ passed_keys = set(init_dict.keys())
289
+ if len(expected_keys - passed_keys) > 0:
290
+ logger.warning(
291
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
292
+ )
293
+
294
+ return init_dict, unused_kwargs
295
+
296
+ @classmethod
297
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
298
+ with open(json_file, "r", encoding="utf-8") as reader:
299
+ text = reader.read()
300
+ return json.loads(text)
301
+
302
+ def __repr__(self):
303
+ return f"{self.__class__.__name__} {self.to_json_string()}"
304
+
305
+ @property
306
+ def config(self) -> Dict[str, Any]:
307
+ return self._internal_dict
308
+
309
+ def to_json_string(self) -> str:
310
+ """
311
+ Serializes this instance to a JSON string.
312
+
313
+ Returns:
314
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
315
+ """
316
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
317
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
318
+
319
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
320
+ """
321
+ Save this instance to a JSON file.
322
+
323
+ Args:
324
+ json_file_path (`str` or `os.PathLike`):
325
+ Path to the JSON file in which this configuration instance's parameters will be saved.
326
+ """
327
+ with open(json_file_path, "w", encoding="utf-8") as writer:
328
+ writer.write(self.to_json_string())
329
+
330
+
331
+ class FrozenDict(OrderedDict):
332
+ def __init__(self, *args, **kwargs):
333
+ super().__init__(*args, **kwargs)
334
+
335
+ for key, value in self.items():
336
+ setattr(self, key, value)
337
+
338
+ self.__frozen = True
339
+
340
+ def __delitem__(self, *args, **kwargs):
341
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
342
+
343
+ def setdefault(self, *args, **kwargs):
344
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
345
+
346
+ def pop(self, *args, **kwargs):
347
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
348
+
349
+ def update(self, *args, **kwargs):
350
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
351
+
352
+ def __setattr__(self, name, value):
353
+ if hasattr(self, "__frozen") and self.__frozen:
354
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
355
+ super().__setattr__(name, value)
356
+
357
+ def __setitem__(self, name, value):
358
+ if hasattr(self, "__frozen") and self.__frozen:
359
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
360
+ super().__setitem__(name, value)
361
+
362
+
363
+ def register_to_config(init):
364
+ r"""
365
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
366
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
367
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
368
+
369
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
370
+ """
371
+
372
+ @functools.wraps(init)
373
+ def inner_init(self, *args, **kwargs):
374
+ # Ignore private kwargs in the init.
375
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
376
+ init(self, *args, **init_kwargs)
377
+ if not isinstance(self, ConfigMixin):
378
+ raise RuntimeError(
379
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
380
+ "not inherit from `ConfigMixin`."
381
+ )
382
+
383
+ ignore = getattr(self, "ignore_for_config", [])
384
+ # Get positional arguments aligned with kwargs
385
+ new_kwargs = {}
386
+ signature = inspect.signature(init)
387
+ parameters = {
388
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
389
+ }
390
+ for arg, name in zip(args, parameters.keys()):
391
+ new_kwargs[name] = arg
392
+
393
+ # Then add all kwargs
394
+ new_kwargs.update(
395
+ {
396
+ k: init_kwargs.get(k, default)
397
+ for k, default in parameters.items()
398
+ if k not in ignore and k not in new_kwargs
399
+ }
400
+ )
401
+ getattr(self, "register_to_config")(**new_kwargs)
402
+
403
+ return inner_init
diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
+ if sys.version_info < (3, 7):
28
+ pkgs_to_check_at_runtime.append("dataclasses")
29
+ if sys.version_info < (3, 8):
30
+ pkgs_to_check_at_runtime.append("importlib_metadata")
31
+
32
+ for pkg in pkgs_to_check_at_runtime:
33
+ if pkg in deps:
34
+ if pkg == "tokenizers":
35
+ # must be loaded here, or else tqdm check may fail
36
+ from .utils import is_tokenizers_available
37
+
38
+ if not is_tokenizers_available():
39
+ continue # not required, check version only if installed
40
+
41
+ require_version_core(deps[pkg])
42
+ else:
43
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
+
45
+
46
+ def dep_version_check(pkg, hint=None):
47
+ require_version(deps[pkg], hint)
diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "black": "black==22.3",
8
+ "datasets": "datasets",
9
+ "filelock": "filelock",
10
+ "flake8": "flake8>=3.8.3",
11
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
+ "huggingface-hub": "huggingface-hub>=0.8.1",
13
+ "importlib_metadata": "importlib_metadata",
14
+ "isort": "isort>=5.5.4",
15
+ "modelcards": "modelcards==0.1.4",
16
+ "numpy": "numpy",
17
+ "pytest": "pytest",
18
+ "pytest-timeout": "pytest-timeout",
19
+ "pytest-xdist": "pytest-xdist",
20
+ "scipy": "scipy",
21
+ "regex": "regex!=2019.12.17",
22
+ "requests": "requests",
23
+ "tensorboard": "tensorboard",
24
+ "torch": "torch>=1.4",
25
+ "transformers": "transformers>=4.21.0",
26
+ }
diffusers/dynamic_modules_utils.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities to dynamically load objects from the Hub."""
16
+
17
+ import importlib
18
+ import os
19
+ import re
20
+ import shutil
21
+ import sys
22
+ from pathlib import Path
23
+ from typing import Dict, Optional, Union
24
+
25
+ from huggingface_hub import cached_download
26
+
27
+ from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ def init_hf_modules():
34
+ """
35
+ Creates the cache directory for modules with an init, and adds it to the Python path.
36
+ """
37
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
38
+ if HF_MODULES_CACHE in sys.path:
39
+ return
40
+
41
+ sys.path.append(HF_MODULES_CACHE)
42
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
43
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
44
+ if not init_path.exists():
45
+ init_path.touch()
46
+
47
+
48
+ def create_dynamic_module(name: Union[str, os.PathLike]):
49
+ """
50
+ Creates a dynamic module in the cache directory for modules.
51
+ """
52
+ init_hf_modules()
53
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
54
+ # If the parent module does not exist yet, recursively create it.
55
+ if not dynamic_module_path.parent.exists():
56
+ create_dynamic_module(dynamic_module_path.parent)
57
+ os.makedirs(dynamic_module_path, exist_ok=True)
58
+ init_path = dynamic_module_path / "__init__.py"
59
+ if not init_path.exists():
60
+ init_path.touch()
61
+
62
+
63
+ def get_relative_imports(module_file):
64
+ """
65
+ Get the list of modules that are relatively imported in a module file.
66
+
67
+ Args:
68
+ module_file (`str` or `os.PathLike`): The module file to inspect.
69
+ """
70
+ with open(module_file, "r", encoding="utf-8") as f:
71
+ content = f.read()
72
+
73
+ # Imports of the form `import .xxx`
74
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
75
+ # Imports of the form `from .xxx import yyy`
76
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
77
+ # Unique-ify
78
+ return list(set(relative_imports))
79
+
80
+
81
+ def get_relative_import_files(module_file):
82
+ """
83
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
84
+ imports (if a imports b and b imports c, it will return module files for b and c).
85
+
86
+ Args:
87
+ module_file (`str` or `os.PathLike`): The module file to inspect.
88
+ """
89
+ no_change = False
90
+ files_to_check = [module_file]
91
+ all_relative_imports = []
92
+
93
+ # Let's recurse through all relative imports
94
+ while not no_change:
95
+ new_imports = []
96
+ for f in files_to_check:
97
+ new_imports.extend(get_relative_imports(f))
98
+
99
+ module_path = Path(module_file).parent
100
+ new_import_files = [str(module_path / m) for m in new_imports]
101
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
102
+ files_to_check = [f"{f}.py" for f in new_import_files]
103
+
104
+ no_change = len(new_import_files) == 0
105
+ all_relative_imports.extend(files_to_check)
106
+
107
+ return all_relative_imports
108
+
109
+
110
+ def check_imports(filename):
111
+ """
112
+ Check if the current Python environment contains all the libraries that are imported in a file.
113
+ """
114
+ with open(filename, "r", encoding="utf-8") as f:
115
+ content = f.read()
116
+
117
+ # Imports of the form `import xxx`
118
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
119
+ # Imports of the form `from xxx import yyy`
120
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
121
+ # Only keep the top-level module
122
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
123
+
124
+ # Unique-ify and test we got them all
125
+ imports = list(set(imports))
126
+ missing_packages = []
127
+ for imp in imports:
128
+ try:
129
+ importlib.import_module(imp)
130
+ except ImportError:
131
+ missing_packages.append(imp)
132
+
133
+ if len(missing_packages) > 0:
134
+ raise ImportError(
135
+ "This modeling file requires the following packages that were not found in your environment: "
136
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
137
+ )
138
+
139
+ return get_relative_imports(filename)
140
+
141
+
142
+ def get_class_in_module(class_name, module_path):
143
+ """
144
+ Import a module on the cache directory for modules and extract a class from it.
145
+ """
146
+ module_path = module_path.replace(os.path.sep, ".")
147
+ module = importlib.import_module(module_path)
148
+ return getattr(module, class_name)
149
+
150
+
151
+ def get_cached_module_file(
152
+ pretrained_model_name_or_path: Union[str, os.PathLike],
153
+ module_file: str,
154
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
155
+ force_download: bool = False,
156
+ resume_download: bool = False,
157
+ proxies: Optional[Dict[str, str]] = None,
158
+ use_auth_token: Optional[Union[bool, str]] = None,
159
+ revision: Optional[str] = None,
160
+ local_files_only: bool = False,
161
+ ):
162
+ """
163
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
164
+ Transformers module.
165
+
166
+ Args:
167
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
168
+ This can be either:
169
+
170
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
171
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
172
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
173
+ - a path to a *directory* containing a configuration file saved using the
174
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
175
+
176
+ module_file (`str`):
177
+ The name of the module file containing the class to look for.
178
+ cache_dir (`str` or `os.PathLike`, *optional*):
179
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
180
+ cache should not be used.
181
+ force_download (`bool`, *optional*, defaults to `False`):
182
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
183
+ exist.
184
+ resume_download (`bool`, *optional*, defaults to `False`):
185
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
186
+ proxies (`Dict[str, str]`, *optional*):
187
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
188
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
189
+ use_auth_token (`str` or *bool*, *optional*):
190
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
191
+ when running `transformers-cli login` (stored in `~/.huggingface`).
192
+ revision (`str`, *optional*, defaults to `"main"`):
193
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
194
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
195
+ identifier allowed by git.
196
+ local_files_only (`bool`, *optional*, defaults to `False`):
197
+ If `True`, will only try to load the tokenizer configuration from local files.
198
+
199
+ <Tip>
200
+
201
+ Passing `use_auth_token=True` is required when you want to use a private model.
202
+
203
+ </Tip>
204
+
205
+ Returns:
206
+ `str`: The path to the module inside the cache.
207
+ """
208
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
209
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
210
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
211
+ submodule = "local"
212
+
213
+ if os.path.isfile(module_file_or_url):
214
+ resolved_module_file = module_file_or_url
215
+ else:
216
+ try:
217
+ # Load from URL or cache if already cached
218
+ resolved_module_file = cached_download(
219
+ module_file_or_url,
220
+ cache_dir=cache_dir,
221
+ force_download=force_download,
222
+ proxies=proxies,
223
+ resume_download=resume_download,
224
+ local_files_only=local_files_only,
225
+ use_auth_token=use_auth_token,
226
+ )
227
+
228
+ except EnvironmentError:
229
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
230
+ raise
231
+
232
+ # Check we have all the requirements in our environment
233
+ modules_needed = check_imports(resolved_module_file)
234
+
235
+ # Now we move the module inside our cached dynamic modules.
236
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
237
+ create_dynamic_module(full_submodule)
238
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
239
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
240
+ # that hash, to only copy when there is a modification but it seems overkill for now).
241
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
242
+ shutil.copy(resolved_module_file, submodule_path / module_file)
243
+ for module_needed in modules_needed:
244
+ module_needed = f"{module_needed}.py"
245
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
246
+ return os.path.join(full_submodule, module_file)
247
+
248
+
249
+ def get_class_from_dynamic_module(
250
+ pretrained_model_name_or_path: Union[str, os.PathLike],
251
+ module_file: str,
252
+ class_name: str,
253
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
254
+ force_download: bool = False,
255
+ resume_download: bool = False,
256
+ proxies: Optional[Dict[str, str]] = None,
257
+ use_auth_token: Optional[Union[bool, str]] = None,
258
+ revision: Optional[str] = None,
259
+ local_files_only: bool = False,
260
+ **kwargs,
261
+ ):
262
+ """
263
+ Extracts a class from a module file, present in the local folder or repository of a model.
264
+
265
+ <Tip warning={true}>
266
+
267
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
268
+ therefore only be called on trusted repos.
269
+
270
+ </Tip>
271
+
272
+ Args:
273
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
274
+ This can be either:
275
+
276
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
277
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
278
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
279
+ - a path to a *directory* containing a configuration file saved using the
280
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
281
+
282
+ module_file (`str`):
283
+ The name of the module file containing the class to look for.
284
+ class_name (`str`):
285
+ The name of the class to import in the module.
286
+ cache_dir (`str` or `os.PathLike`, *optional*):
287
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
288
+ cache should not be used.
289
+ force_download (`bool`, *optional*, defaults to `False`):
290
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
291
+ exist.
292
+ resume_download (`bool`, *optional*, defaults to `False`):
293
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
294
+ proxies (`Dict[str, str]`, *optional*):
295
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
296
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
297
+ use_auth_token (`str` or `bool`, *optional*):
298
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
299
+ when running `transformers-cli login` (stored in `~/.huggingface`).
300
+ revision (`str`, *optional*, defaults to `"main"`):
301
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
302
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
303
+ identifier allowed by git.
304
+ local_files_only (`bool`, *optional*, defaults to `False`):
305
+ If `True`, will only try to load the tokenizer configuration from local files.
306
+
307
+ <Tip>
308
+
309
+ Passing `use_auth_token=True` is required when you want to use a private model.
310
+
311
+ </Tip>
312
+
313
+ Returns:
314
+ `type`: The class, dynamically imported from the module.
315
+
316
+ Examples:
317
+
318
+ ```python
319
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
320
+ # module.
321
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
322
+ ```"""
323
+ # And lastly we get the class inside our newly created module
324
+ final_module = get_cached_module_file(
325
+ pretrained_model_name_or_path,
326
+ module_file,
327
+ cache_dir=cache_dir,
328
+ force_download=force_download,
329
+ resume_download=resume_download,
330
+ proxies=proxies,
331
+ use_auth_token=use_auth_token,
332
+ revision=revision,
333
+ local_files_only=local_files_only,
334
+ )
335
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
diffusers/hub_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import shutil
19
+ from pathlib import Path
20
+ from typing import Optional
21
+
22
+ from huggingface_hub import HfFolder, Repository, whoami
23
+
24
+ from .pipeline_utils import DiffusionPipeline
25
+ from .utils import is_modelcards_available, logging
26
+
27
+
28
+ if is_modelcards_available():
29
+ from modelcards import CardData, ModelCard
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
36
+
37
+
38
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
39
+ if token is None:
40
+ token = HfFolder.get_token()
41
+ if organization is None:
42
+ username = whoami(token)["name"]
43
+ return f"{username}/{model_id}"
44
+ else:
45
+ return f"{organization}/{model_id}"
46
+
47
+
48
+ def init_git_repo(args, at_init: bool = False):
49
+ """
50
+ Args:
51
+ Initializes a git repo in `args.hub_model_id`.
52
+ at_init (`bool`, *optional*, defaults to `False`):
53
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
54
+ and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
55
+ """
56
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
57
+ return
58
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
59
+ use_auth_token = True if hub_token is None else hub_token
60
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
61
+ repo_name = Path(args.output_dir).absolute().name
62
+ else:
63
+ repo_name = args.hub_model_id
64
+ if "/" not in repo_name:
65
+ repo_name = get_full_repo_name(repo_name, token=hub_token)
66
+
67
+ try:
68
+ repo = Repository(
69
+ args.output_dir,
70
+ clone_from=repo_name,
71
+ use_auth_token=use_auth_token,
72
+ private=args.hub_private_repo,
73
+ )
74
+ except EnvironmentError:
75
+ if args.overwrite_output_dir and at_init:
76
+ # Try again after wiping output_dir
77
+ shutil.rmtree(args.output_dir)
78
+ repo = Repository(
79
+ args.output_dir,
80
+ clone_from=repo_name,
81
+ use_auth_token=use_auth_token,
82
+ )
83
+ else:
84
+ raise
85
+
86
+ repo.git_pull()
87
+
88
+ # By default, ignore the checkpoint folders
89
+ if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
90
+ with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
91
+ writer.writelines(["checkpoint-*/"])
92
+
93
+ return repo
94
+
95
+
96
+ def push_to_hub(
97
+ args,
98
+ pipeline: DiffusionPipeline,
99
+ repo: Repository,
100
+ commit_message: Optional[str] = "End of training",
101
+ blocking: bool = True,
102
+ **kwargs,
103
+ ) -> str:
104
+ """
105
+ Parameters:
106
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
107
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
108
+ Message to commit while pushing.
109
+ blocking (`bool`, *optional*, defaults to `True`):
110
+ Whether the function should return only when the `git push` has finished.
111
+ kwargs:
112
+ Additional keyword arguments passed along to [`create_model_card`].
113
+ Returns:
114
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
115
+ commit and an object to track the progress of the commit if `blocking=True`
116
+ """
117
+
118
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
119
+ model_name = Path(args.output_dir).name
120
+ else:
121
+ model_name = args.hub_model_id.split("/")[-1]
122
+
123
+ output_dir = args.output_dir
124
+ os.makedirs(output_dir, exist_ok=True)
125
+ logger.info(f"Saving pipeline checkpoint to {output_dir}")
126
+ pipeline.save_pretrained(output_dir)
127
+
128
+ # Only push from one node.
129
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
130
+ return
131
+
132
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
133
+ if (
134
+ blocking
135
+ and len(repo.command_queue) > 0
136
+ and repo.command_queue[-1] is not None
137
+ and not repo.command_queue[-1].is_done
138
+ ):
139
+ repo.command_queue[-1]._process.kill()
140
+
141
+ git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
142
+ # push separately the model card to be independent from the rest of the model
143
+ create_model_card(args, model_name=model_name)
144
+ try:
145
+ repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
146
+ except EnvironmentError as exc:
147
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
148
+
149
+ return git_head_commit_url
150
+
151
+
152
+ def create_model_card(args, model_name):
153
+ if not is_modelcards_available:
154
+ raise ValueError(
155
+ "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
156
+ " install the package with `pip install modelcards`."
157
+ )
158
+
159
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
160
+ return
161
+
162
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
163
+ repo_name = get_full_repo_name(model_name, token=hub_token)
164
+
165
+ model_card = ModelCard.from_template(
166
+ card_data=CardData( # Card metadata object that will be converted to YAML block
167
+ language="en",
168
+ license="apache-2.0",
169
+ library_name="diffusers",
170
+ tags=[],
171
+ datasets=args.dataset_name,
172
+ metrics=[],
173
+ ),
174
+ template_path=MODEL_CARD_TEMPLATE_PATH,
175
+ model_name=model_name,
176
+ repo_name=repo_name,
177
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
178
+ learning_rate=args.learning_rate,
179
+ train_batch_size=args.train_batch_size,
180
+ eval_batch_size=args.eval_batch_size,
181
+ gradient_accumulation_steps=args.gradient_accumulation_steps
182
+ if hasattr(args, "gradient_accumulation_steps")
183
+ else None,
184
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
185
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
186
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
187
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
188
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
189
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
190
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
191
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
192
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
193
+ mixed_precision=args.mixed_precision,
194
+ )
195
+
196
+ card_path = os.path.join(args.output_dir, "README.md")
197
+ model_card.save(card_path)
diffusers/modeling_utils.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ from typing import Callable, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from torch import Tensor, device
22
+
23
+ from huggingface_hub import hf_hub_download
24
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
25
+ from requests import HTTPError
26
+
27
+ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
28
+
29
+
30
+ WEIGHTS_NAME = "diffusion_pytorch_model.bin"
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ def get_parameter_device(parameter: torch.nn.Module):
37
+ try:
38
+ return next(parameter.parameters()).device
39
+ except StopIteration:
40
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
41
+
42
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
43
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
44
+ return tuples
45
+
46
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
47
+ first_tuple = next(gen)
48
+ return first_tuple[1].device
49
+
50
+
51
+ def get_parameter_dtype(parameter: torch.nn.Module):
52
+ try:
53
+ return next(parameter.parameters()).dtype
54
+ except StopIteration:
55
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
56
+
57
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
58
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
59
+ return tuples
60
+
61
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
62
+ first_tuple = next(gen)
63
+ return first_tuple[1].dtype
64
+
65
+
66
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
67
+ """
68
+ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
69
+ """
70
+ try:
71
+ return torch.load(checkpoint_file, map_location="cpu")
72
+ except Exception as e:
73
+ try:
74
+ with open(checkpoint_file) as f:
75
+ if f.read().startswith("version"):
76
+ raise OSError(
77
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
78
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
79
+ "you cloned."
80
+ )
81
+ else:
82
+ raise ValueError(
83
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
84
+ "model. Make sure you have saved the model properly."
85
+ ) from e
86
+ except (UnicodeDecodeError, ValueError):
87
+ raise OSError(
88
+ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
89
+ f"at '{checkpoint_file}'. "
90
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
91
+ )
92
+
93
+
94
+ def _load_state_dict_into_model(model_to_load, state_dict):
95
+ # Convert old format to new format if needed from a PyTorch state_dict
96
+ # copy state_dict so _load_from_state_dict can modify it
97
+ state_dict = state_dict.copy()
98
+ error_msgs = []
99
+
100
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
101
+ # so we need to apply the function recursively.
102
+ def load(module: torch.nn.Module, prefix=""):
103
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
104
+ module._load_from_state_dict(*args)
105
+
106
+ for name, child in module._modules.items():
107
+ if child is not None:
108
+ load(child, prefix + name + ".")
109
+
110
+ load(model_to_load)
111
+
112
+ return error_msgs
113
+
114
+
115
+ class ModelMixin(torch.nn.Module):
116
+ r"""
117
+ Base class for all models.
118
+
119
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
120
+ and saving models.
121
+
122
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
123
+ [`~modeling_utils.ModelMixin.save_pretrained`].
124
+ """
125
+ config_name = CONFIG_NAME
126
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
127
+
128
+ def __init__(self):
129
+ super().__init__()
130
+
131
+ def save_pretrained(
132
+ self,
133
+ save_directory: Union[str, os.PathLike],
134
+ is_main_process: bool = True,
135
+ save_function: Callable = torch.save,
136
+ ):
137
+ """
138
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
139
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
140
+
141
+ Arguments:
142
+ save_directory (`str` or `os.PathLike`):
143
+ Directory to which to save. Will be created if it doesn't exist.
144
+ is_main_process (`bool`, *optional*, defaults to `True`):
145
+ Whether the process calling this is the main process or not. Useful when in distributed training like
146
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
147
+ the main process to avoid race conditions.
148
+ save_function (`Callable`):
149
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
150
+ need to replace `torch.save` by another method.
151
+ """
152
+ if os.path.isfile(save_directory):
153
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
154
+ return
155
+
156
+ os.makedirs(save_directory, exist_ok=True)
157
+
158
+ model_to_save = self
159
+
160
+ # Attach architecture to the config
161
+ # Save the config
162
+ if is_main_process:
163
+ model_to_save.save_config(save_directory)
164
+
165
+ # Save the model
166
+ state_dict = model_to_save.state_dict()
167
+
168
+ # Clean the folder from a previous save
169
+ for filename in os.listdir(save_directory):
170
+ full_filename = os.path.join(save_directory, filename)
171
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
172
+ # in distributed settings to avoid race conditions.
173
+ if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
174
+ os.remove(full_filename)
175
+
176
+ # Save the model
177
+ save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
178
+
179
+ logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
180
+
181
+ @classmethod
182
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
183
+ r"""
184
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
185
+
186
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
187
+ the model, you should first set it back in training mode with `model.train()`.
188
+
189
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
190
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
191
+ task.
192
+
193
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
194
+ weights are discarded.
195
+
196
+ Parameters:
197
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
198
+ Can be either:
199
+
200
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
201
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
202
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
203
+ `./my_model_directory/`.
204
+
205
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
206
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
207
+ standard cache should not be used.
208
+ torch_dtype (`str` or `torch.dtype`, *optional*):
209
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
210
+ will be automatically derived from the model's weights.
211
+ force_download (`bool`, *optional*, defaults to `False`):
212
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
213
+ cached versions if they exist.
214
+ resume_download (`bool`, *optional*, defaults to `False`):
215
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
216
+ file exists.
217
+ proxies (`Dict[str, str]`, *optional*):
218
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
219
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
220
+ output_loading_info(`bool`, *optional*, defaults to `False`):
221
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
222
+ local_files_only(`bool`, *optional*, defaults to `False`):
223
+ Whether or not to only look at local files (i.e., do not try to download the model).
224
+ use_auth_token (`str` or *bool*, *optional*):
225
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
226
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
227
+ revision (`str`, *optional*, defaults to `"main"`):
228
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
229
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
230
+ identifier allowed by git.
231
+ mirror (`str`, *optional*):
232
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
233
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
234
+ Please refer to the mirror site for more information.
235
+
236
+ <Tip>
237
+
238
+ Passing `use_auth_token=True`` is required when you want to use a private model.
239
+
240
+ </Tip>
241
+
242
+ <Tip>
243
+
244
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
245
+ this method in a firewalled environment.
246
+
247
+ </Tip>
248
+
249
+ """
250
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
251
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
252
+ force_download = kwargs.pop("force_download", False)
253
+ resume_download = kwargs.pop("resume_download", False)
254
+ proxies = kwargs.pop("proxies", None)
255
+ output_loading_info = kwargs.pop("output_loading_info", False)
256
+ local_files_only = kwargs.pop("local_files_only", False)
257
+ use_auth_token = kwargs.pop("use_auth_token", None)
258
+ revision = kwargs.pop("revision", None)
259
+ from_auto_class = kwargs.pop("_from_auto", False)
260
+ torch_dtype = kwargs.pop("torch_dtype", None)
261
+ subfolder = kwargs.pop("subfolder", None)
262
+
263
+ user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
264
+
265
+ # Load config if we don't provide a configuration
266
+ config_path = pretrained_model_name_or_path
267
+ model, unused_kwargs = cls.from_config(
268
+ config_path,
269
+ cache_dir=cache_dir,
270
+ return_unused_kwargs=True,
271
+ force_download=force_download,
272
+ resume_download=resume_download,
273
+ proxies=proxies,
274
+ local_files_only=local_files_only,
275
+ use_auth_token=use_auth_token,
276
+ revision=revision,
277
+ subfolder=subfolder,
278
+ **kwargs,
279
+ )
280
+
281
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
282
+ raise ValueError(
283
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
284
+ )
285
+ elif torch_dtype is not None:
286
+ model = model.to(torch_dtype)
287
+
288
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
289
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
290
+ # Load model
291
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
292
+ if os.path.isdir(pretrained_model_name_or_path):
293
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
294
+ # Load from a PyTorch checkpoint
295
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
296
+ elif subfolder is not None and os.path.isfile(
297
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
298
+ ):
299
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
300
+ else:
301
+ raise EnvironmentError(
302
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
303
+ )
304
+ else:
305
+ try:
306
+ # Load from URL or cache if already cached
307
+ model_file = hf_hub_download(
308
+ pretrained_model_name_or_path,
309
+ filename=WEIGHTS_NAME,
310
+ cache_dir=cache_dir,
311
+ force_download=force_download,
312
+ proxies=proxies,
313
+ resume_download=resume_download,
314
+ local_files_only=local_files_only,
315
+ use_auth_token=use_auth_token,
316
+ user_agent=user_agent,
317
+ subfolder=subfolder,
318
+ revision=revision,
319
+ )
320
+
321
+ except RepositoryNotFoundError:
322
+ raise EnvironmentError(
323
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
324
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
325
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
326
+ "login` and pass `use_auth_token=True`."
327
+ )
328
+ except RevisionNotFoundError:
329
+ raise EnvironmentError(
330
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
331
+ "this model name. Check the model page at "
332
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
333
+ )
334
+ except EntryNotFoundError:
335
+ raise EnvironmentError(
336
+ f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
337
+ )
338
+ except HTTPError as err:
339
+ raise EnvironmentError(
340
+ "There was a specific connection error when trying to load"
341
+ f" {pretrained_model_name_or_path}:\n{err}"
342
+ )
343
+ except ValueError:
344
+ raise EnvironmentError(
345
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
346
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
347
+ f" directory containing a file named {WEIGHTS_NAME} or"
348
+ " \nCheckout your internet connection or see how to run the library in"
349
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
350
+ )
351
+ except EnvironmentError:
352
+ raise EnvironmentError(
353
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
354
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
355
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
356
+ f"containing a file named {WEIGHTS_NAME}"
357
+ )
358
+
359
+ # restore default dtype
360
+ state_dict = load_state_dict(model_file)
361
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
362
+ model,
363
+ state_dict,
364
+ model_file,
365
+ pretrained_model_name_or_path,
366
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
367
+ )
368
+
369
+ # Set model in evaluation mode to deactivate DropOut modules by default
370
+ model.eval()
371
+
372
+ if output_loading_info:
373
+ loading_info = {
374
+ "missing_keys": missing_keys,
375
+ "unexpected_keys": unexpected_keys,
376
+ "mismatched_keys": mismatched_keys,
377
+ "error_msgs": error_msgs,
378
+ }
379
+ return model, loading_info
380
+
381
+ return model
382
+
383
+ @classmethod
384
+ def _load_pretrained_model(
385
+ cls,
386
+ model,
387
+ state_dict,
388
+ resolved_archive_file,
389
+ pretrained_model_name_or_path,
390
+ ignore_mismatched_sizes=False,
391
+ ):
392
+ # Retrieve missing & unexpected_keys
393
+ model_state_dict = model.state_dict()
394
+ loaded_keys = [k for k in state_dict.keys()]
395
+
396
+ expected_keys = list(model_state_dict.keys())
397
+
398
+ original_loaded_keys = loaded_keys
399
+
400
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
401
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
402
+
403
+ # Make sure we are able to load base models as well as derived models (with heads)
404
+ model_to_load = model
405
+
406
+ def _find_mismatched_keys(
407
+ state_dict,
408
+ model_state_dict,
409
+ loaded_keys,
410
+ ignore_mismatched_sizes,
411
+ ):
412
+ mismatched_keys = []
413
+ if ignore_mismatched_sizes:
414
+ for checkpoint_key in loaded_keys:
415
+ model_key = checkpoint_key
416
+
417
+ if (
418
+ model_key in model_state_dict
419
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
420
+ ):
421
+ mismatched_keys.append(
422
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
423
+ )
424
+ del state_dict[checkpoint_key]
425
+ return mismatched_keys
426
+
427
+ if state_dict is not None:
428
+ # Whole checkpoint
429
+ mismatched_keys = _find_mismatched_keys(
430
+ state_dict,
431
+ model_state_dict,
432
+ original_loaded_keys,
433
+ ignore_mismatched_sizes,
434
+ )
435
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
436
+
437
+ if len(error_msgs) > 0:
438
+ error_msg = "\n\t".join(error_msgs)
439
+ if "size mismatch" in error_msg:
440
+ error_msg += (
441
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
442
+ )
443
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
444
+
445
+ if len(unexpected_keys) > 0:
446
+ logger.warning(
447
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
448
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
449
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
450
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
451
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
452
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
453
+ " identical (initializing a BertForSequenceClassification model from a"
454
+ " BertForSequenceClassification model)."
455
+ )
456
+ else:
457
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
458
+ if len(missing_keys) > 0:
459
+ logger.warning(
460
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
461
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
462
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
463
+ )
464
+ elif len(mismatched_keys) == 0:
465
+ logger.info(
466
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
467
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
468
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
469
+ " without further training."
470
+ )
471
+ if len(mismatched_keys) > 0:
472
+ mismatched_warning = "\n".join(
473
+ [
474
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
475
+ for key, shape1, shape2 in mismatched_keys
476
+ ]
477
+ )
478
+ logger.warning(
479
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
480
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
481
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
482
+ " able to use it for predictions and inference."
483
+ )
484
+
485
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
486
+
487
+ @property
488
+ def device(self) -> device:
489
+ """
490
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
491
+ device).
492
+ """
493
+ return get_parameter_device(self)
494
+
495
+ @property
496
+ def dtype(self) -> torch.dtype:
497
+ """
498
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
499
+ """
500
+ return get_parameter_dtype(self)
501
+
502
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
503
+ """
504
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
505
+
506
+ Args:
507
+ only_trainable (`bool`, *optional*, defaults to `False`):
508
+ Whether or not to return only the number of trainable parameters
509
+
510
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
511
+ Whether or not to return only the number of non-embeddings parameters
512
+
513
+ Returns:
514
+ `int`: The number of parameters.
515
+ """
516
+
517
+ if exclude_embeddings:
518
+ embedding_param_names = [
519
+ f"{name}.weight"
520
+ for name, module_type in self.named_modules()
521
+ if isinstance(module_type, torch.nn.Embedding)
522
+ ]
523
+ non_embedding_parameters = [
524
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
525
+ ]
526
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
527
+ else:
528
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
529
+
530
+
531
+ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
532
+ """
533
+ Recursively unwraps a model from potential containers (as used in distributed training).
534
+
535
+ Args:
536
+ model (`torch.nn.Module`): The model to unwrap.
537
+ """
538
+ # since there could be multiple levels of wrapping, unwrap recursively
539
+ if hasattr(model, "module"):
540
+ return unwrap_model(model.module)
541
+ else:
542
+ return model
diffusers/models/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .unet_2d import UNet2DModel
16
+ from .unet_2d_condition import UNet2DConditionModel
17
+ from .vae import AutoencoderKL, VQModel
diffusers/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (313 Bytes). View file
 
diffusers/models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (372 Bytes). View file
 
diffusers/models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
diffusers/models/__pycache__/attention.cpython-37.pyc ADDED
Binary file (14.2 kB). View file
 
diffusers/models/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (3.72 kB). View file
 
diffusers/models/__pycache__/embeddings.cpython-37.pyc ADDED
Binary file (3.71 kB). View file
 
diffusers/models/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
diffusers/models/__pycache__/resnet.cpython-37.pyc ADDED
Binary file (14.9 kB). View file
 
diffusers/models/__pycache__/unet_2d.cpython-310.pyc ADDED
Binary file (7.94 kB). View file
 
diffusers/models/__pycache__/unet_2d.cpython-37.pyc ADDED
Binary file (7.84 kB). View file
 
diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
diffusers/models/__pycache__/unet_2d_condition.cpython-37.pyc ADDED
Binary file (8.63 kB). View file
 
diffusers/models/__pycache__/unet_blocks.cpython-310.pyc ADDED
Binary file (23.7 kB). View file
 
diffusers/models/__pycache__/unet_blocks.cpython-37.pyc ADDED
Binary file (25.5 kB). View file
 
diffusers/models/__pycache__/vae.cpython-310.pyc ADDED
Binary file (16.5 kB). View file
 
diffusers/models/__pycache__/vae.cpython-37.pyc ADDED
Binary file (16.5 kB). View file