HiPer0 commited on
Commit
325aac4
1 Parent(s): 5a2a4e4

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +99 -0
  2. src/diffusers_/__init__.py +15 -0
  3. src/diffusers_/__pycache__/__init__.cpython-310.pyc +0 -0
  4. src/diffusers_/__pycache__/__init__.cpython-37.pyc +0 -0
  5. src/diffusers_/__pycache__/__init__.cpython-38.pyc +0 -0
  6. src/diffusers_/__pycache__/configuration_utils.cpython-310.pyc +0 -0
  7. src/diffusers_/__pycache__/configuration_utils.cpython-37.pyc +0 -0
  8. src/diffusers_/__pycache__/configuration_utils.cpython-38.pyc +0 -0
  9. src/diffusers_/__pycache__/dependency_versions_check.cpython-38.pyc +0 -0
  10. src/diffusers_/__pycache__/dependency_versions_table.cpython-38.pyc +0 -0
  11. src/diffusers_/__pycache__/dynamic_modules_utils.cpython-310.pyc +0 -0
  12. src/diffusers_/__pycache__/dynamic_modules_utils.cpython-37.pyc +0 -0
  13. src/diffusers_/__pycache__/dynamic_modules_utils.cpython-38.pyc +0 -0
  14. src/diffusers_/__pycache__/hub_utils.cpython-310.pyc +0 -0
  15. src/diffusers_/__pycache__/hub_utils.cpython-37.pyc +0 -0
  16. src/diffusers_/__pycache__/hub_utils.cpython-38.pyc +0 -0
  17. src/diffusers_/__pycache__/modeling_flax_pytorch_utils.cpython-38.pyc +0 -0
  18. src/diffusers_/__pycache__/modeling_flax_utils.cpython-38.pyc +0 -0
  19. src/diffusers_/__pycache__/modeling_utils.cpython-310.pyc +0 -0
  20. src/diffusers_/__pycache__/modeling_utils.cpython-37.pyc +0 -0
  21. src/diffusers_/__pycache__/modeling_utils.cpython-38.pyc +0 -0
  22. src/diffusers_/__pycache__/onnx_utils.cpython-37.pyc +0 -0
  23. src/diffusers_/__pycache__/onnx_utils.cpython-38.pyc +0 -0
  24. src/diffusers_/__pycache__/optimization.cpython-37.pyc +0 -0
  25. src/diffusers_/__pycache__/optimization.cpython-38.pyc +0 -0
  26. src/diffusers_/__pycache__/pipeline_flax_utils.cpython-38.pyc +0 -0
  27. src/diffusers_/__pycache__/pipeline_utils.cpython-310.pyc +0 -0
  28. src/diffusers_/__pycache__/pipeline_utils.cpython-37.pyc +0 -0
  29. src/diffusers_/__pycache__/pipeline_utils.cpython-38.pyc +0 -0
  30. src/diffusers_/__pycache__/scheduling_utils.cpython-310.pyc +0 -0
  31. src/diffusers_/__pycache__/scheduling_utils.cpython-38.pyc +0 -0
  32. src/diffusers_/__pycache__/training_utils.cpython-37.pyc +0 -0
  33. src/diffusers_/__pycache__/training_utils.cpython-38.pyc +0 -0
  34. src/diffusers_/configuration_utils.py +605 -0
  35. src/diffusers_/dynamic_modules_utils.py +428 -0
  36. src/diffusers_/hub_utils.py +246 -0
  37. src/diffusers_/modeling_utils.py +693 -0
  38. src/diffusers_/pipeline_utils.py +755 -0
  39. src/diffusers_/scheduling_utils.py +154 -0
  40. src/diffusers_/stable_diffusion/__init__.py +35 -0
  41. src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  42. src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-37.pyc +0 -0
  43. src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  44. src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-37.pyc +0 -0
  45. src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc +0 -0
  46. src/diffusers_/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-38.pyc +0 -0
  47. src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-38.pyc +0 -0
  48. src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-38.pyc +0 -0
  49. src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-38.pyc +0 -0
  50. src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-38.pyc +0 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ from src.utils.gradio_utils import *
7
+
8
+
9
+ if __name__=="__main__":
10
+ step_dict = {'800': 800, '900': 900, '1000': 1000, '1100': 1100}
11
+
12
+ with gr.Blocks(css=CSS_main) as demo:
13
+ gr.HTML(HTML_header)
14
+
15
+ with gr.Row():
16
+ # col A: Optimize personalized embedding
17
+ with gr.Column(scale=2) as gc_left:
18
+ gr.HTML(" <center> <p style='font-size:150%;'> [Step 1] Optimize personalized embedding </p> </center>")
19
+ img_in_real = gr.Image(type="pil", label="Start by uploading the source image", elem_id="input_image").style(height=300, width=300)
20
+ gr.Examples( examples="src_image", inputs=[img_in_real])
21
+ prompt = gr.Textbox(value="a standing dog", label="Source text prompt (Describe the source image)", interactive=True)
22
+ n_hiper = gr.Slider(5, 10, 5, label="Number of personalized embedding (Tips! 5 for animals / 10 for humans)", interactive=True, step=1)
23
+
24
+ btn_optimize = gr.Button("Optimize", label="")
25
+ fpath_z_gen = gr.Textbox(value="placeholder", visible=False)
26
+
27
+ gr.HTML(" <center> <p style='font-size:150%;'> See the [Step 1] results with different optimization steps </p> </center>")
28
+ with gr.Row():
29
+ with gr.Column(scale=0.3, min_width=0.7) as gc_left:
30
+ # btn_source = gr.Button("Source image", label="")
31
+ btn_opt_step800 = gr.Button("Step 800", label="")
32
+ btn_opt_step900 = gr.Button("Step 900", label="")
33
+ btn_opt_step1000 = gr.Button("Step 1000", label="")
34
+ btn_opt_step1100 = gr.Button("Step 1100", label="")
35
+ with gr.Column(scale=0.5, min_width=0.8) as gc_left:
36
+ img_src = gr.Image(type="pil", label="Source image", visible=True).style(height=250, width=250)
37
+ with gr.Column(scale=0.5, min_width=0.8) as gc_left:
38
+ img_out_opt = gr.Image(type="pil", label="Optimization step output", visible=True).style(height=250, width=250)
39
+
40
+
41
+ # col B: Generate target image
42
+ with gr.Column(scale=2) as gc_left:
43
+
44
+ gr.HTML(" <center> <p style='font-size:150%;'> [Step 2] Generate target image </p> </center>")
45
+ with gr.Row():
46
+
47
+ with gr.Column():
48
+ dest = gr.Textbox(value="a sitting dog", label="Target text prompt", interactive=True)
49
+ step = gr.Radio(["Step 800", "Step 900", "Step 1000", "Step 1100"], value="Step 1000", label="Training optimization step \n (Refer to the personalized results corresponding to each optimization step listed in the left column.)")
50
+ seed = gr.Number(value=111111, label="Random seed", interactive=True)
51
+ with gr.Row():
52
+ btn_generate = gr.Button("Generate", label="")
53
+ img_out = gr.Image(type="pil", label="Output Image", visible=True)
54
+
55
+ with gr.Accordion("Instruction", open=True):
56
+ gr.Textbox("In NVIDIA GeForce GTX 3090, [step 1] takes about 4 minutes and [step 2] takes about 1 minute.", show_label=False)
57
+ gr.Textbox("At [step 1], put the desired source image and write the source text that describes the source image. If it is difficult to describe, you can use a noun such as 'a dog' or 'a woman.' Then decide on the number of desired personalized embeddings.", show_label=False)
58
+ gr.Textbox("After [step 1], you can check the personalized results with different optimization steps and select the optimization step. First, check if the image at step 1000 has a subject similar to the source image. In the paper, we use the 1000 step for optimization almost.", show_label=False)
59
+ gr.Textbox("At [step 2], write the derised target text. Then, refer to the generated personalized image in the bottom left and choose an optimization. If the desired image is not obtained, try another random seed.", show_label=False)
60
+
61
+
62
+ ############
63
+ btn_optimize.click(launch_optimize, [img_in_real, prompt, n_hiper], [fpath_z_gen, img_src])
64
+ def fn_set_none():
65
+ return gr.update(value=None)
66
+ btn_optimize.click(fn_set_none, [], img_in_real)
67
+ # btn_optimize.click(set_visible_true, [], img_in_synth)
68
+ btn_optimize.click(set_visible_false, [], img_in_real)
69
+
70
+
71
+ ############
72
+ def fn_clear_all():
73
+ return gr.update(value=None), gr.update(value=None), gr.update(value=None)
74
+
75
+ img_in_real.clear(fn_clear_all, [], [img_out, img_in_real])#, img_in_synth])
76
+ # img_in_real.clear(set_visible_true, [], img_in_synth)
77
+ img_in_real.clear(set_visible_false, [], img_in_real)
78
+
79
+ img_out.clear(fn_clear_all, [], [img_out, img_in_real])#, img_in_synth])
80
+
81
+
82
+ ############
83
+ btn_generate.click(launch_main,
84
+ [
85
+ dest, step,
86
+ fpath_z_gen, seed,
87
+ ],
88
+ [img_out]
89
+ )
90
+ ############
91
+ btn_opt_step800.click(launch_opt800, [],[img_out_opt])
92
+ btn_opt_step900.click(launch_opt900, [],[img_out_opt])
93
+ btn_opt_step1000.click(launch_opt1000, [],[img_out_opt])
94
+ btn_opt_step1100.click(launch_opt1100, [],[img_out_opt])
95
+ gr.HTML("<hr>")
96
+
97
+ gr.close_all()
98
+ demo.queue(concurrency_count=1)
99
+ demo.launch(server_port=2222, server_name="0.0.0.0", debug=True,share=True)
src/diffusers_/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (
2
+ is_torch_available,
3
+ is_transformers_available,
4
+ )
5
+
6
+
7
+ __version__ = "0.9.0"
8
+
9
+
10
+ if is_torch_available() and is_transformers_available():
11
+ from .stable_diffusion import (
12
+ StableDiffusionPipeline,
13
+ )
14
+ else:
15
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
src/diffusers_/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (383 Bytes). View file
 
src/diffusers_/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (3.7 kB). View file
 
src/diffusers_/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (398 Bytes). View file
 
src/diffusers_/__pycache__/configuration_utils.cpython-310.pyc ADDED
Binary file (21.3 kB). View file
 
src/diffusers_/__pycache__/configuration_utils.cpython-37.pyc ADDED
Binary file (21.7 kB). View file
 
src/diffusers_/__pycache__/configuration_utils.cpython-38.pyc ADDED
Binary file (21.5 kB). View file
 
src/diffusers_/__pycache__/dependency_versions_check.cpython-38.pyc ADDED
Binary file (960 Bytes). View file
 
src/diffusers_/__pycache__/dependency_versions_table.cpython-38.pyc ADDED
Binary file (927 Bytes). View file
 
src/diffusers_/__pycache__/dynamic_modules_utils.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
src/diffusers_/__pycache__/dynamic_modules_utils.cpython-37.pyc ADDED
Binary file (13.3 kB). View file
 
src/diffusers_/__pycache__/dynamic_modules_utils.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
src/diffusers_/__pycache__/hub_utils.cpython-310.pyc ADDED
Binary file (7.15 kB). View file
 
src/diffusers_/__pycache__/hub_utils.cpython-37.pyc ADDED
Binary file (6.92 kB). View file
 
src/diffusers_/__pycache__/hub_utils.cpython-38.pyc ADDED
Binary file (6.99 kB). View file
 
src/diffusers_/__pycache__/modeling_flax_pytorch_utils.cpython-38.pyc ADDED
Binary file (2.64 kB). View file
 
src/diffusers_/__pycache__/modeling_flax_utils.cpython-38.pyc ADDED
Binary file (20.7 kB). View file
 
src/diffusers_/__pycache__/modeling_utils.cpython-310.pyc ADDED
Binary file (23.7 kB). View file
 
src/diffusers_/__pycache__/modeling_utils.cpython-37.pyc ADDED
Binary file (23.8 kB). View file
 
src/diffusers_/__pycache__/modeling_utils.cpython-38.pyc ADDED
Binary file (23.9 kB). View file
 
src/diffusers_/__pycache__/onnx_utils.cpython-37.pyc ADDED
Binary file (6.73 kB). View file
 
src/diffusers_/__pycache__/onnx_utils.cpython-38.pyc ADDED
Binary file (6.85 kB). View file
 
src/diffusers_/__pycache__/optimization.cpython-37.pyc ADDED
Binary file (10.2 kB). View file
 
src/diffusers_/__pycache__/optimization.cpython-38.pyc ADDED
Binary file (10.2 kB). View file
 
src/diffusers_/__pycache__/pipeline_flax_utils.cpython-38.pyc ADDED
Binary file (15.7 kB). View file
 
src/diffusers_/__pycache__/pipeline_utils.cpython-310.pyc ADDED
Binary file (26.1 kB). View file
 
src/diffusers_/__pycache__/pipeline_utils.cpython-37.pyc ADDED
Binary file (26.1 kB). View file
 
src/diffusers_/__pycache__/pipeline_utils.cpython-38.pyc ADDED
Binary file (26.2 kB). View file
 
src/diffusers_/__pycache__/scheduling_utils.cpython-310.pyc ADDED
Binary file (6.87 kB). View file
 
src/diffusers_/__pycache__/scheduling_utils.cpython-38.pyc ADDED
Binary file (6.85 kB). View file
 
src/diffusers_/__pycache__/training_utils.cpython-37.pyc ADDED
Binary file (3.59 kB). View file
 
src/diffusers_/__pycache__/training_utils.cpython-38.pyc ADDED
Binary file (3.63 kB). View file
 
src/diffusers_/configuration_utils.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import dataclasses
18
+ import functools
19
+ import importlib
20
+ import inspect
21
+ import json
22
+ import os
23
+ import re
24
+ from collections import OrderedDict
25
+ from typing import Any, Dict, Tuple, Union
26
+
27
+ from huggingface_hub import hf_hub_download
28
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
29
+ from requests import HTTPError
30
+
31
+ from . import __version__
32
+ from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
38
+
39
+
40
+ class FrozenDict(OrderedDict):
41
+ def __init__(self, *args, **kwargs):
42
+ super().__init__(*args, **kwargs)
43
+
44
+ for key, value in self.items():
45
+ setattr(self, key, value)
46
+
47
+ self.__frozen = True
48
+
49
+ def __delitem__(self, *args, **kwargs):
50
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
51
+
52
+ def setdefault(self, *args, **kwargs):
53
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
54
+
55
+ def pop(self, *args, **kwargs):
56
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
57
+
58
+ def update(self, *args, **kwargs):
59
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
60
+
61
+ def __setattr__(self, name, value):
62
+ if hasattr(self, "__frozen") and self.__frozen:
63
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
64
+ super().__setattr__(name, value)
65
+
66
+ def __setitem__(self, name, value):
67
+ if hasattr(self, "__frozen") and self.__frozen:
68
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
69
+ super().__setitem__(name, value)
70
+
71
+
72
+ class ConfigMixin:
73
+ r"""
74
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
75
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
76
+ - [`~ConfigMixin.from_config`]
77
+ - [`~ConfigMixin.save_config`]
78
+
79
+ Class attributes:
80
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
81
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
82
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
83
+ overridden by subclass).
84
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
85
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
86
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
87
+ subclass).
88
+ """
89
+ config_name = None
90
+ ignore_for_config = []
91
+ has_compatibles = False
92
+
93
+ _deprecated_kwargs = []
94
+
95
+ def register_to_config(self, **kwargs):
96
+ if self.config_name is None:
97
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
98
+ # Special case for `kwargs` used in deprecation warning added to schedulers
99
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
100
+ # or solve in a more general way.
101
+ kwargs.pop("kwargs", None)
102
+ for key, value in kwargs.items():
103
+ try:
104
+ setattr(self, key, value)
105
+ except AttributeError as err:
106
+ logger.error(f"Can't set {key} with value {value} for {self}")
107
+ raise err
108
+
109
+ if not hasattr(self, "_internal_dict"):
110
+ internal_dict = kwargs
111
+ else:
112
+ previous_dict = dict(self._internal_dict)
113
+ internal_dict = {**self._internal_dict, **kwargs}
114
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
115
+
116
+ self._internal_dict = FrozenDict(internal_dict)
117
+
118
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
119
+ """
120
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
121
+ [`~ConfigMixin.from_config`] class method.
122
+
123
+ Args:
124
+ save_directory (`str` or `os.PathLike`):
125
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
126
+ """
127
+ if os.path.isfile(save_directory):
128
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
129
+
130
+ os.makedirs(save_directory, exist_ok=True)
131
+
132
+ # If we save using the predefined names, we can load using `from_config`
133
+ output_config_file = os.path.join(save_directory, self.config_name)
134
+
135
+ self.to_json_file(output_config_file)
136
+ logger.info(f"Configuration saved in {output_config_file}")
137
+
138
+ @classmethod
139
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
140
+ r"""
141
+ Instantiate a Python class from a config dictionary
142
+
143
+ Parameters:
144
+ config (`Dict[str, Any]`):
145
+ A config dictionary from which the Python class will be instantiated. Make sure to only load
146
+ configuration files of compatible classes.
147
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
148
+ Whether kwargs that are not consumed by the Python class should be returned or not.
149
+
150
+ kwargs (remaining dictionary of keyword arguments, *optional*):
151
+ Can be used to update the configuration object (after it being loaded) and initiate the Python class.
152
+ `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
153
+ overwrite same named arguments of `config`.
154
+
155
+ Examples:
156
+
157
+ ```python
158
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
159
+
160
+ >>> # Download scheduler from huggingface.co and cache.
161
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
162
+
163
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
164
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
165
+
166
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
167
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
168
+ ```
169
+ """
170
+ # <===== TO BE REMOVED WITH DEPRECATION
171
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
172
+ if "pretrained_model_name_or_path" in kwargs:
173
+ config = kwargs.pop("pretrained_model_name_or_path")
174
+
175
+ if config is None:
176
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
177
+ # ======>
178
+
179
+ if not isinstance(config, dict):
180
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
181
+ if "Scheduler" in cls.__name__:
182
+ deprecation_message += (
183
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
184
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
185
+ " be removed in v1.0.0."
186
+ )
187
+ elif "Model" in cls.__name__:
188
+ deprecation_message += (
189
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
190
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
191
+ " instead. This functionality will be removed in v1.0.0."
192
+ )
193
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
194
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
195
+
196
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
197
+
198
+ # Allow dtype to be specified on initialization
199
+ if "dtype" in unused_kwargs:
200
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
201
+
202
+ # add possible deprecated kwargs
203
+ for deprecated_kwarg in cls._deprecated_kwargs:
204
+ if deprecated_kwarg in unused_kwargs:
205
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
206
+
207
+ # Return model and optionally state and/or unused_kwargs
208
+ model = cls(**init_dict)
209
+
210
+ # make sure to also save config parameters that might be used for compatible classes
211
+ model.register_to_config(**hidden_dict)
212
+
213
+ # add hidden kwargs of compatible classes to unused_kwargs
214
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
215
+
216
+ if return_unused_kwargs:
217
+ return (model, unused_kwargs)
218
+ else:
219
+ return model
220
+
221
+ @classmethod
222
+ def get_config_dict(cls, *args, **kwargs):
223
+ deprecation_message = (
224
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
225
+ " removed in version v1.0.0"
226
+ )
227
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
228
+ return cls.load_config(*args, **kwargs)
229
+
230
+ @classmethod
231
+ def load_config(
232
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
233
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
234
+ r"""
235
+ Instantiate a Python class from a config dictionary
236
+
237
+ Parameters:
238
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
239
+ Can be either:
240
+
241
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
242
+ organization name, like `google/ddpm-celebahq-256`.
243
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
244
+ `./my_model_directory/`.
245
+
246
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
247
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
248
+ standard cache should not be used.
249
+ force_download (`bool`, *optional*, defaults to `False`):
250
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
251
+ cached versions if they exist.
252
+ resume_download (`bool`, *optional*, defaults to `False`):
253
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
254
+ file exists.
255
+ proxies (`Dict[str, str]`, *optional*):
256
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
257
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
258
+ output_loading_info(`bool`, *optional*, defaults to `False`):
259
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
260
+ local_files_only(`bool`, *optional*, defaults to `False`):
261
+ Whether or not to only look at local files (i.e., do not try to download the model).
262
+ use_auth_token (`str` or *bool*, *optional*):
263
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
264
+ when running `transformers-cli login` (stored in `~/.huggingface`).
265
+ revision (`str`, *optional*, defaults to `"main"`):
266
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
267
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
268
+ identifier allowed by git.
269
+ subfolder (`str`, *optional*, defaults to `""`):
270
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
271
+ huggingface.co or downloaded locally), you can specify the folder name here.
272
+
273
+ <Tip>
274
+
275
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
276
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
277
+
278
+ </Tip>
279
+
280
+ <Tip>
281
+
282
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
283
+ use this method in a firewalled environment.
284
+
285
+ </Tip>
286
+ """
287
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
288
+ force_download = kwargs.pop("force_download", False)
289
+ resume_download = kwargs.pop("resume_download", False)
290
+ proxies = kwargs.pop("proxies", None)
291
+ use_auth_token = kwargs.pop("use_auth_token", None)
292
+ local_files_only = kwargs.pop("local_files_only", False)
293
+ revision = kwargs.pop("revision", None)
294
+ _ = kwargs.pop("mirror", None)
295
+ subfolder = kwargs.pop("subfolder", None)
296
+
297
+ user_agent = {"file_type": "config"}
298
+
299
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
300
+
301
+ if cls.config_name is None:
302
+ raise ValueError(
303
+ "`self.config_name` is not defined. Note that one should not load a config from "
304
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
305
+ )
306
+
307
+ if os.path.isfile(pretrained_model_name_or_path):
308
+ config_file = pretrained_model_name_or_path
309
+ elif os.path.isdir(pretrained_model_name_or_path):
310
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
311
+ # Load from a PyTorch checkpoint
312
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
313
+ elif subfolder is not None and os.path.isfile(
314
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
315
+ ):
316
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
317
+ else:
318
+ raise EnvironmentError(
319
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
320
+ )
321
+ else:
322
+ try:
323
+ # Load from URL or cache if already cached
324
+ config_file = hf_hub_download(
325
+ pretrained_model_name_or_path,
326
+ filename=cls.config_name,
327
+ cache_dir=cache_dir,
328
+ force_download=force_download,
329
+ proxies=proxies,
330
+ resume_download=resume_download,
331
+ local_files_only=local_files_only,
332
+ use_auth_token=use_auth_token,
333
+ user_agent=user_agent,
334
+ subfolder=subfolder,
335
+ revision=revision,
336
+ )
337
+
338
+ except RepositoryNotFoundError:
339
+ raise EnvironmentError(
340
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
341
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
342
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
343
+ " login`."
344
+ )
345
+ except RevisionNotFoundError:
346
+ raise EnvironmentError(
347
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
348
+ " this model name. Check the model page at"
349
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
350
+ )
351
+ except EntryNotFoundError:
352
+ raise EnvironmentError(
353
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
354
+ )
355
+ except HTTPError as err:
356
+ raise EnvironmentError(
357
+ "There was a specific connection error when trying to load"
358
+ f" {pretrained_model_name_or_path}:\n{err}"
359
+ )
360
+ except ValueError:
361
+ raise EnvironmentError(
362
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
363
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
364
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
365
+ " run the library in offline mode at"
366
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
367
+ )
368
+ except EnvironmentError:
369
+ raise EnvironmentError(
370
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
371
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
372
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
373
+ f"containing a {cls.config_name} file"
374
+ )
375
+
376
+ try:
377
+ # Load config dict
378
+ config_dict = cls._dict_from_json_file(config_file)
379
+ except (json.JSONDecodeError, UnicodeDecodeError):
380
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
381
+
382
+ if return_unused_kwargs:
383
+ return config_dict, kwargs
384
+
385
+ return config_dict
386
+
387
+ @staticmethod
388
+ def _get_init_keys(cls):
389
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
390
+
391
+ @classmethod
392
+ def extract_init_dict(cls, config_dict, **kwargs):
393
+ # 0. Copy origin config dict
394
+ original_dict = {k: v for k, v in config_dict.items()}
395
+
396
+ # 1. Retrieve expected config attributes from __init__ signature
397
+ expected_keys = cls._get_init_keys(cls)
398
+ expected_keys.remove("self")
399
+ # remove general kwargs if present in dict
400
+ if "kwargs" in expected_keys:
401
+ expected_keys.remove("kwargs")
402
+ # remove flax internal keys
403
+ if hasattr(cls, "_flax_internal_args"):
404
+ for arg in cls._flax_internal_args:
405
+ expected_keys.remove(arg)
406
+
407
+ # 2. Remove attributes that cannot be expected from expected config attributes
408
+ # remove keys to be ignored
409
+ if len(cls.ignore_for_config) > 0:
410
+ expected_keys = expected_keys - set(cls.ignore_for_config)
411
+
412
+ # load diffusers library to import compatible and original scheduler
413
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
414
+
415
+ if cls.has_compatibles:
416
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
417
+ else:
418
+ compatible_classes = []
419
+
420
+ expected_keys_comp_cls = set()
421
+ for c in compatible_classes:
422
+ expected_keys_c = cls._get_init_keys(c)
423
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
424
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
425
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
426
+
427
+ # remove attributes from orig class that cannot be expected
428
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
429
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
430
+ orig_cls = getattr(diffusers_library, orig_cls_name)
431
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
432
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
433
+
434
+ # remove private attributes
435
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
436
+
437
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
438
+ init_dict = {}
439
+ for key in expected_keys:
440
+ # if config param is passed to kwarg and is present in config dict
441
+ # it should overwrite existing config dict key
442
+ if key in kwargs and key in config_dict:
443
+ config_dict[key] = kwargs.pop(key)
444
+
445
+ if key in kwargs:
446
+ # overwrite key
447
+ init_dict[key] = kwargs.pop(key)
448
+ elif key in config_dict:
449
+ # use value from config dict
450
+ init_dict[key] = config_dict.pop(key)
451
+
452
+ # 4. Give nice warning if unexpected values have been passed
453
+ if len(config_dict) > 0:
454
+ logger.warning(
455
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
456
+ "but are not expected and will be ignored. Please verify your "
457
+ f"{cls.config_name} configuration file."
458
+ )
459
+
460
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
461
+ passed_keys = set(init_dict.keys())
462
+ if len(expected_keys - passed_keys) > 0:
463
+ logger.info(
464
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
465
+ )
466
+
467
+ # 6. Define unused keyword arguments
468
+ unused_kwargs = {**config_dict, **kwargs}
469
+
470
+ # 7. Define "hidden" config parameters that were saved for compatible classes
471
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
472
+
473
+ return init_dict, unused_kwargs, hidden_config_dict
474
+
475
+ @classmethod
476
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
477
+ with open(json_file, "r", encoding="utf-8") as reader:
478
+ text = reader.read()
479
+ return json.loads(text)
480
+
481
+ def __repr__(self):
482
+ return f"{self.__class__.__name__} {self.to_json_string()}"
483
+
484
+ @property
485
+ def config(self) -> Dict[str, Any]:
486
+ """
487
+ Returns the config of the class as a frozen dictionary
488
+
489
+ Returns:
490
+ `Dict[str, Any]`: Config of the class.
491
+ """
492
+ return self._internal_dict
493
+
494
+ def to_json_string(self) -> str:
495
+ """
496
+ Serializes this instance to a JSON string.
497
+
498
+ Returns:
499
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
500
+ """
501
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
502
+ config_dict["_class_name"] = self.__class__.__name__
503
+ config_dict["_diffusers_version"] = __version__
504
+
505
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
506
+
507
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
508
+ """
509
+ Save this instance to a JSON file.
510
+
511
+ Args:
512
+ json_file_path (`str` or `os.PathLike`):
513
+ Path to the JSON file in which this configuration instance's parameters will be saved.
514
+ """
515
+ with open(json_file_path, "w", encoding="utf-8") as writer:
516
+ writer.write(self.to_json_string())
517
+
518
+
519
+ def register_to_config(init):
520
+ r"""
521
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
522
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
523
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
524
+
525
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
526
+ """
527
+
528
+ @functools.wraps(init)
529
+ def inner_init(self, *args, **kwargs):
530
+ # Ignore private kwargs in the init.
531
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
532
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
533
+ if not isinstance(self, ConfigMixin):
534
+ raise RuntimeError(
535
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
536
+ "not inherit from `ConfigMixin`."
537
+ )
538
+
539
+ ignore = getattr(self, "ignore_for_config", [])
540
+ # Get positional arguments aligned with kwargs
541
+ new_kwargs = {}
542
+ signature = inspect.signature(init)
543
+ parameters = {
544
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
545
+ }
546
+ for arg, name in zip(args, parameters.keys()):
547
+ new_kwargs[name] = arg
548
+
549
+ # Then add all kwargs
550
+ new_kwargs.update(
551
+ {
552
+ k: init_kwargs.get(k, default)
553
+ for k, default in parameters.items()
554
+ if k not in ignore and k not in new_kwargs
555
+ }
556
+ )
557
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
558
+ getattr(self, "register_to_config")(**new_kwargs)
559
+ init(self, *args, **init_kwargs)
560
+
561
+ return inner_init
562
+
563
+
564
+ def flax_register_to_config(cls):
565
+ original_init = cls.__init__
566
+
567
+ @functools.wraps(original_init)
568
+ def init(self, *args, **kwargs):
569
+ if not isinstance(self, ConfigMixin):
570
+ raise RuntimeError(
571
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
572
+ "not inherit from `ConfigMixin`."
573
+ )
574
+
575
+ # Ignore private kwargs in the init. Retrieve all passed attributes
576
+ init_kwargs = {k: v for k, v in kwargs.items()}
577
+
578
+ # Retrieve default values
579
+ fields = dataclasses.fields(self)
580
+ default_kwargs = {}
581
+ for field in fields:
582
+ # ignore flax specific attributes
583
+ if field.name in self._flax_internal_args:
584
+ continue
585
+ if type(field.default) == dataclasses._MISSING_TYPE:
586
+ default_kwargs[field.name] = None
587
+ else:
588
+ default_kwargs[field.name] = getattr(self, field.name)
589
+
590
+ # Make sure init_kwargs override default kwargs
591
+ new_kwargs = {**default_kwargs, **init_kwargs}
592
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
593
+ if "dtype" in new_kwargs:
594
+ new_kwargs.pop("dtype")
595
+
596
+ # Get positional arguments aligned with kwargs
597
+ for i, arg in enumerate(args):
598
+ name = fields[i].name
599
+ new_kwargs[name] = arg
600
+
601
+ getattr(self, "register_to_config")(**new_kwargs)
602
+ original_init(self, *args, **kwargs)
603
+
604
+ cls.__init__ = init
605
+ return cls
src/diffusers_/dynamic_modules_utils.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities to dynamically load objects from the Hub."""
16
+
17
+ import importlib
18
+ import inspect
19
+ import os
20
+ import re
21
+ import shutil
22
+ import sys
23
+ from pathlib import Path
24
+ from typing import Dict, Optional, Union
25
+
26
+ from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
27
+
28
+ from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
29
+
30
+
31
+ COMMUNITY_PIPELINES_URL = (
32
+ "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py"
33
+ )
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ def init_hf_modules():
40
+ """
41
+ Creates the cache directory for modules with an init, and adds it to the Python path.
42
+ """
43
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
44
+ if HF_MODULES_CACHE in sys.path:
45
+ return
46
+
47
+ sys.path.append(HF_MODULES_CACHE)
48
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
49
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
50
+ if not init_path.exists():
51
+ init_path.touch()
52
+
53
+
54
+ def create_dynamic_module(name: Union[str, os.PathLike]):
55
+ """
56
+ Creates a dynamic module in the cache directory for modules.
57
+ """
58
+ init_hf_modules()
59
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
60
+ # If the parent module does not exist yet, recursively create it.
61
+ if not dynamic_module_path.parent.exists():
62
+ create_dynamic_module(dynamic_module_path.parent)
63
+ os.makedirs(dynamic_module_path, exist_ok=True)
64
+ init_path = dynamic_module_path / "__init__.py"
65
+ if not init_path.exists():
66
+ init_path.touch()
67
+
68
+
69
+ def get_relative_imports(module_file):
70
+ """
71
+ Get the list of modules that are relatively imported in a module file.
72
+
73
+ Args:
74
+ module_file (`str` or `os.PathLike`): The module file to inspect.
75
+ """
76
+ with open(module_file, "r", encoding="utf-8") as f:
77
+ content = f.read()
78
+
79
+ # Imports of the form `import .xxx`
80
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
81
+ # Imports of the form `from .xxx import yyy`
82
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
83
+ # Unique-ify
84
+ return list(set(relative_imports))
85
+
86
+
87
+ def get_relative_import_files(module_file):
88
+ """
89
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
90
+ imports (if a imports b and b imports c, it will return module files for b and c).
91
+
92
+ Args:
93
+ module_file (`str` or `os.PathLike`): The module file to inspect.
94
+ """
95
+ no_change = False
96
+ files_to_check = [module_file]
97
+ all_relative_imports = []
98
+
99
+ # Let's recurse through all relative imports
100
+ while not no_change:
101
+ new_imports = []
102
+ for f in files_to_check:
103
+ new_imports.extend(get_relative_imports(f))
104
+
105
+ module_path = Path(module_file).parent
106
+ new_import_files = [str(module_path / m) for m in new_imports]
107
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
108
+ files_to_check = [f"{f}.py" for f in new_import_files]
109
+
110
+ no_change = len(new_import_files) == 0
111
+ all_relative_imports.extend(files_to_check)
112
+
113
+ return all_relative_imports
114
+
115
+
116
+ def check_imports(filename):
117
+ """
118
+ Check if the current Python environment contains all the libraries that are imported in a file.
119
+ """
120
+ with open(filename, "r", encoding="utf-8") as f:
121
+ content = f.read()
122
+
123
+ # Imports of the form `import xxx`
124
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
125
+ # Imports of the form `from xxx import yyy`
126
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
127
+ # Only keep the top-level module
128
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
129
+
130
+ # Unique-ify and test we got them all
131
+ imports = list(set(imports))
132
+ missing_packages = []
133
+ for imp in imports:
134
+ try:
135
+ importlib.import_module(imp)
136
+ except ImportError:
137
+ missing_packages.append(imp)
138
+
139
+ if len(missing_packages) > 0:
140
+ raise ImportError(
141
+ "This modeling file requires the following packages that were not found in your environment: "
142
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
143
+ )
144
+
145
+ return get_relative_imports(filename)
146
+
147
+
148
+ def get_class_in_module(class_name, module_path):
149
+ """
150
+ Import a module on the cache directory for modules and extract a class from it.
151
+ """
152
+ module_path = module_path.replace(os.path.sep, ".")
153
+ module = importlib.import_module(module_path)
154
+
155
+ if class_name is None:
156
+ return find_pipeline_class(module)
157
+ return getattr(module, class_name)
158
+
159
+
160
+ def find_pipeline_class(loaded_module):
161
+ """
162
+ Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
163
+ inheriting from `DiffusionPipeline`.
164
+ """
165
+ from .pipeline_utils import DiffusionPipeline
166
+
167
+ cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
168
+
169
+ pipeline_class = None
170
+ for cls_name, cls in cls_members.items():
171
+ if (
172
+ cls_name != DiffusionPipeline.__name__
173
+ and issubclass(cls, DiffusionPipeline)
174
+ and cls.__module__.split(".")[0] != "diffusers"
175
+ ):
176
+ if pipeline_class is not None:
177
+ raise ValueError(
178
+ f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
179
+ f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
180
+ f" {loaded_module}."
181
+ )
182
+ pipeline_class = cls
183
+
184
+ return pipeline_class
185
+
186
+
187
+ def get_cached_module_file(
188
+ pretrained_model_name_or_path: Union[str, os.PathLike],
189
+ module_file: str,
190
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
191
+ force_download: bool = False,
192
+ resume_download: bool = False,
193
+ proxies: Optional[Dict[str, str]] = None,
194
+ use_auth_token: Optional[Union[bool, str]] = None,
195
+ revision: Optional[str] = None,
196
+ local_files_only: bool = False,
197
+ ):
198
+ """
199
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
200
+ Transformers module.
201
+
202
+ Args:
203
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
204
+ This can be either:
205
+
206
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
207
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
208
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
209
+ - a path to a *directory* containing a configuration file saved using the
210
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
211
+
212
+ module_file (`str`):
213
+ The name of the module file containing the class to look for.
214
+ cache_dir (`str` or `os.PathLike`, *optional*):
215
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
216
+ cache should not be used.
217
+ force_download (`bool`, *optional*, defaults to `False`):
218
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
219
+ exist.
220
+ resume_download (`bool`, *optional*, defaults to `False`):
221
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
222
+ proxies (`Dict[str, str]`, *optional*):
223
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
224
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
225
+ use_auth_token (`str` or *bool*, *optional*):
226
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
227
+ when running `transformers-cli login` (stored in `~/.huggingface`).
228
+ revision (`str`, *optional*, defaults to `"main"`):
229
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
230
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
231
+ identifier allowed by git.
232
+ local_files_only (`bool`, *optional*, defaults to `False`):
233
+ If `True`, will only try to load the tokenizer configuration from local files.
234
+
235
+ <Tip>
236
+
237
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
238
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
239
+
240
+ </Tip>
241
+
242
+ Returns:
243
+ `str`: The path to the module inside the cache.
244
+ """
245
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
246
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
247
+
248
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
249
+
250
+ if os.path.isfile(module_file_or_url):
251
+ resolved_module_file = module_file_or_url
252
+ submodule = "local"
253
+ elif pretrained_model_name_or_path.count("/") == 0:
254
+ # community pipeline on GitHub
255
+ github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path)
256
+ try:
257
+ resolved_module_file = cached_download(
258
+ github_url,
259
+ cache_dir=cache_dir,
260
+ force_download=force_download,
261
+ proxies=proxies,
262
+ resume_download=resume_download,
263
+ local_files_only=local_files_only,
264
+ use_auth_token=False,
265
+ )
266
+ submodule = "git"
267
+ module_file = pretrained_model_name_or_path + ".py"
268
+ except EnvironmentError:
269
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
270
+ raise
271
+ else:
272
+ try:
273
+ # Load from URL or cache if already cached
274
+ resolved_module_file = hf_hub_download(
275
+ pretrained_model_name_or_path,
276
+ module_file,
277
+ cache_dir=cache_dir,
278
+ force_download=force_download,
279
+ proxies=proxies,
280
+ resume_download=resume_download,
281
+ local_files_only=local_files_only,
282
+ use_auth_token=use_auth_token,
283
+ )
284
+ submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
285
+ except EnvironmentError:
286
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
287
+ raise
288
+
289
+ # Check we have all the requirements in our environment
290
+ modules_needed = check_imports(resolved_module_file)
291
+
292
+ # Now we move the module inside our cached dynamic modules.
293
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
294
+ create_dynamic_module(full_submodule)
295
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
296
+ if submodule == "local" or submodule == "git":
297
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
298
+ # that hash, to only copy when there is a modification but it seems overkill for now).
299
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
300
+ shutil.copy(resolved_module_file, submodule_path / module_file)
301
+ for module_needed in modules_needed:
302
+ module_needed = f"{module_needed}.py"
303
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
304
+ else:
305
+ # Get the commit hash
306
+ # TODO: we will get this info in the etag soon, so retrieve it from there and not here.
307
+ if isinstance(use_auth_token, str):
308
+ token = use_auth_token
309
+ elif use_auth_token is True:
310
+ token = HfFolder.get_token()
311
+ else:
312
+ token = None
313
+
314
+ commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
315
+
316
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
317
+ # benefit of versioning.
318
+ submodule_path = submodule_path / commit_hash
319
+ full_submodule = full_submodule + os.path.sep + commit_hash
320
+ create_dynamic_module(full_submodule)
321
+
322
+ if not (submodule_path / module_file).exists():
323
+ shutil.copy(resolved_module_file, submodule_path / module_file)
324
+ # Make sure we also have every file with relative
325
+ for module_needed in modules_needed:
326
+ if not (submodule_path / module_needed).exists():
327
+ get_cached_module_file(
328
+ pretrained_model_name_or_path,
329
+ f"{module_needed}.py",
330
+ cache_dir=cache_dir,
331
+ force_download=force_download,
332
+ resume_download=resume_download,
333
+ proxies=proxies,
334
+ use_auth_token=use_auth_token,
335
+ revision=revision,
336
+ local_files_only=local_files_only,
337
+ )
338
+ return os.path.join(full_submodule, module_file)
339
+
340
+
341
+ def get_class_from_dynamic_module(
342
+ pretrained_model_name_or_path: Union[str, os.PathLike],
343
+ module_file: str,
344
+ class_name: Optional[str] = None,
345
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
346
+ force_download: bool = False,
347
+ resume_download: bool = False,
348
+ proxies: Optional[Dict[str, str]] = None,
349
+ use_auth_token: Optional[Union[bool, str]] = None,
350
+ revision: Optional[str] = None,
351
+ local_files_only: bool = False,
352
+ **kwargs,
353
+ ):
354
+ """
355
+ Extracts a class from a module file, present in the local folder or repository of a model.
356
+
357
+ <Tip warning={true}>
358
+
359
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
360
+ therefore only be called on trusted repos.
361
+
362
+ </Tip>
363
+
364
+ Args:
365
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
366
+ This can be either:
367
+
368
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
369
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
370
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
371
+ - a path to a *directory* containing a configuration file saved using the
372
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
373
+
374
+ module_file (`str`):
375
+ The name of the module file containing the class to look for.
376
+ class_name (`str`):
377
+ The name of the class to import in the module.
378
+ cache_dir (`str` or `os.PathLike`, *optional*):
379
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
380
+ cache should not be used.
381
+ force_download (`bool`, *optional*, defaults to `False`):
382
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
383
+ exist.
384
+ resume_download (`bool`, *optional*, defaults to `False`):
385
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
386
+ proxies (`Dict[str, str]`, *optional*):
387
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
388
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
389
+ use_auth_token (`str` or `bool`, *optional*):
390
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
391
+ when running `transformers-cli login` (stored in `~/.huggingface`).
392
+ revision (`str`, *optional*, defaults to `"main"`):
393
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
394
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
395
+ identifier allowed by git.
396
+ local_files_only (`bool`, *optional*, defaults to `False`):
397
+ If `True`, will only try to load the tokenizer configuration from local files.
398
+
399
+ <Tip>
400
+
401
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
402
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
403
+
404
+ </Tip>
405
+
406
+ Returns:
407
+ `type`: The class, dynamically imported from the module.
408
+
409
+ Examples:
410
+
411
+ ```python
412
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
413
+ # module.
414
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
415
+ ```"""
416
+ # And lastly we get the class inside our newly created module
417
+ final_module = get_cached_module_file(
418
+ pretrained_model_name_or_path,
419
+ module_file,
420
+ cache_dir=cache_dir,
421
+ force_download=force_download,
422
+ resume_download=resume_download,
423
+ proxies=proxies,
424
+ use_auth_token=use_auth_token,
425
+ revision=revision,
426
+ local_files_only=local_files_only,
427
+ )
428
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
src/diffusers_/hub_utils.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import shutil
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Dict, Optional, Union
22
+ from uuid import uuid4
23
+
24
+ from huggingface_hub import HfFolder, Repository, whoami
25
+
26
+ from . import __version__
27
+ from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
28
+ from .utils.import_utils import (
29
+ _flax_version,
30
+ _jax_version,
31
+ _onnxruntime_version,
32
+ _torch_version,
33
+ is_flax_available,
34
+ is_modelcards_available,
35
+ is_onnx_available,
36
+ is_torch_available,
37
+ )
38
+
39
+
40
+ if is_modelcards_available():
41
+ from modelcards import CardData, ModelCard
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
48
+ SESSION_ID = uuid4().hex
49
+ DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
50
+
51
+
52
+ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
53
+ """
54
+ Formats a user-agent string with basic info about a request.
55
+ """
56
+ ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
57
+ if DISABLE_TELEMETRY:
58
+ return ua + "; telemetry/off"
59
+ if is_torch_available():
60
+ ua += f"; torch/{_torch_version}"
61
+ if is_flax_available():
62
+ ua += f"; jax/{_jax_version}"
63
+ ua += f"; flax/{_flax_version}"
64
+ if is_onnx_available():
65
+ ua += f"; onnxruntime/{_onnxruntime_version}"
66
+ # CI will set this value to True
67
+ if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
68
+ ua += "; is_ci/true"
69
+ if isinstance(user_agent, dict):
70
+ ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
71
+ elif isinstance(user_agent, str):
72
+ ua += "; " + user_agent
73
+ return ua
74
+
75
+
76
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
77
+ if token is None:
78
+ token = HfFolder.get_token()
79
+ if organization is None:
80
+ username = whoami(token)["name"]
81
+ return f"{username}/{model_id}"
82
+ else:
83
+ return f"{organization}/{model_id}"
84
+
85
+
86
+ def init_git_repo(args, at_init: bool = False):
87
+ """
88
+ Args:
89
+ Initializes a git repo in `args.hub_model_id`.
90
+ at_init (`bool`, *optional*, defaults to `False`):
91
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
92
+ and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
93
+ """
94
+ deprecation_message = (
95
+ "Please use `huggingface_hub.Repository`. "
96
+ "See `examples/unconditional_image_generation/train_unconditional.py` for an example."
97
+ )
98
+ deprecate("init_git_repo()", "0.10.0", deprecation_message)
99
+
100
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
101
+ return
102
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
103
+ use_auth_token = True if hub_token is None else hub_token
104
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
105
+ repo_name = Path(args.output_dir).absolute().name
106
+ else:
107
+ repo_name = args.hub_model_id
108
+ if "/" not in repo_name:
109
+ repo_name = get_full_repo_name(repo_name, token=hub_token)
110
+
111
+ try:
112
+ repo = Repository(
113
+ args.output_dir,
114
+ clone_from=repo_name,
115
+ use_auth_token=use_auth_token,
116
+ private=args.hub_private_repo,
117
+ )
118
+ except EnvironmentError:
119
+ if args.overwrite_output_dir and at_init:
120
+ # Try again after wiping output_dir
121
+ shutil.rmtree(args.output_dir)
122
+ repo = Repository(
123
+ args.output_dir,
124
+ clone_from=repo_name,
125
+ use_auth_token=use_auth_token,
126
+ )
127
+ else:
128
+ raise
129
+
130
+ repo.git_pull()
131
+
132
+ # By default, ignore the checkpoint folders
133
+ if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
134
+ with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
135
+ writer.writelines(["checkpoint-*/"])
136
+
137
+ return repo
138
+
139
+
140
+ def push_to_hub(
141
+ args,
142
+ pipeline,
143
+ repo: Repository,
144
+ commit_message: Optional[str] = "End of training",
145
+ blocking: bool = True,
146
+ **kwargs,
147
+ ) -> str:
148
+ """
149
+ Parameters:
150
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
151
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
152
+ Message to commit while pushing.
153
+ blocking (`bool`, *optional*, defaults to `True`):
154
+ Whether the function should return only when the `git push` has finished.
155
+ kwargs:
156
+ Additional keyword arguments passed along to [`create_model_card`].
157
+ Returns:
158
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
159
+ commit and an object to track the progress of the commit if `blocking=True`
160
+ """
161
+ deprecation_message = (
162
+ "Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. "
163
+ "See `examples/unconditional_image_generation/train_unconditional.py` for an example."
164
+ )
165
+ deprecate("push_to_hub()", "0.10.0", deprecation_message)
166
+
167
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
168
+ model_name = Path(args.output_dir).name
169
+ else:
170
+ model_name = args.hub_model_id.split("/")[-1]
171
+
172
+ output_dir = args.output_dir
173
+ os.makedirs(output_dir, exist_ok=True)
174
+ logger.info(f"Saving pipeline checkpoint to {output_dir}")
175
+ pipeline.save_pretrained(output_dir)
176
+
177
+ # Only push from one node.
178
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
179
+ return
180
+
181
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
182
+ if (
183
+ blocking
184
+ and len(repo.command_queue) > 0
185
+ and repo.command_queue[-1] is not None
186
+ and not repo.command_queue[-1].is_done
187
+ ):
188
+ repo.command_queue[-1]._process.kill()
189
+
190
+ git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
191
+ # push separately the model card to be independent from the rest of the model
192
+ create_model_card(args, model_name=model_name)
193
+ try:
194
+ repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
195
+ except EnvironmentError as exc:
196
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
197
+
198
+ return git_head_commit_url
199
+
200
+
201
+ def create_model_card(args, model_name):
202
+ if not is_modelcards_available:
203
+ raise ValueError(
204
+ "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
205
+ " install the package with `pip install modelcards`."
206
+ )
207
+
208
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
209
+ return
210
+
211
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
212
+ repo_name = get_full_repo_name(model_name, token=hub_token)
213
+
214
+ model_card = ModelCard.from_template(
215
+ card_data=CardData( # Card metadata object that will be converted to YAML block
216
+ language="en",
217
+ license="apache-2.0",
218
+ library_name="diffusers",
219
+ tags=[],
220
+ datasets=args.dataset_name,
221
+ metrics=[],
222
+ ),
223
+ template_path=MODEL_CARD_TEMPLATE_PATH,
224
+ model_name=model_name,
225
+ repo_name=repo_name,
226
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
227
+ learning_rate=args.learning_rate,
228
+ train_batch_size=args.train_batch_size,
229
+ eval_batch_size=args.eval_batch_size,
230
+ gradient_accumulation_steps=args.gradient_accumulation_steps
231
+ if hasattr(args, "gradient_accumulation_steps")
232
+ else None,
233
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
234
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
235
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
236
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
237
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
238
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
239
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
240
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
241
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
242
+ mixed_precision=args.mixed_precision,
243
+ )
244
+
245
+ card_path = os.path.join(args.output_dir, "README.md")
246
+ model_card.save(card_path)
src/diffusers_/modeling_utils.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ from functools import partial
19
+ from typing import Callable, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ from torch import Tensor, device
23
+
24
+ from huggingface_hub import hf_hub_download
25
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
26
+ from requests import HTTPError
27
+
28
+ from . import __version__
29
+ from .utils import (
30
+ CONFIG_NAME,
31
+ DIFFUSERS_CACHE,
32
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
33
+ WEIGHTS_NAME,
34
+ is_accelerate_available,
35
+ is_torch_version,
36
+ logging,
37
+ )
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ if is_torch_version(">=", "1.9.0"):
44
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
45
+ else:
46
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
47
+
48
+
49
+ if is_accelerate_available():
50
+ import accelerate
51
+ from accelerate.utils import set_module_tensor_to_device
52
+ from accelerate.utils.versions import is_torch_version
53
+
54
+
55
+ def get_parameter_device(parameter: torch.nn.Module):
56
+ try:
57
+ return next(parameter.parameters()).device
58
+ except StopIteration:
59
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
60
+
61
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
62
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
63
+ return tuples
64
+
65
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
66
+ first_tuple = next(gen)
67
+ return first_tuple[1].device
68
+
69
+
70
+ def get_parameter_dtype(parameter: torch.nn.Module):
71
+ try:
72
+ return next(parameter.parameters()).dtype
73
+ except StopIteration:
74
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
75
+
76
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
77
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
78
+ return tuples
79
+
80
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
81
+ first_tuple = next(gen)
82
+ return first_tuple[1].dtype
83
+
84
+
85
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
86
+ """
87
+ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
88
+ """
89
+ try:
90
+ return torch.load(checkpoint_file, map_location="cpu")
91
+ except Exception as e:
92
+ try:
93
+ with open(checkpoint_file) as f:
94
+ if f.read().startswith("version"):
95
+ raise OSError(
96
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
97
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
98
+ "you cloned."
99
+ )
100
+ else:
101
+ raise ValueError(
102
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
103
+ "model. Make sure you have saved the model properly."
104
+ ) from e
105
+ except (UnicodeDecodeError, ValueError):
106
+ raise OSError(
107
+ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
108
+ f"at '{checkpoint_file}'. "
109
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
110
+ )
111
+
112
+
113
+ def _load_state_dict_into_model(model_to_load, state_dict):
114
+ # Convert old format to new format if needed from a PyTorch state_dict
115
+ # copy state_dict so _load_from_state_dict can modify it
116
+ state_dict = state_dict.copy()
117
+ error_msgs = []
118
+
119
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
120
+ # so we need to apply the function recursively.
121
+ def load(module: torch.nn.Module, prefix=""):
122
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
123
+ module._load_from_state_dict(*args)
124
+
125
+ for name, child in module._modules.items():
126
+ if child is not None:
127
+ load(child, prefix + name + ".")
128
+
129
+ load(model_to_load)
130
+
131
+ return error_msgs
132
+
133
+
134
+ class ModelMixin(torch.nn.Module):
135
+ r"""
136
+ Base class for all models.
137
+
138
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
139
+ and saving models.
140
+
141
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
142
+ [`~modeling_utils.ModelMixin.save_pretrained`].
143
+ """
144
+ config_name = CONFIG_NAME
145
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
146
+ _supports_gradient_checkpointing = False
147
+
148
+ def __init__(self):
149
+ super().__init__()
150
+
151
+ @property
152
+ def is_gradient_checkpointing(self) -> bool:
153
+ """
154
+ Whether gradient checkpointing is activated for this model or not.
155
+
156
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
157
+ activations".
158
+ """
159
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
160
+
161
+ def enable_gradient_checkpointing(self):
162
+ """
163
+ Activates gradient checkpointing for the current model.
164
+
165
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
166
+ activations".
167
+ """
168
+ if not self._supports_gradient_checkpointing:
169
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
170
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
171
+
172
+ def disable_gradient_checkpointing(self):
173
+ """
174
+ Deactivates gradient checkpointing for the current model.
175
+
176
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
177
+ activations".
178
+ """
179
+ if self._supports_gradient_checkpointing:
180
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
181
+
182
+ def save_pretrained(
183
+ self,
184
+ save_directory: Union[str, os.PathLike],
185
+ is_main_process: bool = True,
186
+ save_function: Callable = torch.save,
187
+ ):
188
+ """
189
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
190
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
191
+
192
+ Arguments:
193
+ save_directory (`str` or `os.PathLike`):
194
+ Directory to which to save. Will be created if it doesn't exist.
195
+ is_main_process (`bool`, *optional*, defaults to `True`):
196
+ Whether the process calling this is the main process or not. Useful when in distributed training like
197
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
198
+ the main process to avoid race conditions.
199
+ save_function (`Callable`):
200
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
201
+ need to replace `torch.save` by another method.
202
+ """
203
+ if os.path.isfile(save_directory):
204
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
205
+ return
206
+
207
+ os.makedirs(save_directory, exist_ok=True)
208
+
209
+ model_to_save = self
210
+
211
+ # Attach architecture to the config
212
+ # Save the config
213
+ if is_main_process:
214
+ model_to_save.save_config(save_directory)
215
+
216
+ # Save the model
217
+ state_dict = model_to_save.state_dict()
218
+
219
+ # Clean the folder from a previous save
220
+ for filename in os.listdir(save_directory):
221
+ full_filename = os.path.join(save_directory, filename)
222
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
223
+ # in distributed settings to avoid race conditions.
224
+ if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
225
+ os.remove(full_filename)
226
+
227
+ # Save the model
228
+ save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
229
+
230
+ logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
231
+
232
+ @classmethod
233
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
234
+ r"""
235
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
236
+
237
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
238
+ the model, you should first set it back in training mode with `model.train()`.
239
+
240
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
241
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
242
+ task.
243
+
244
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
245
+ weights are discarded.
246
+
247
+ Parameters:
248
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
249
+ Can be either:
250
+
251
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
252
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
253
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
254
+ `./my_model_directory/`.
255
+
256
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
257
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
258
+ standard cache should not be used.
259
+ torch_dtype (`str` or `torch.dtype`, *optional*):
260
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
261
+ will be automatically derived from the model's weights.
262
+ force_download (`bool`, *optional*, defaults to `False`):
263
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
264
+ cached versions if they exist.
265
+ resume_download (`bool`, *optional*, defaults to `False`):
266
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
267
+ file exists.
268
+ proxies (`Dict[str, str]`, *optional*):
269
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
270
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
271
+ output_loading_info(`bool`, *optional*, defaults to `False`):
272
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
273
+ local_files_only(`bool`, *optional*, defaults to `False`):
274
+ Whether or not to only look at local files (i.e., do not try to download the model).
275
+ use_auth_token (`str` or *bool*, *optional*):
276
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
277
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
278
+ revision (`str`, *optional*, defaults to `"main"`):
279
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
280
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
281
+ identifier allowed by git.
282
+ subfolder (`str`, *optional*, defaults to `""`):
283
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
284
+ huggingface.co or downloaded locally), you can specify the folder name here.
285
+
286
+ mirror (`str`, *optional*):
287
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
288
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
289
+ Please refer to the mirror site for more information.
290
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
291
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
292
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
293
+ same device.
294
+
295
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
296
+ more information about each option see [designing a device
297
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
298
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
299
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
300
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
301
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
302
+ setting this argument to `True` will raise an error.
303
+
304
+ <Tip>
305
+
306
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
307
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
308
+
309
+ </Tip>
310
+
311
+ <Tip>
312
+
313
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
314
+ this method in a firewalled environment.
315
+
316
+ </Tip>
317
+
318
+ """
319
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
320
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
321
+ force_download = kwargs.pop("force_download", False)
322
+ resume_download = kwargs.pop("resume_download", False)
323
+ proxies = kwargs.pop("proxies", None)
324
+ output_loading_info = kwargs.pop("output_loading_info", False)
325
+ local_files_only = kwargs.pop("local_files_only", False)
326
+ use_auth_token = kwargs.pop("use_auth_token", None)
327
+ revision = kwargs.pop("revision", None)
328
+ torch_dtype = kwargs.pop("torch_dtype", None)
329
+ subfolder = kwargs.pop("subfolder", None)
330
+ device_map = kwargs.pop("device_map", None)
331
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
332
+
333
+ if low_cpu_mem_usage and not is_accelerate_available():
334
+ low_cpu_mem_usage = False
335
+ logger.warning(
336
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
337
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
338
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
339
+ " install accelerate\n```\n."
340
+ )
341
+
342
+ if device_map is not None and not is_accelerate_available():
343
+ raise NotImplementedError(
344
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
345
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
346
+ )
347
+
348
+ # Check if we can handle device_map and dispatching the weights
349
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
350
+ raise NotImplementedError(
351
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
352
+ " `device_map=None`."
353
+ )
354
+
355
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
356
+ raise NotImplementedError(
357
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
358
+ " `low_cpu_mem_usage=False`."
359
+ )
360
+
361
+ if low_cpu_mem_usage is False and device_map is not None:
362
+ raise ValueError(
363
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
364
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
365
+ )
366
+
367
+ user_agent = {
368
+ "diffusers": __version__,
369
+ "file_type": "model",
370
+ "framework": "pytorch",
371
+ }
372
+
373
+ # Load config if we don't provide a configuration
374
+ config_path = pretrained_model_name_or_path
375
+
376
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
377
+ # Load model
378
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
379
+ if os.path.isdir(pretrained_model_name_or_path):
380
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
381
+ # Load from a PyTorch checkpoint
382
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
383
+ elif subfolder is not None and os.path.isfile(
384
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
385
+ ):
386
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
387
+ else:
388
+ raise EnvironmentError(
389
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
390
+ )
391
+ else:
392
+ try:
393
+ # Load from URL or cache if already cached
394
+ model_file = hf_hub_download(
395
+ pretrained_model_name_or_path,
396
+ filename=WEIGHTS_NAME,
397
+ cache_dir=cache_dir,
398
+ force_download=force_download,
399
+ proxies=proxies,
400
+ resume_download=resume_download,
401
+ local_files_only=local_files_only,
402
+ use_auth_token=use_auth_token,
403
+ user_agent=user_agent,
404
+ subfolder=subfolder,
405
+ revision=revision,
406
+ )
407
+
408
+ except RepositoryNotFoundError:
409
+ raise EnvironmentError(
410
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
411
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
412
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
413
+ "login`."
414
+ )
415
+ except RevisionNotFoundError:
416
+ raise EnvironmentError(
417
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
418
+ "this model name. Check the model page at "
419
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
420
+ )
421
+ except EntryNotFoundError:
422
+ raise EnvironmentError(
423
+ f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
424
+ )
425
+ except HTTPError as err:
426
+ raise EnvironmentError(
427
+ "There was a specific connection error when trying to load"
428
+ f" {pretrained_model_name_or_path}:\n{err}"
429
+ )
430
+ except ValueError:
431
+ raise EnvironmentError(
432
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
433
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
434
+ f" directory containing a file named {WEIGHTS_NAME} or"
435
+ " \nCheckout your internet connection or see how to run the library in"
436
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
437
+ )
438
+ except EnvironmentError:
439
+ raise EnvironmentError(
440
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
441
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
442
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
443
+ f"containing a file named {WEIGHTS_NAME}"
444
+ )
445
+
446
+ # restore default dtype
447
+
448
+ if low_cpu_mem_usage:
449
+ # Instantiate model with empty weights
450
+ with accelerate.init_empty_weights():
451
+ config, unused_kwargs = cls.load_config(
452
+ config_path,
453
+ cache_dir=cache_dir,
454
+ return_unused_kwargs=True,
455
+ force_download=force_download,
456
+ resume_download=resume_download,
457
+ proxies=proxies,
458
+ local_files_only=local_files_only,
459
+ use_auth_token=use_auth_token,
460
+ revision=revision,
461
+ subfolder=subfolder,
462
+ device_map=device_map,
463
+ **kwargs,
464
+ )
465
+ model = cls.from_config(config, **unused_kwargs)
466
+
467
+ # if device_map is Non,e load the state dict on move the params from meta device to the cpu
468
+ if device_map is None:
469
+ param_device = "cpu"
470
+ state_dict = load_state_dict(model_file)
471
+ # move the parms from meta device to cpu
472
+ for param_name, param in state_dict.items():
473
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
474
+ else: # else let accelerate handle loading and dispatching.
475
+ # Load weights and dispatch according to the device_map
476
+ # by deafult the device_map is None and the weights are loaded on the CPU
477
+ accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
478
+
479
+ loading_info = {
480
+ "missing_keys": [],
481
+ "unexpected_keys": [],
482
+ "mismatched_keys": [],
483
+ "error_msgs": [],
484
+ }
485
+ else:
486
+ config, unused_kwargs = cls.load_config(
487
+ config_path,
488
+ cache_dir=cache_dir,
489
+ return_unused_kwargs=True,
490
+ force_download=force_download,
491
+ resume_download=resume_download,
492
+ proxies=proxies,
493
+ local_files_only=local_files_only,
494
+ use_auth_token=use_auth_token,
495
+ revision=revision,
496
+ subfolder=subfolder,
497
+ device_map=device_map,
498
+ **kwargs,
499
+ )
500
+ model = cls.from_config(config, **unused_kwargs)
501
+
502
+ state_dict = load_state_dict(model_file)
503
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
504
+ model,
505
+ state_dict,
506
+ model_file,
507
+ pretrained_model_name_or_path,
508
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
509
+ )
510
+
511
+ loading_info = {
512
+ "missing_keys": missing_keys,
513
+ "unexpected_keys": unexpected_keys,
514
+ "mismatched_keys": mismatched_keys,
515
+ "error_msgs": error_msgs,
516
+ }
517
+
518
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
519
+ raise ValueError(
520
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
521
+ )
522
+ elif torch_dtype is not None:
523
+ model = model.to(torch_dtype)
524
+
525
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
526
+
527
+ # Set model in evaluation mode to deactivate DropOut modules by default
528
+ model.eval()
529
+ if output_loading_info:
530
+ return model, loading_info
531
+
532
+ return model
533
+
534
+ @classmethod
535
+ def _load_pretrained_model(
536
+ cls,
537
+ model,
538
+ state_dict,
539
+ resolved_archive_file,
540
+ pretrained_model_name_or_path,
541
+ ignore_mismatched_sizes=False,
542
+ ):
543
+ # Retrieve missing & unexpected_keys
544
+ model_state_dict = model.state_dict()
545
+ loaded_keys = [k for k in state_dict.keys()]
546
+
547
+ expected_keys = list(model_state_dict.keys())
548
+
549
+ original_loaded_keys = loaded_keys
550
+
551
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
552
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
553
+
554
+ # Make sure we are able to load base models as well as derived models (with heads)
555
+ model_to_load = model
556
+
557
+ def _find_mismatched_keys(
558
+ state_dict,
559
+ model_state_dict,
560
+ loaded_keys,
561
+ ignore_mismatched_sizes,
562
+ ):
563
+ mismatched_keys = []
564
+ if ignore_mismatched_sizes:
565
+ for checkpoint_key in loaded_keys:
566
+ model_key = checkpoint_key
567
+
568
+ if (
569
+ model_key in model_state_dict
570
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
571
+ ):
572
+ mismatched_keys.append(
573
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
574
+ )
575
+ del state_dict[checkpoint_key]
576
+ return mismatched_keys
577
+
578
+ if state_dict is not None:
579
+ # Whole checkpoint
580
+ mismatched_keys = _find_mismatched_keys(
581
+ state_dict,
582
+ model_state_dict,
583
+ original_loaded_keys,
584
+ ignore_mismatched_sizes,
585
+ )
586
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
587
+
588
+ if len(error_msgs) > 0:
589
+ error_msg = "\n\t".join(error_msgs)
590
+ if "size mismatch" in error_msg:
591
+ error_msg += (
592
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
593
+ )
594
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
595
+
596
+ if len(unexpected_keys) > 0:
597
+ logger.warning(
598
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
599
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
600
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
601
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
602
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
603
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
604
+ " identical (initializing a BertForSequenceClassification model from a"
605
+ " BertForSequenceClassification model)."
606
+ )
607
+ else:
608
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
609
+ if len(missing_keys) > 0:
610
+ logger.warning(
611
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
612
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
613
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
614
+ )
615
+ elif len(mismatched_keys) == 0:
616
+ logger.info(
617
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
618
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
619
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
620
+ " without further training."
621
+ )
622
+ if len(mismatched_keys) > 0:
623
+ mismatched_warning = "\n".join(
624
+ [
625
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
626
+ for key, shape1, shape2 in mismatched_keys
627
+ ]
628
+ )
629
+ logger.warning(
630
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
631
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
632
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
633
+ " able to use it for predictions and inference."
634
+ )
635
+
636
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
637
+
638
+ @property
639
+ def device(self) -> device:
640
+ """
641
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
642
+ device).
643
+ """
644
+ return get_parameter_device(self)
645
+
646
+ @property
647
+ def dtype(self) -> torch.dtype:
648
+ """
649
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
650
+ """
651
+ return get_parameter_dtype(self)
652
+
653
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
654
+ """
655
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
656
+
657
+ Args:
658
+ only_trainable (`bool`, *optional*, defaults to `False`):
659
+ Whether or not to return only the number of trainable parameters
660
+
661
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
662
+ Whether or not to return only the number of non-embeddings parameters
663
+
664
+ Returns:
665
+ `int`: The number of parameters.
666
+ """
667
+
668
+ if exclude_embeddings:
669
+ embedding_param_names = [
670
+ f"{name}.weight"
671
+ for name, module_type in self.named_modules()
672
+ if isinstance(module_type, torch.nn.Embedding)
673
+ ]
674
+ non_embedding_parameters = [
675
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
676
+ ]
677
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
678
+ else:
679
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
680
+
681
+
682
+ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
683
+ """
684
+ Recursively unwraps a model from potential containers (as used in distributed training).
685
+
686
+ Args:
687
+ model (`torch.nn.Module`): The model to unwrap.
688
+ """
689
+ # since there could be multiple levels of wrapping, unwrap recursively
690
+ if hasattr(model, "module"):
691
+ return unwrap_model(model.module)
692
+ else:
693
+ return model
src/diffusers_/pipeline_utils.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import importlib
18
+ import inspect
19
+ import os
20
+ from dataclasses import dataclass
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Union
23
+
24
+ import numpy as np
25
+ import torch
26
+
27
+ import diffusers
28
+ import PIL
29
+ from huggingface_hub import snapshot_download
30
+ from packaging import version
31
+ from PIL import Image
32
+ from tqdm.auto import tqdm
33
+
34
+ from .configuration_utils import ConfigMixin
35
+ from .dynamic_modules_utils import get_class_from_dynamic_module
36
+ from .hub_utils import http_user_agent
37
+ from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
38
+ from .scheduling_utils import SCHEDULER_CONFIG_NAME
39
+ from .utils import (
40
+ CONFIG_NAME,
41
+ DIFFUSERS_CACHE,
42
+ ONNX_WEIGHTS_NAME,
43
+ WEIGHTS_NAME,
44
+ BaseOutput,
45
+ deprecate,
46
+ is_accelerate_available,
47
+ is_torch_version,
48
+ is_transformers_available,
49
+ logging,
50
+ )
51
+
52
+
53
+ if is_transformers_available():
54
+ import transformers
55
+ from transformers import PreTrainedModel
56
+
57
+
58
+ INDEX_FILE = "diffusion_pytorch_model.bin"
59
+ CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
60
+ DUMMY_MODULES_FOLDER = "diffusers.utils"
61
+ TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
62
+
63
+
64
+ logger = logging.get_logger(__name__)
65
+
66
+
67
+ LOADABLE_CLASSES = {
68
+ "diffusers": {
69
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
70
+ "SchedulerMixin": ["save_pretrained", "from_pretrained"],
71
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
72
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
73
+ },
74
+ "transformers": {
75
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
76
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
77
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
78
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
79
+ "ProcessorMixin": ["save_pretrained", "from_pretrained"],
80
+ "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
81
+ },
82
+ "onnxruntime.training": {
83
+ "ORTModule": ["save_pretrained", "from_pretrained"],
84
+ },
85
+ }
86
+
87
+ ALL_IMPORTABLE_CLASSES = {}
88
+ for library in LOADABLE_CLASSES:
89
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
90
+
91
+
92
+ @dataclass
93
+ class ImagePipelineOutput(BaseOutput):
94
+ """
95
+ Output class for image pipelines.
96
+
97
+ Args:
98
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
99
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
100
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
101
+ """
102
+
103
+ images: Union[List[PIL.Image.Image], np.ndarray]
104
+
105
+
106
+ @dataclass
107
+ class AudioPipelineOutput(BaseOutput):
108
+ """
109
+ Output class for audio pipelines.
110
+
111
+ Args:
112
+ audios (`np.ndarray`)
113
+ List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
114
+ denoised audio samples of the diffusion pipeline.
115
+ """
116
+
117
+ audios: np.ndarray
118
+
119
+
120
+ class DiffusionPipeline(ConfigMixin):
121
+ r"""
122
+ Base class for all models.
123
+
124
+ [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
125
+ and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
126
+
127
+ - move all PyTorch modules to the device of your choice
128
+ - enabling/disabling the progress bar for the denoising iteration
129
+
130
+ Class attributes:
131
+
132
+ - **config_name** (`str`) -- name of the config file that will store the class and module names of all
133
+ components of the diffusion pipeline.
134
+ - **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
135
+ passed for the pipeline to function (should be overridden by subclasses).
136
+ """
137
+ config_name = "model_index.json"
138
+ _optional_components = []
139
+
140
+ def register_modules(self, **kwargs):
141
+ # import it here to avoid circular import
142
+ from diffusers import pipelines
143
+
144
+ for name, module in kwargs.items():
145
+ # retrieve library
146
+ if module is None:
147
+ register_dict = {name: (None, None)}
148
+ else:
149
+ library = module.__module__.split(".")[0]
150
+
151
+ # check if the module is a pipeline module
152
+ pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
153
+ path = module.__module__.split(".")
154
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
155
+
156
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
157
+ # Or if it's a pipeline module, then the module is inside the pipeline
158
+ # folder so we set the library to module name.
159
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
160
+ library = pipeline_dir
161
+
162
+ # retrieve class_name
163
+ class_name = module.__class__.__name__
164
+
165
+ register_dict = {name: (library, class_name)}
166
+
167
+ # save model index config
168
+ self.register_to_config(**register_dict)
169
+
170
+ # set models
171
+ setattr(self, name, module)
172
+
173
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
174
+ """
175
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
176
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
177
+ method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
178
+
179
+ Arguments:
180
+ save_directory (`str` or `os.PathLike`):
181
+ Directory to which to save. Will be created if it doesn't exist.
182
+ """
183
+ self.save_config(save_directory)
184
+
185
+ model_index_dict = dict(self.config)
186
+ model_index_dict.pop("_class_name")
187
+ model_index_dict.pop("_diffusers_version")
188
+ model_index_dict.pop("_module", None)
189
+
190
+ expected_modules, optional_kwargs = self._get_signature_keys(self)
191
+
192
+ def is_saveable_module(name, value):
193
+ if name not in expected_modules:
194
+ return False
195
+ if name in self._optional_components and value[0] is None:
196
+ return False
197
+ return True
198
+
199
+ model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
200
+
201
+ for pipeline_component_name in model_index_dict.keys():
202
+ sub_model = getattr(self, pipeline_component_name)
203
+ model_cls = sub_model.__class__
204
+
205
+ save_method_name = None
206
+ # search for the model's base class in LOADABLE_CLASSES
207
+ for library_name, library_classes in LOADABLE_CLASSES.items():
208
+ library = importlib.import_module(library_name)
209
+ for base_class, save_load_methods in library_classes.items():
210
+ class_candidate = getattr(library, base_class, None)
211
+ if class_candidate is not None and issubclass(model_cls, class_candidate):
212
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
213
+ save_method_name = save_load_methods[0]
214
+ break
215
+ if save_method_name is not None:
216
+ break
217
+
218
+ if save_method_name is not None:
219
+ save_method = getattr(sub_model, save_method_name)
220
+ save_method(os.path.join(save_directory, pipeline_component_name))
221
+
222
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
223
+ if torch_device is None:
224
+ return self
225
+
226
+ module_names, _, _ = self.extract_init_dict(dict(self.config))
227
+ for name in module_names.keys():
228
+ module = getattr(self, name)
229
+ if isinstance(module, torch.nn.Module):
230
+ if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
231
+ logger.warning(
232
+ "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
233
+ " is not recommended to move them to `cpu` as running them will fail. Please make"
234
+ " sure to use an accelerator to run the pipeline in inference, due to the lack of"
235
+ " support for`float16` operations on this device in PyTorch. Please, remove the"
236
+ " `torch_dtype=torch.float16` argument, or use another device for inference."
237
+ )
238
+ module.to(torch_device)
239
+ return self
240
+
241
+ @property
242
+ def device(self) -> torch.device:
243
+ r"""
244
+ Returns:
245
+ `torch.device`: The torch device on which the pipeline is located.
246
+ """
247
+ module_names, _, _ = self.extract_init_dict(dict(self.config))
248
+ for name in module_names.keys():
249
+ module = getattr(self, name)
250
+ if isinstance(module, torch.nn.Module):
251
+ return module.device
252
+ return torch.device("cpu")
253
+
254
+ @classmethod
255
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
256
+ r"""
257
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
258
+
259
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
260
+
261
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
262
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
263
+ task.
264
+
265
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
266
+ weights are discarded.
267
+
268
+ Parameters:
269
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
270
+ Can be either:
271
+
272
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
273
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
274
+ `CompVis/ldm-text2im-large-256`.
275
+ - A path to a *directory* containing pipeline weights saved using
276
+ [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
277
+ torch_dtype (`str` or `torch.dtype`, *optional*):
278
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
279
+ will be automatically derived from the model's weights.
280
+ custom_pipeline (`str`, *optional*):
281
+
282
+ <Tip warning={true}>
283
+
284
+ This is an experimental feature and is likely to change in the future.
285
+
286
+ </Tip>
287
+
288
+ Can be either:
289
+
290
+ - A string, the *repo id* of a custom pipeline hosted inside a model repo on
291
+ https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
292
+ like `hf-internal-testing/diffusers-dummy-pipeline`.
293
+
294
+ <Tip>
295
+
296
+ It is required that the model repo has a file, called `pipeline.py` that defines the custom
297
+ pipeline.
298
+
299
+ </Tip>
300
+
301
+ - A string, the *file name* of a community pipeline hosted on GitHub under
302
+ https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
303
+ match exactly the file name without `.py` located under the above link, *e.g.*
304
+ `clip_guided_stable_diffusion`.
305
+
306
+ <Tip>
307
+
308
+ Community pipelines are always loaded from the current `main` branch of GitHub.
309
+
310
+ </Tip>
311
+
312
+ - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
313
+
314
+ <Tip>
315
+
316
+ It is required that the directory has a file, called `pipeline.py` that defines the custom
317
+ pipeline.
318
+
319
+ </Tip>
320
+
321
+ For more information on how to load and create custom pipelines, please have a look at [Loading and
322
+ Adding Custom
323
+ Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
324
+
325
+ torch_dtype (`str` or `torch.dtype`, *optional*):
326
+ force_download (`bool`, *optional*, defaults to `False`):
327
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
328
+ cached versions if they exist.
329
+ resume_download (`bool`, *optional*, defaults to `False`):
330
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
331
+ file exists.
332
+ proxies (`Dict[str, str]`, *optional*):
333
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
334
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
335
+ output_loading_info(`bool`, *optional*, defaults to `False`):
336
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
337
+ local_files_only(`bool`, *optional*, defaults to `False`):
338
+ Whether or not to only look at local files (i.e., do not try to download the model).
339
+ use_auth_token (`str` or *bool*, *optional*):
340
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
341
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
342
+ revision (`str`, *optional*, defaults to `"main"`):
343
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
344
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
345
+ identifier allowed by git.
346
+ mirror (`str`, *optional*):
347
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
348
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
349
+ Please refer to the mirror site for more information. specify the folder name here.
350
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
351
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
352
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
353
+ same device.
354
+
355
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
356
+ more information about each option see [designing a device
357
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
358
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
359
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
360
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
361
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
362
+ setting this argument to `True` will raise an error.
363
+
364
+ kwargs (remaining dictionary of keyword arguments, *optional*):
365
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
366
+ specific pipeline class. The overwritten components are then directly passed to the pipelines
367
+ `__init__` method. See example below for more information.
368
+
369
+ <Tip>
370
+
371
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
372
+ models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
373
+
374
+ </Tip>
375
+
376
+ <Tip>
377
+
378
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
379
+ this method in a firewalled environment.
380
+
381
+ </Tip>
382
+
383
+ Examples:
384
+
385
+ ```py
386
+ >>> from diffusers import DiffusionPipeline
387
+
388
+ >>> # Download pipeline from huggingface.co and cache.
389
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
390
+
391
+ >>> # Download pipeline that requires an authorization token
392
+ >>> # For more information on access tokens, please refer to this section
393
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
394
+ >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
395
+
396
+ >>> # Use a different scheduler
397
+ >>> from diffusers import LMSDiscreteScheduler
398
+
399
+ >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
400
+ >>> pipeline.scheduler = scheduler
401
+ ```
402
+ """
403
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
404
+ resume_download = kwargs.pop("resume_download", False)
405
+ force_download = kwargs.pop("force_download", False)
406
+ proxies = kwargs.pop("proxies", None)
407
+ local_files_only = kwargs.pop("local_files_only", False)
408
+ use_auth_token = kwargs.pop("use_auth_token", None)
409
+ revision = kwargs.pop("revision", None)
410
+ torch_dtype = kwargs.pop("torch_dtype", None)
411
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
412
+ provider = kwargs.pop("provider", None)
413
+ sess_options = kwargs.pop("sess_options", None)
414
+ device_map = kwargs.pop("device_map", None)
415
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
416
+
417
+ if low_cpu_mem_usage and not is_accelerate_available():
418
+ low_cpu_mem_usage = False
419
+ logger.warning(
420
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
421
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
422
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
423
+ " install accelerate\n```\n."
424
+ )
425
+
426
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
427
+ raise NotImplementedError(
428
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
429
+ " `device_map=None`."
430
+ )
431
+
432
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
433
+ raise NotImplementedError(
434
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
435
+ " `low_cpu_mem_usage=False`."
436
+ )
437
+
438
+ if low_cpu_mem_usage is False and device_map is not None:
439
+ raise ValueError(
440
+ f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
441
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
442
+ )
443
+
444
+ # 1. Download the checkpoints and configs
445
+ # use snapshot download here to get it working from from_pretrained
446
+ if not os.path.isdir(pretrained_model_name_or_path):
447
+ config_dict = cls.load_config(
448
+ pretrained_model_name_or_path,
449
+ cache_dir=cache_dir,
450
+ resume_download=resume_download,
451
+ force_download=force_download,
452
+ proxies=proxies,
453
+ local_files_only=local_files_only,
454
+ use_auth_token=use_auth_token,
455
+ revision=revision,
456
+ )
457
+ # make sure we only download sub-folders and `diffusers` filenames
458
+ folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
459
+ allow_patterns = [os.path.join(k, "*") for k in folder_names]
460
+ allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
461
+
462
+ # make sure we don't download flax weights
463
+ ignore_patterns = "*.msgpack"
464
+
465
+ if custom_pipeline is not None:
466
+ allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
467
+
468
+ if cls != DiffusionPipeline:
469
+ requested_pipeline_class = cls.__name__
470
+ else:
471
+ requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
472
+ user_agent = {"pipeline_class": requested_pipeline_class}
473
+ if custom_pipeline is not None:
474
+ user_agent["custom_pipeline"] = custom_pipeline
475
+ user_agent = http_user_agent(user_agent)
476
+
477
+ # download all allow_patterns
478
+ cached_folder = snapshot_download(
479
+ pretrained_model_name_or_path,
480
+ cache_dir=cache_dir,
481
+ resume_download=resume_download,
482
+ proxies=proxies,
483
+ local_files_only=local_files_only,
484
+ use_auth_token=use_auth_token,
485
+ revision=revision,
486
+ allow_patterns=allow_patterns,
487
+ ignore_patterns=ignore_patterns,
488
+ user_agent=user_agent,
489
+ )
490
+ else:
491
+ cached_folder = pretrained_model_name_or_path
492
+
493
+ config_dict = cls.load_config(cached_folder)
494
+
495
+ # 2. Load the pipeline class, if using custom module then load it from the hub
496
+ # if we load from explicit class, let's use it
497
+ if custom_pipeline is not None:
498
+ if custom_pipeline.endswith(".py"):
499
+ path = Path(custom_pipeline)
500
+ # decompose into folder & file
501
+ file_name = path.name
502
+ custom_pipeline = path.parent.absolute()
503
+ else:
504
+ file_name = CUSTOM_PIPELINE_FILE_NAME
505
+ import ipdb; ipdb.set_trace()
506
+ pipeline_class = get_class_from_dynamic_module(
507
+ custom_pipeline, module_file=file_name, cache_dir=custom_pipeline
508
+ )
509
+ elif cls != DiffusionPipeline:
510
+ pipeline_class = cls
511
+ else:
512
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
513
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
514
+
515
+ # To be removed in 1.0.0
516
+ if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
517
+ version.parse(config_dict["_diffusers_version"]).base_version
518
+ ) <= version.parse("0.5.1"):
519
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
520
+
521
+ pipeline_class = StableDiffusionInpaintPipelineLegacy
522
+
523
+ deprecation_message = (
524
+ "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
525
+ f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
526
+ " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
527
+ " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
528
+ f" checkpoint {pretrained_model_name_or_path} to the format of"
529
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
530
+ " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
531
+ )
532
+ deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
533
+
534
+ # some modules can be passed directly to the init
535
+ # in this case they are already instantiated in `kwargs`
536
+ # extract them here
537
+ expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
538
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
539
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
540
+
541
+ init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
542
+
543
+ # define init kwargs
544
+ init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
545
+ init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
546
+
547
+ # remove `null` components
548
+ def load_module(name, value):
549
+ if value[0] is None:
550
+ return False
551
+ if name in passed_class_obj and passed_class_obj[name] is None:
552
+ return False
553
+ return True
554
+
555
+ init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
556
+
557
+ if len(unused_kwargs) > 0:
558
+ logger.warning(
559
+ f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
560
+ )
561
+
562
+ # import it here to avoid circular import
563
+ from diffusers import pipelines
564
+
565
+ # 3. Load each module in the pipeline
566
+ for name, (library_name, class_name) in init_dict.items():
567
+ # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
568
+ if class_name.startswith("Flax"):
569
+ class_name = class_name[4:]
570
+
571
+ is_pipeline_module = hasattr(pipelines, library_name)
572
+ loaded_sub_model = None
573
+
574
+ # if the model is in a pipeline module, then we load it from the pipeline
575
+ if name in passed_class_obj:
576
+ # 1. check that passed_class_obj has correct parent class
577
+ if not is_pipeline_module:
578
+ library = importlib.import_module(library_name)
579
+ class_obj = getattr(library, class_name)
580
+ importable_classes = LOADABLE_CLASSES[library_name]
581
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
582
+
583
+ expected_class_obj = None
584
+ for class_name, class_candidate in class_candidates.items():
585
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
586
+ expected_class_obj = class_candidate
587
+
588
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
589
+ raise ValueError(
590
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
591
+ f" {expected_class_obj}"
592
+ )
593
+ else:
594
+ logger.warning(
595
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
596
+ " has the correct type"
597
+ )
598
+
599
+ # set passed class object
600
+ loaded_sub_model = passed_class_obj[name]
601
+ elif is_pipeline_module:
602
+ pipeline_module = getattr(pipelines, library_name)
603
+ class_obj = getattr(pipeline_module, class_name)
604
+ importable_classes = ALL_IMPORTABLE_CLASSES
605
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
606
+ else:
607
+ # else we just import it from the library.
608
+ library = importlib.import_module(library_name)
609
+
610
+ class_obj = getattr(library, class_name)
611
+ importable_classes = LOADABLE_CLASSES[library_name]
612
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
613
+
614
+ if loaded_sub_model is None:
615
+ load_method_name = None
616
+ for class_name, class_candidate in class_candidates.items():
617
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
618
+ load_method_name = importable_classes[class_name][1]
619
+
620
+ if load_method_name is None:
621
+ none_module = class_obj.__module__
622
+ is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
623
+ TRANSFORMERS_DUMMY_MODULES_FOLDER
624
+ )
625
+ if is_dummy_path and "dummy" in none_module:
626
+ # call class_obj for nice error message of missing requirements
627
+ class_obj()
628
+
629
+ raise ValueError(
630
+ f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
631
+ f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
632
+ )
633
+
634
+ load_method = getattr(class_obj, load_method_name)
635
+ loading_kwargs = {}
636
+
637
+ if issubclass(class_obj, torch.nn.Module):
638
+ loading_kwargs["torch_dtype"] = torch_dtype
639
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
640
+ loading_kwargs["provider"] = provider
641
+ loading_kwargs["sess_options"] = sess_options
642
+
643
+ is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
644
+ is_transformers_model = (
645
+ is_transformers_available()
646
+ and issubclass(class_obj, PreTrainedModel)
647
+ and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
648
+ )
649
+
650
+ # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
651
+ # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
652
+ # This makes sure that the weights won't be initialized which significantly speeds up loading.
653
+ if is_diffusers_model or is_transformers_model:
654
+ loading_kwargs["device_map"] = device_map
655
+ loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
656
+
657
+ # check if the module is in a subdirectory
658
+ if os.path.isdir(os.path.join(cached_folder, name)):
659
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
660
+ else:
661
+ # else load from the root directory
662
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
663
+
664
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
665
+
666
+ # 4. Potentially add passed objects if expected
667
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
668
+ passed_modules = list(passed_class_obj.keys())
669
+ optional_modules = pipeline_class._optional_components
670
+ if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
671
+ for module in missing_modules:
672
+ init_kwargs[module] = passed_class_obj.get(module, None)
673
+ elif len(missing_modules) > 0:
674
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
675
+ # raise ValueError(
676
+ # f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
677
+ # )
678
+
679
+ # 5. Instantiate the pipeline
680
+ model = pipeline_class(**init_kwargs)
681
+ return model
682
+
683
+ @staticmethod
684
+ def _get_signature_keys(obj):
685
+ parameters = inspect.signature(obj.__init__).parameters
686
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
687
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
688
+ expected_modules = set(required_parameters.keys()) - set(["self"])
689
+ return expected_modules, optional_parameters
690
+
691
+ @property
692
+ def components(self) -> Dict[str, Any]:
693
+ r"""
694
+
695
+ The `self.components` property can be useful to run different pipelines with the same weights and
696
+ configurations to not have to re-allocate memory.
697
+
698
+ Examples:
699
+
700
+ ```py
701
+ >>> from diffusers import (
702
+ ... StableDiffusionPipeline,
703
+ ... StableDiffusionImg2ImgPipeline,
704
+ ... StableDiffusionInpaintPipeline,
705
+ ... )
706
+
707
+ >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
708
+ >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
709
+ >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
710
+ ```
711
+
712
+ Returns:
713
+ A dictionaly containing all the modules needed to initialize the pipeline.
714
+ """
715
+ expected_modules, optional_parameters = self._get_signature_keys(self)
716
+ components = {
717
+ k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
718
+ }
719
+
720
+ if set(components.keys()) != expected_modules:
721
+ raise ValueError(
722
+ f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
723
+ f" {expected_modules} to be defined, but {components} are defined."
724
+ )
725
+
726
+ return components
727
+
728
+ @staticmethod
729
+ def numpy_to_pil(images):
730
+ """
731
+ Convert a numpy image or a batch of images to a PIL image.
732
+ """
733
+ if images.ndim == 3:
734
+ images = images[None, ...]
735
+ images = (images * 255).round().astype("uint8")
736
+ if images.shape[-1] == 1:
737
+ # special case for grayscale (single channel) images
738
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
739
+ else:
740
+ pil_images = [Image.fromarray(image) for image in images]
741
+
742
+ return pil_images
743
+
744
+ def progress_bar(self, iterable):
745
+ if not hasattr(self, "_progress_bar_config"):
746
+ self._progress_bar_config = {}
747
+ elif not isinstance(self._progress_bar_config, dict):
748
+ raise ValueError(
749
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
750
+ )
751
+
752
+ return tqdm(iterable, **self._progress_bar_config)
753
+
754
+ def set_progress_bar_config(self, **kwargs):
755
+ self._progress_bar_config = kwargs
src/diffusers_/scheduling_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import os
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, Optional, Union
18
+
19
+ import torch
20
+
21
+ from .utils import BaseOutput
22
+
23
+
24
+ SCHEDULER_CONFIG_NAME = "scheduler_config.json"
25
+
26
+
27
+ @dataclass
28
+ class SchedulerOutput(BaseOutput):
29
+ """
30
+ Base class for the scheduler's step function output.
31
+
32
+ Args:
33
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
34
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
35
+ denoising loop.
36
+ """
37
+
38
+ prev_sample: torch.FloatTensor
39
+
40
+
41
+ class SchedulerMixin:
42
+ """
43
+ Mixin containing common functions for the schedulers.
44
+
45
+ Class attributes:
46
+ - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
47
+ `from_config` can be used from a class different than the one used to save the config (should be overridden
48
+ by parent class).
49
+ """
50
+
51
+ config_name = SCHEDULER_CONFIG_NAME
52
+ _compatibles = []
53
+ has_compatibles = True
54
+
55
+ @classmethod
56
+ def from_pretrained(
57
+ cls,
58
+ pretrained_model_name_or_path: Dict[str, Any] = None,
59
+ subfolder: Optional[str] = None,
60
+ return_unused_kwargs=False,
61
+ **kwargs,
62
+ ):
63
+ r"""
64
+ Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
65
+
66
+ Parameters:
67
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
68
+ Can be either:
69
+
70
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
71
+ organization name, like `google/ddpm-celebahq-256`.
72
+ - A path to a *directory* containing the schedluer configurations saved using
73
+ [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
74
+ subfolder (`str`, *optional*):
75
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
76
+ huggingface.co or downloaded locally), you can specify the folder name here.
77
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
78
+ Whether kwargs that are not consumed by the Python class should be returned or not.
79
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
80
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
81
+ standard cache should not be used.
82
+ force_download (`bool`, *optional*, defaults to `False`):
83
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
84
+ cached versions if they exist.
85
+ resume_download (`bool`, *optional*, defaults to `False`):
86
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
87
+ file exists.
88
+ proxies (`Dict[str, str]`, *optional*):
89
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
90
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
91
+ output_loading_info(`bool`, *optional*, defaults to `False`):
92
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
93
+ local_files_only(`bool`, *optional*, defaults to `False`):
94
+ Whether or not to only look at local files (i.e., do not try to download the model).
95
+ use_auth_token (`str` or *bool*, *optional*):
96
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
97
+ when running `transformers-cli login` (stored in `~/.huggingface`).
98
+ revision (`str`, *optional*, defaults to `"main"`):
99
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
100
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
101
+ identifier allowed by git.
102
+
103
+ <Tip>
104
+
105
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
106
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
107
+
108
+ </Tip>
109
+
110
+ <Tip>
111
+
112
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
113
+ use this method in a firewalled environment.
114
+
115
+ </Tip>
116
+
117
+ """
118
+ config, kwargs = cls.load_config(
119
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
120
+ subfolder=subfolder,
121
+ return_unused_kwargs=True,
122
+ **kwargs,
123
+ )
124
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
125
+
126
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
127
+ """
128
+ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
129
+ [`~SchedulerMixin.from_pretrained`] class method.
130
+
131
+ Args:
132
+ save_directory (`str` or `os.PathLike`):
133
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
134
+ """
135
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
136
+
137
+ @property
138
+ def compatibles(self):
139
+ """
140
+ Returns all schedulers that are compatible with this scheduler
141
+
142
+ Returns:
143
+ `List[SchedulerMixin]`: List of compatible schedulers
144
+ """
145
+ return self._get_compatibles()
146
+
147
+ @classmethod
148
+ def _get_compatibles(cls):
149
+ compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
150
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
151
+ compatible_classes = [
152
+ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
153
+ ]
154
+ return compatible_classes
src/diffusers_/stable_diffusion/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+
6
+ import PIL
7
+ from PIL import Image
8
+
9
+ from ..utils import (
10
+ BaseOutput,
11
+ is_torch_available,
12
+ is_transformers_available,
13
+ )
14
+
15
+
16
+ @dataclass
17
+ class StableDiffusionPipelineOutput(BaseOutput):
18
+ """
19
+ Output class for Stable Diffusion pipelines.
20
+
21
+ Args:
22
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
23
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
24
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
25
+ nsfw_content_detected (`List[bool]`)
26
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
27
+ (nsfw) content, or `None` if safety checking could not be performed.
28
+ """
29
+
30
+ images: Union[List[PIL.Image.Image], np.ndarray]
31
+ nsfw_content_detected: Optional[List[bool]]
32
+
33
+
34
+ if is_transformers_available() and is_torch_available():
35
+ from .pipeline_stable_diffusion import StableDiffusionPipeline
src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.47 kB). View file
 
src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (3.81 kB). View file
 
src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.48 kB). View file
 
src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-37.pyc ADDED
Binary file (23.4 kB). View file
 
src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc ADDED
Binary file (23.7 kB). View file
 
src/diffusers_/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-38.pyc ADDED
Binary file (14.6 kB). View file
 
src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-38.pyc ADDED
Binary file (9.97 kB). View file
 
src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-38.pyc ADDED
Binary file (16.7 kB). View file
 
src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-38.pyc ADDED
Binary file (17.3 kB). View file
 
src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-38.pyc ADDED
Binary file (18.4 kB). View file