Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +99 -0
- src/diffusers_/__init__.py +15 -0
- src/diffusers_/__pycache__/__init__.cpython-310.pyc +0 -0
- src/diffusers_/__pycache__/__init__.cpython-37.pyc +0 -0
- src/diffusers_/__pycache__/__init__.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/configuration_utils.cpython-310.pyc +0 -0
- src/diffusers_/__pycache__/configuration_utils.cpython-37.pyc +0 -0
- src/diffusers_/__pycache__/configuration_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/dependency_versions_check.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/dependency_versions_table.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/dynamic_modules_utils.cpython-310.pyc +0 -0
- src/diffusers_/__pycache__/dynamic_modules_utils.cpython-37.pyc +0 -0
- src/diffusers_/__pycache__/dynamic_modules_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/hub_utils.cpython-310.pyc +0 -0
- src/diffusers_/__pycache__/hub_utils.cpython-37.pyc +0 -0
- src/diffusers_/__pycache__/hub_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/modeling_flax_pytorch_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/modeling_flax_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/modeling_utils.cpython-310.pyc +0 -0
- src/diffusers_/__pycache__/modeling_utils.cpython-37.pyc +0 -0
- src/diffusers_/__pycache__/modeling_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/onnx_utils.cpython-37.pyc +0 -0
- src/diffusers_/__pycache__/onnx_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/optimization.cpython-37.pyc +0 -0
- src/diffusers_/__pycache__/optimization.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/pipeline_flax_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/pipeline_utils.cpython-310.pyc +0 -0
- src/diffusers_/__pycache__/pipeline_utils.cpython-37.pyc +0 -0
- src/diffusers_/__pycache__/pipeline_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/scheduling_utils.cpython-310.pyc +0 -0
- src/diffusers_/__pycache__/scheduling_utils.cpython-38.pyc +0 -0
- src/diffusers_/__pycache__/training_utils.cpython-37.pyc +0 -0
- src/diffusers_/__pycache__/training_utils.cpython-38.pyc +0 -0
- src/diffusers_/configuration_utils.py +605 -0
- src/diffusers_/dynamic_modules_utils.py +428 -0
- src/diffusers_/hub_utils.py +246 -0
- src/diffusers_/modeling_utils.py +693 -0
- src/diffusers_/pipeline_utils.py +755 -0
- src/diffusers_/scheduling_utils.py +154 -0
- src/diffusers_/stable_diffusion/__init__.py +35 -0
- src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
- src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-37.pyc +0 -0
- src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
- src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-37.pyc +0 -0
- src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc +0 -0
- src/diffusers_/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-38.pyc +0 -0
- src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-38.pyc +0 -0
- src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-38.pyc +0 -0
- src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-38.pyc +0 -0
- src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-38.pyc +0 -0
app.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
from PIL import Image
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from src.utils.gradio_utils import *
|
7 |
+
|
8 |
+
|
9 |
+
if __name__=="__main__":
|
10 |
+
step_dict = {'800': 800, '900': 900, '1000': 1000, '1100': 1100}
|
11 |
+
|
12 |
+
with gr.Blocks(css=CSS_main) as demo:
|
13 |
+
gr.HTML(HTML_header)
|
14 |
+
|
15 |
+
with gr.Row():
|
16 |
+
# col A: Optimize personalized embedding
|
17 |
+
with gr.Column(scale=2) as gc_left:
|
18 |
+
gr.HTML(" <center> <p style='font-size:150%;'> [Step 1] Optimize personalized embedding </p> </center>")
|
19 |
+
img_in_real = gr.Image(type="pil", label="Start by uploading the source image", elem_id="input_image").style(height=300, width=300)
|
20 |
+
gr.Examples( examples="src_image", inputs=[img_in_real])
|
21 |
+
prompt = gr.Textbox(value="a standing dog", label="Source text prompt (Describe the source image)", interactive=True)
|
22 |
+
n_hiper = gr.Slider(5, 10, 5, label="Number of personalized embedding (Tips! 5 for animals / 10 for humans)", interactive=True, step=1)
|
23 |
+
|
24 |
+
btn_optimize = gr.Button("Optimize", label="")
|
25 |
+
fpath_z_gen = gr.Textbox(value="placeholder", visible=False)
|
26 |
+
|
27 |
+
gr.HTML(" <center> <p style='font-size:150%;'> See the [Step 1] results with different optimization steps </p> </center>")
|
28 |
+
with gr.Row():
|
29 |
+
with gr.Column(scale=0.3, min_width=0.7) as gc_left:
|
30 |
+
# btn_source = gr.Button("Source image", label="")
|
31 |
+
btn_opt_step800 = gr.Button("Step 800", label="")
|
32 |
+
btn_opt_step900 = gr.Button("Step 900", label="")
|
33 |
+
btn_opt_step1000 = gr.Button("Step 1000", label="")
|
34 |
+
btn_opt_step1100 = gr.Button("Step 1100", label="")
|
35 |
+
with gr.Column(scale=0.5, min_width=0.8) as gc_left:
|
36 |
+
img_src = gr.Image(type="pil", label="Source image", visible=True).style(height=250, width=250)
|
37 |
+
with gr.Column(scale=0.5, min_width=0.8) as gc_left:
|
38 |
+
img_out_opt = gr.Image(type="pil", label="Optimization step output", visible=True).style(height=250, width=250)
|
39 |
+
|
40 |
+
|
41 |
+
# col B: Generate target image
|
42 |
+
with gr.Column(scale=2) as gc_left:
|
43 |
+
|
44 |
+
gr.HTML(" <center> <p style='font-size:150%;'> [Step 2] Generate target image </p> </center>")
|
45 |
+
with gr.Row():
|
46 |
+
|
47 |
+
with gr.Column():
|
48 |
+
dest = gr.Textbox(value="a sitting dog", label="Target text prompt", interactive=True)
|
49 |
+
step = gr.Radio(["Step 800", "Step 900", "Step 1000", "Step 1100"], value="Step 1000", label="Training optimization step \n (Refer to the personalized results corresponding to each optimization step listed in the left column.)")
|
50 |
+
seed = gr.Number(value=111111, label="Random seed", interactive=True)
|
51 |
+
with gr.Row():
|
52 |
+
btn_generate = gr.Button("Generate", label="")
|
53 |
+
img_out = gr.Image(type="pil", label="Output Image", visible=True)
|
54 |
+
|
55 |
+
with gr.Accordion("Instruction", open=True):
|
56 |
+
gr.Textbox("In NVIDIA GeForce GTX 3090, [step 1] takes about 4 minutes and [step 2] takes about 1 minute.", show_label=False)
|
57 |
+
gr.Textbox("At [step 1], put the desired source image and write the source text that describes the source image. If it is difficult to describe, you can use a noun such as 'a dog' or 'a woman.' Then decide on the number of desired personalized embeddings.", show_label=False)
|
58 |
+
gr.Textbox("After [step 1], you can check the personalized results with different optimization steps and select the optimization step. First, check if the image at step 1000 has a subject similar to the source image. In the paper, we use the 1000 step for optimization almost.", show_label=False)
|
59 |
+
gr.Textbox("At [step 2], write the derised target text. Then, refer to the generated personalized image in the bottom left and choose an optimization. If the desired image is not obtained, try another random seed.", show_label=False)
|
60 |
+
|
61 |
+
|
62 |
+
############
|
63 |
+
btn_optimize.click(launch_optimize, [img_in_real, prompt, n_hiper], [fpath_z_gen, img_src])
|
64 |
+
def fn_set_none():
|
65 |
+
return gr.update(value=None)
|
66 |
+
btn_optimize.click(fn_set_none, [], img_in_real)
|
67 |
+
# btn_optimize.click(set_visible_true, [], img_in_synth)
|
68 |
+
btn_optimize.click(set_visible_false, [], img_in_real)
|
69 |
+
|
70 |
+
|
71 |
+
############
|
72 |
+
def fn_clear_all():
|
73 |
+
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
|
74 |
+
|
75 |
+
img_in_real.clear(fn_clear_all, [], [img_out, img_in_real])#, img_in_synth])
|
76 |
+
# img_in_real.clear(set_visible_true, [], img_in_synth)
|
77 |
+
img_in_real.clear(set_visible_false, [], img_in_real)
|
78 |
+
|
79 |
+
img_out.clear(fn_clear_all, [], [img_out, img_in_real])#, img_in_synth])
|
80 |
+
|
81 |
+
|
82 |
+
############
|
83 |
+
btn_generate.click(launch_main,
|
84 |
+
[
|
85 |
+
dest, step,
|
86 |
+
fpath_z_gen, seed,
|
87 |
+
],
|
88 |
+
[img_out]
|
89 |
+
)
|
90 |
+
############
|
91 |
+
btn_opt_step800.click(launch_opt800, [],[img_out_opt])
|
92 |
+
btn_opt_step900.click(launch_opt900, [],[img_out_opt])
|
93 |
+
btn_opt_step1000.click(launch_opt1000, [],[img_out_opt])
|
94 |
+
btn_opt_step1100.click(launch_opt1100, [],[img_out_opt])
|
95 |
+
gr.HTML("<hr>")
|
96 |
+
|
97 |
+
gr.close_all()
|
98 |
+
demo.queue(concurrency_count=1)
|
99 |
+
demo.launch(server_port=2222, server_name="0.0.0.0", debug=True,share=True)
|
src/diffusers_/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import (
|
2 |
+
is_torch_available,
|
3 |
+
is_transformers_available,
|
4 |
+
)
|
5 |
+
|
6 |
+
|
7 |
+
__version__ = "0.9.0"
|
8 |
+
|
9 |
+
|
10 |
+
if is_torch_available() and is_transformers_available():
|
11 |
+
from .stable_diffusion import (
|
12 |
+
StableDiffusionPipeline,
|
13 |
+
)
|
14 |
+
else:
|
15 |
+
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
src/diffusers_/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (383 Bytes). View file
|
|
src/diffusers_/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (3.7 kB). View file
|
|
src/diffusers_/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (398 Bytes). View file
|
|
src/diffusers_/__pycache__/configuration_utils.cpython-310.pyc
ADDED
Binary file (21.3 kB). View file
|
|
src/diffusers_/__pycache__/configuration_utils.cpython-37.pyc
ADDED
Binary file (21.7 kB). View file
|
|
src/diffusers_/__pycache__/configuration_utils.cpython-38.pyc
ADDED
Binary file (21.5 kB). View file
|
|
src/diffusers_/__pycache__/dependency_versions_check.cpython-38.pyc
ADDED
Binary file (960 Bytes). View file
|
|
src/diffusers_/__pycache__/dependency_versions_table.cpython-38.pyc
ADDED
Binary file (927 Bytes). View file
|
|
src/diffusers_/__pycache__/dynamic_modules_utils.cpython-310.pyc
ADDED
Binary file (13.4 kB). View file
|
|
src/diffusers_/__pycache__/dynamic_modules_utils.cpython-37.pyc
ADDED
Binary file (13.3 kB). View file
|
|
src/diffusers_/__pycache__/dynamic_modules_utils.cpython-38.pyc
ADDED
Binary file (13.4 kB). View file
|
|
src/diffusers_/__pycache__/hub_utils.cpython-310.pyc
ADDED
Binary file (7.15 kB). View file
|
|
src/diffusers_/__pycache__/hub_utils.cpython-37.pyc
ADDED
Binary file (6.92 kB). View file
|
|
src/diffusers_/__pycache__/hub_utils.cpython-38.pyc
ADDED
Binary file (6.99 kB). View file
|
|
src/diffusers_/__pycache__/modeling_flax_pytorch_utils.cpython-38.pyc
ADDED
Binary file (2.64 kB). View file
|
|
src/diffusers_/__pycache__/modeling_flax_utils.cpython-38.pyc
ADDED
Binary file (20.7 kB). View file
|
|
src/diffusers_/__pycache__/modeling_utils.cpython-310.pyc
ADDED
Binary file (23.7 kB). View file
|
|
src/diffusers_/__pycache__/modeling_utils.cpython-37.pyc
ADDED
Binary file (23.8 kB). View file
|
|
src/diffusers_/__pycache__/modeling_utils.cpython-38.pyc
ADDED
Binary file (23.9 kB). View file
|
|
src/diffusers_/__pycache__/onnx_utils.cpython-37.pyc
ADDED
Binary file (6.73 kB). View file
|
|
src/diffusers_/__pycache__/onnx_utils.cpython-38.pyc
ADDED
Binary file (6.85 kB). View file
|
|
src/diffusers_/__pycache__/optimization.cpython-37.pyc
ADDED
Binary file (10.2 kB). View file
|
|
src/diffusers_/__pycache__/optimization.cpython-38.pyc
ADDED
Binary file (10.2 kB). View file
|
|
src/diffusers_/__pycache__/pipeline_flax_utils.cpython-38.pyc
ADDED
Binary file (15.7 kB). View file
|
|
src/diffusers_/__pycache__/pipeline_utils.cpython-310.pyc
ADDED
Binary file (26.1 kB). View file
|
|
src/diffusers_/__pycache__/pipeline_utils.cpython-37.pyc
ADDED
Binary file (26.1 kB). View file
|
|
src/diffusers_/__pycache__/pipeline_utils.cpython-38.pyc
ADDED
Binary file (26.2 kB). View file
|
|
src/diffusers_/__pycache__/scheduling_utils.cpython-310.pyc
ADDED
Binary file (6.87 kB). View file
|
|
src/diffusers_/__pycache__/scheduling_utils.cpython-38.pyc
ADDED
Binary file (6.85 kB). View file
|
|
src/diffusers_/__pycache__/training_utils.cpython-37.pyc
ADDED
Binary file (3.59 kB). View file
|
|
src/diffusers_/__pycache__/training_utils.cpython-38.pyc
ADDED
Binary file (3.63 kB). View file
|
|
src/diffusers_/configuration_utils.py
ADDED
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" ConfigMixin base class and utilities."""
|
17 |
+
import dataclasses
|
18 |
+
import functools
|
19 |
+
import importlib
|
20 |
+
import inspect
|
21 |
+
import json
|
22 |
+
import os
|
23 |
+
import re
|
24 |
+
from collections import OrderedDict
|
25 |
+
from typing import Any, Dict, Tuple, Union
|
26 |
+
|
27 |
+
from huggingface_hub import hf_hub_download
|
28 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
29 |
+
from requests import HTTPError
|
30 |
+
|
31 |
+
from . import __version__
|
32 |
+
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.get_logger(__name__)
|
36 |
+
|
37 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
38 |
+
|
39 |
+
|
40 |
+
class FrozenDict(OrderedDict):
|
41 |
+
def __init__(self, *args, **kwargs):
|
42 |
+
super().__init__(*args, **kwargs)
|
43 |
+
|
44 |
+
for key, value in self.items():
|
45 |
+
setattr(self, key, value)
|
46 |
+
|
47 |
+
self.__frozen = True
|
48 |
+
|
49 |
+
def __delitem__(self, *args, **kwargs):
|
50 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
51 |
+
|
52 |
+
def setdefault(self, *args, **kwargs):
|
53 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
54 |
+
|
55 |
+
def pop(self, *args, **kwargs):
|
56 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
57 |
+
|
58 |
+
def update(self, *args, **kwargs):
|
59 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
60 |
+
|
61 |
+
def __setattr__(self, name, value):
|
62 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
63 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
64 |
+
super().__setattr__(name, value)
|
65 |
+
|
66 |
+
def __setitem__(self, name, value):
|
67 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
68 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
69 |
+
super().__setitem__(name, value)
|
70 |
+
|
71 |
+
|
72 |
+
class ConfigMixin:
|
73 |
+
r"""
|
74 |
+
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
|
75 |
+
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
|
76 |
+
- [`~ConfigMixin.from_config`]
|
77 |
+
- [`~ConfigMixin.save_config`]
|
78 |
+
|
79 |
+
Class attributes:
|
80 |
+
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
81 |
+
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
82 |
+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
83 |
+
overridden by subclass).
|
84 |
+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
85 |
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
|
86 |
+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
87 |
+
subclass).
|
88 |
+
"""
|
89 |
+
config_name = None
|
90 |
+
ignore_for_config = []
|
91 |
+
has_compatibles = False
|
92 |
+
|
93 |
+
_deprecated_kwargs = []
|
94 |
+
|
95 |
+
def register_to_config(self, **kwargs):
|
96 |
+
if self.config_name is None:
|
97 |
+
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
98 |
+
# Special case for `kwargs` used in deprecation warning added to schedulers
|
99 |
+
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
100 |
+
# or solve in a more general way.
|
101 |
+
kwargs.pop("kwargs", None)
|
102 |
+
for key, value in kwargs.items():
|
103 |
+
try:
|
104 |
+
setattr(self, key, value)
|
105 |
+
except AttributeError as err:
|
106 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
107 |
+
raise err
|
108 |
+
|
109 |
+
if not hasattr(self, "_internal_dict"):
|
110 |
+
internal_dict = kwargs
|
111 |
+
else:
|
112 |
+
previous_dict = dict(self._internal_dict)
|
113 |
+
internal_dict = {**self._internal_dict, **kwargs}
|
114 |
+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
115 |
+
|
116 |
+
self._internal_dict = FrozenDict(internal_dict)
|
117 |
+
|
118 |
+
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
119 |
+
"""
|
120 |
+
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
121 |
+
[`~ConfigMixin.from_config`] class method.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
save_directory (`str` or `os.PathLike`):
|
125 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
126 |
+
"""
|
127 |
+
if os.path.isfile(save_directory):
|
128 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
129 |
+
|
130 |
+
os.makedirs(save_directory, exist_ok=True)
|
131 |
+
|
132 |
+
# If we save using the predefined names, we can load using `from_config`
|
133 |
+
output_config_file = os.path.join(save_directory, self.config_name)
|
134 |
+
|
135 |
+
self.to_json_file(output_config_file)
|
136 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
137 |
+
|
138 |
+
@classmethod
|
139 |
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
140 |
+
r"""
|
141 |
+
Instantiate a Python class from a config dictionary
|
142 |
+
|
143 |
+
Parameters:
|
144 |
+
config (`Dict[str, Any]`):
|
145 |
+
A config dictionary from which the Python class will be instantiated. Make sure to only load
|
146 |
+
configuration files of compatible classes.
|
147 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
148 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
149 |
+
|
150 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
151 |
+
Can be used to update the configuration object (after it being loaded) and initiate the Python class.
|
152 |
+
`**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
|
153 |
+
overwrite same named arguments of `config`.
|
154 |
+
|
155 |
+
Examples:
|
156 |
+
|
157 |
+
```python
|
158 |
+
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
159 |
+
|
160 |
+
>>> # Download scheduler from huggingface.co and cache.
|
161 |
+
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
162 |
+
|
163 |
+
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
164 |
+
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
165 |
+
|
166 |
+
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
167 |
+
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
168 |
+
```
|
169 |
+
"""
|
170 |
+
# <===== TO BE REMOVED WITH DEPRECATION
|
171 |
+
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
172 |
+
if "pretrained_model_name_or_path" in kwargs:
|
173 |
+
config = kwargs.pop("pretrained_model_name_or_path")
|
174 |
+
|
175 |
+
if config is None:
|
176 |
+
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
177 |
+
# ======>
|
178 |
+
|
179 |
+
if not isinstance(config, dict):
|
180 |
+
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
181 |
+
if "Scheduler" in cls.__name__:
|
182 |
+
deprecation_message += (
|
183 |
+
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
184 |
+
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
185 |
+
" be removed in v1.0.0."
|
186 |
+
)
|
187 |
+
elif "Model" in cls.__name__:
|
188 |
+
deprecation_message += (
|
189 |
+
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
190 |
+
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
191 |
+
" instead. This functionality will be removed in v1.0.0."
|
192 |
+
)
|
193 |
+
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
194 |
+
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
195 |
+
|
196 |
+
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
197 |
+
|
198 |
+
# Allow dtype to be specified on initialization
|
199 |
+
if "dtype" in unused_kwargs:
|
200 |
+
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
201 |
+
|
202 |
+
# add possible deprecated kwargs
|
203 |
+
for deprecated_kwarg in cls._deprecated_kwargs:
|
204 |
+
if deprecated_kwarg in unused_kwargs:
|
205 |
+
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
206 |
+
|
207 |
+
# Return model and optionally state and/or unused_kwargs
|
208 |
+
model = cls(**init_dict)
|
209 |
+
|
210 |
+
# make sure to also save config parameters that might be used for compatible classes
|
211 |
+
model.register_to_config(**hidden_dict)
|
212 |
+
|
213 |
+
# add hidden kwargs of compatible classes to unused_kwargs
|
214 |
+
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
215 |
+
|
216 |
+
if return_unused_kwargs:
|
217 |
+
return (model, unused_kwargs)
|
218 |
+
else:
|
219 |
+
return model
|
220 |
+
|
221 |
+
@classmethod
|
222 |
+
def get_config_dict(cls, *args, **kwargs):
|
223 |
+
deprecation_message = (
|
224 |
+
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
225 |
+
" removed in version v1.0.0"
|
226 |
+
)
|
227 |
+
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
228 |
+
return cls.load_config(*args, **kwargs)
|
229 |
+
|
230 |
+
@classmethod
|
231 |
+
def load_config(
|
232 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
233 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
234 |
+
r"""
|
235 |
+
Instantiate a Python class from a config dictionary
|
236 |
+
|
237 |
+
Parameters:
|
238 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
239 |
+
Can be either:
|
240 |
+
|
241 |
+
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
242 |
+
organization name, like `google/ddpm-celebahq-256`.
|
243 |
+
- A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
|
244 |
+
`./my_model_directory/`.
|
245 |
+
|
246 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
247 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
248 |
+
standard cache should not be used.
|
249 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
250 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
251 |
+
cached versions if they exist.
|
252 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
253 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
254 |
+
file exists.
|
255 |
+
proxies (`Dict[str, str]`, *optional*):
|
256 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
257 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
258 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
259 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
260 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
261 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
262 |
+
use_auth_token (`str` or *bool*, *optional*):
|
263 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
264 |
+
when running `transformers-cli login` (stored in `~/.huggingface`).
|
265 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
266 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
267 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
268 |
+
identifier allowed by git.
|
269 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
270 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
271 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
272 |
+
|
273 |
+
<Tip>
|
274 |
+
|
275 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
276 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
277 |
+
|
278 |
+
</Tip>
|
279 |
+
|
280 |
+
<Tip>
|
281 |
+
|
282 |
+
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
283 |
+
use this method in a firewalled environment.
|
284 |
+
|
285 |
+
</Tip>
|
286 |
+
"""
|
287 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
288 |
+
force_download = kwargs.pop("force_download", False)
|
289 |
+
resume_download = kwargs.pop("resume_download", False)
|
290 |
+
proxies = kwargs.pop("proxies", None)
|
291 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
292 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
293 |
+
revision = kwargs.pop("revision", None)
|
294 |
+
_ = kwargs.pop("mirror", None)
|
295 |
+
subfolder = kwargs.pop("subfolder", None)
|
296 |
+
|
297 |
+
user_agent = {"file_type": "config"}
|
298 |
+
|
299 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
300 |
+
|
301 |
+
if cls.config_name is None:
|
302 |
+
raise ValueError(
|
303 |
+
"`self.config_name` is not defined. Note that one should not load a config from "
|
304 |
+
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
305 |
+
)
|
306 |
+
|
307 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
308 |
+
config_file = pretrained_model_name_or_path
|
309 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
310 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
311 |
+
# Load from a PyTorch checkpoint
|
312 |
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
313 |
+
elif subfolder is not None and os.path.isfile(
|
314 |
+
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
315 |
+
):
|
316 |
+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
317 |
+
else:
|
318 |
+
raise EnvironmentError(
|
319 |
+
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
320 |
+
)
|
321 |
+
else:
|
322 |
+
try:
|
323 |
+
# Load from URL or cache if already cached
|
324 |
+
config_file = hf_hub_download(
|
325 |
+
pretrained_model_name_or_path,
|
326 |
+
filename=cls.config_name,
|
327 |
+
cache_dir=cache_dir,
|
328 |
+
force_download=force_download,
|
329 |
+
proxies=proxies,
|
330 |
+
resume_download=resume_download,
|
331 |
+
local_files_only=local_files_only,
|
332 |
+
use_auth_token=use_auth_token,
|
333 |
+
user_agent=user_agent,
|
334 |
+
subfolder=subfolder,
|
335 |
+
revision=revision,
|
336 |
+
)
|
337 |
+
|
338 |
+
except RepositoryNotFoundError:
|
339 |
+
raise EnvironmentError(
|
340 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
341 |
+
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
342 |
+
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
|
343 |
+
" login`."
|
344 |
+
)
|
345 |
+
except RevisionNotFoundError:
|
346 |
+
raise EnvironmentError(
|
347 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
348 |
+
" this model name. Check the model page at"
|
349 |
+
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
350 |
+
)
|
351 |
+
except EntryNotFoundError:
|
352 |
+
raise EnvironmentError(
|
353 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
354 |
+
)
|
355 |
+
except HTTPError as err:
|
356 |
+
raise EnvironmentError(
|
357 |
+
"There was a specific connection error when trying to load"
|
358 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
359 |
+
)
|
360 |
+
except ValueError:
|
361 |
+
raise EnvironmentError(
|
362 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
363 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
364 |
+
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
365 |
+
" run the library in offline mode at"
|
366 |
+
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
367 |
+
)
|
368 |
+
except EnvironmentError:
|
369 |
+
raise EnvironmentError(
|
370 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
371 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
372 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
373 |
+
f"containing a {cls.config_name} file"
|
374 |
+
)
|
375 |
+
|
376 |
+
try:
|
377 |
+
# Load config dict
|
378 |
+
config_dict = cls._dict_from_json_file(config_file)
|
379 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
380 |
+
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
381 |
+
|
382 |
+
if return_unused_kwargs:
|
383 |
+
return config_dict, kwargs
|
384 |
+
|
385 |
+
return config_dict
|
386 |
+
|
387 |
+
@staticmethod
|
388 |
+
def _get_init_keys(cls):
|
389 |
+
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
390 |
+
|
391 |
+
@classmethod
|
392 |
+
def extract_init_dict(cls, config_dict, **kwargs):
|
393 |
+
# 0. Copy origin config dict
|
394 |
+
original_dict = {k: v for k, v in config_dict.items()}
|
395 |
+
|
396 |
+
# 1. Retrieve expected config attributes from __init__ signature
|
397 |
+
expected_keys = cls._get_init_keys(cls)
|
398 |
+
expected_keys.remove("self")
|
399 |
+
# remove general kwargs if present in dict
|
400 |
+
if "kwargs" in expected_keys:
|
401 |
+
expected_keys.remove("kwargs")
|
402 |
+
# remove flax internal keys
|
403 |
+
if hasattr(cls, "_flax_internal_args"):
|
404 |
+
for arg in cls._flax_internal_args:
|
405 |
+
expected_keys.remove(arg)
|
406 |
+
|
407 |
+
# 2. Remove attributes that cannot be expected from expected config attributes
|
408 |
+
# remove keys to be ignored
|
409 |
+
if len(cls.ignore_for_config) > 0:
|
410 |
+
expected_keys = expected_keys - set(cls.ignore_for_config)
|
411 |
+
|
412 |
+
# load diffusers library to import compatible and original scheduler
|
413 |
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
414 |
+
|
415 |
+
if cls.has_compatibles:
|
416 |
+
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
417 |
+
else:
|
418 |
+
compatible_classes = []
|
419 |
+
|
420 |
+
expected_keys_comp_cls = set()
|
421 |
+
for c in compatible_classes:
|
422 |
+
expected_keys_c = cls._get_init_keys(c)
|
423 |
+
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
424 |
+
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
425 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
426 |
+
|
427 |
+
# remove attributes from orig class that cannot be expected
|
428 |
+
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
429 |
+
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
430 |
+
orig_cls = getattr(diffusers_library, orig_cls_name)
|
431 |
+
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
432 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
433 |
+
|
434 |
+
# remove private attributes
|
435 |
+
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
436 |
+
|
437 |
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
438 |
+
init_dict = {}
|
439 |
+
for key in expected_keys:
|
440 |
+
# if config param is passed to kwarg and is present in config dict
|
441 |
+
# it should overwrite existing config dict key
|
442 |
+
if key in kwargs and key in config_dict:
|
443 |
+
config_dict[key] = kwargs.pop(key)
|
444 |
+
|
445 |
+
if key in kwargs:
|
446 |
+
# overwrite key
|
447 |
+
init_dict[key] = kwargs.pop(key)
|
448 |
+
elif key in config_dict:
|
449 |
+
# use value from config dict
|
450 |
+
init_dict[key] = config_dict.pop(key)
|
451 |
+
|
452 |
+
# 4. Give nice warning if unexpected values have been passed
|
453 |
+
if len(config_dict) > 0:
|
454 |
+
logger.warning(
|
455 |
+
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
456 |
+
"but are not expected and will be ignored. Please verify your "
|
457 |
+
f"{cls.config_name} configuration file."
|
458 |
+
)
|
459 |
+
|
460 |
+
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
461 |
+
passed_keys = set(init_dict.keys())
|
462 |
+
if len(expected_keys - passed_keys) > 0:
|
463 |
+
logger.info(
|
464 |
+
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
465 |
+
)
|
466 |
+
|
467 |
+
# 6. Define unused keyword arguments
|
468 |
+
unused_kwargs = {**config_dict, **kwargs}
|
469 |
+
|
470 |
+
# 7. Define "hidden" config parameters that were saved for compatible classes
|
471 |
+
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
472 |
+
|
473 |
+
return init_dict, unused_kwargs, hidden_config_dict
|
474 |
+
|
475 |
+
@classmethod
|
476 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
477 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
478 |
+
text = reader.read()
|
479 |
+
return json.loads(text)
|
480 |
+
|
481 |
+
def __repr__(self):
|
482 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
483 |
+
|
484 |
+
@property
|
485 |
+
def config(self) -> Dict[str, Any]:
|
486 |
+
"""
|
487 |
+
Returns the config of the class as a frozen dictionary
|
488 |
+
|
489 |
+
Returns:
|
490 |
+
`Dict[str, Any]`: Config of the class.
|
491 |
+
"""
|
492 |
+
return self._internal_dict
|
493 |
+
|
494 |
+
def to_json_string(self) -> str:
|
495 |
+
"""
|
496 |
+
Serializes this instance to a JSON string.
|
497 |
+
|
498 |
+
Returns:
|
499 |
+
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
500 |
+
"""
|
501 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
502 |
+
config_dict["_class_name"] = self.__class__.__name__
|
503 |
+
config_dict["_diffusers_version"] = __version__
|
504 |
+
|
505 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
506 |
+
|
507 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
508 |
+
"""
|
509 |
+
Save this instance to a JSON file.
|
510 |
+
|
511 |
+
Args:
|
512 |
+
json_file_path (`str` or `os.PathLike`):
|
513 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
514 |
+
"""
|
515 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
516 |
+
writer.write(self.to_json_string())
|
517 |
+
|
518 |
+
|
519 |
+
def register_to_config(init):
|
520 |
+
r"""
|
521 |
+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
522 |
+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
523 |
+
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
524 |
+
|
525 |
+
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
526 |
+
"""
|
527 |
+
|
528 |
+
@functools.wraps(init)
|
529 |
+
def inner_init(self, *args, **kwargs):
|
530 |
+
# Ignore private kwargs in the init.
|
531 |
+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
532 |
+
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
533 |
+
if not isinstance(self, ConfigMixin):
|
534 |
+
raise RuntimeError(
|
535 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
536 |
+
"not inherit from `ConfigMixin`."
|
537 |
+
)
|
538 |
+
|
539 |
+
ignore = getattr(self, "ignore_for_config", [])
|
540 |
+
# Get positional arguments aligned with kwargs
|
541 |
+
new_kwargs = {}
|
542 |
+
signature = inspect.signature(init)
|
543 |
+
parameters = {
|
544 |
+
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
545 |
+
}
|
546 |
+
for arg, name in zip(args, parameters.keys()):
|
547 |
+
new_kwargs[name] = arg
|
548 |
+
|
549 |
+
# Then add all kwargs
|
550 |
+
new_kwargs.update(
|
551 |
+
{
|
552 |
+
k: init_kwargs.get(k, default)
|
553 |
+
for k, default in parameters.items()
|
554 |
+
if k not in ignore and k not in new_kwargs
|
555 |
+
}
|
556 |
+
)
|
557 |
+
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
558 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
559 |
+
init(self, *args, **init_kwargs)
|
560 |
+
|
561 |
+
return inner_init
|
562 |
+
|
563 |
+
|
564 |
+
def flax_register_to_config(cls):
|
565 |
+
original_init = cls.__init__
|
566 |
+
|
567 |
+
@functools.wraps(original_init)
|
568 |
+
def init(self, *args, **kwargs):
|
569 |
+
if not isinstance(self, ConfigMixin):
|
570 |
+
raise RuntimeError(
|
571 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
572 |
+
"not inherit from `ConfigMixin`."
|
573 |
+
)
|
574 |
+
|
575 |
+
# Ignore private kwargs in the init. Retrieve all passed attributes
|
576 |
+
init_kwargs = {k: v for k, v in kwargs.items()}
|
577 |
+
|
578 |
+
# Retrieve default values
|
579 |
+
fields = dataclasses.fields(self)
|
580 |
+
default_kwargs = {}
|
581 |
+
for field in fields:
|
582 |
+
# ignore flax specific attributes
|
583 |
+
if field.name in self._flax_internal_args:
|
584 |
+
continue
|
585 |
+
if type(field.default) == dataclasses._MISSING_TYPE:
|
586 |
+
default_kwargs[field.name] = None
|
587 |
+
else:
|
588 |
+
default_kwargs[field.name] = getattr(self, field.name)
|
589 |
+
|
590 |
+
# Make sure init_kwargs override default kwargs
|
591 |
+
new_kwargs = {**default_kwargs, **init_kwargs}
|
592 |
+
# dtype should be part of `init_kwargs`, but not `new_kwargs`
|
593 |
+
if "dtype" in new_kwargs:
|
594 |
+
new_kwargs.pop("dtype")
|
595 |
+
|
596 |
+
# Get positional arguments aligned with kwargs
|
597 |
+
for i, arg in enumerate(args):
|
598 |
+
name = fields[i].name
|
599 |
+
new_kwargs[name] = arg
|
600 |
+
|
601 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
602 |
+
original_init(self, *args, **kwargs)
|
603 |
+
|
604 |
+
cls.__init__ = init
|
605 |
+
return cls
|
src/diffusers_/dynamic_modules_utils.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Utilities to dynamically load objects from the Hub."""
|
16 |
+
|
17 |
+
import importlib
|
18 |
+
import inspect
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
import shutil
|
22 |
+
import sys
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Dict, Optional, Union
|
25 |
+
|
26 |
+
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
|
27 |
+
|
28 |
+
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
29 |
+
|
30 |
+
|
31 |
+
COMMUNITY_PIPELINES_URL = (
|
32 |
+
"https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py"
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
|
39 |
+
def init_hf_modules():
|
40 |
+
"""
|
41 |
+
Creates the cache directory for modules with an init, and adds it to the Python path.
|
42 |
+
"""
|
43 |
+
# This function has already been executed if HF_MODULES_CACHE already is in the Python path.
|
44 |
+
if HF_MODULES_CACHE in sys.path:
|
45 |
+
return
|
46 |
+
|
47 |
+
sys.path.append(HF_MODULES_CACHE)
|
48 |
+
os.makedirs(HF_MODULES_CACHE, exist_ok=True)
|
49 |
+
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
|
50 |
+
if not init_path.exists():
|
51 |
+
init_path.touch()
|
52 |
+
|
53 |
+
|
54 |
+
def create_dynamic_module(name: Union[str, os.PathLike]):
|
55 |
+
"""
|
56 |
+
Creates a dynamic module in the cache directory for modules.
|
57 |
+
"""
|
58 |
+
init_hf_modules()
|
59 |
+
dynamic_module_path = Path(HF_MODULES_CACHE) / name
|
60 |
+
# If the parent module does not exist yet, recursively create it.
|
61 |
+
if not dynamic_module_path.parent.exists():
|
62 |
+
create_dynamic_module(dynamic_module_path.parent)
|
63 |
+
os.makedirs(dynamic_module_path, exist_ok=True)
|
64 |
+
init_path = dynamic_module_path / "__init__.py"
|
65 |
+
if not init_path.exists():
|
66 |
+
init_path.touch()
|
67 |
+
|
68 |
+
|
69 |
+
def get_relative_imports(module_file):
|
70 |
+
"""
|
71 |
+
Get the list of modules that are relatively imported in a module file.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
module_file (`str` or `os.PathLike`): The module file to inspect.
|
75 |
+
"""
|
76 |
+
with open(module_file, "r", encoding="utf-8") as f:
|
77 |
+
content = f.read()
|
78 |
+
|
79 |
+
# Imports of the form `import .xxx`
|
80 |
+
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
|
81 |
+
# Imports of the form `from .xxx import yyy`
|
82 |
+
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
|
83 |
+
# Unique-ify
|
84 |
+
return list(set(relative_imports))
|
85 |
+
|
86 |
+
|
87 |
+
def get_relative_import_files(module_file):
|
88 |
+
"""
|
89 |
+
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
|
90 |
+
imports (if a imports b and b imports c, it will return module files for b and c).
|
91 |
+
|
92 |
+
Args:
|
93 |
+
module_file (`str` or `os.PathLike`): The module file to inspect.
|
94 |
+
"""
|
95 |
+
no_change = False
|
96 |
+
files_to_check = [module_file]
|
97 |
+
all_relative_imports = []
|
98 |
+
|
99 |
+
# Let's recurse through all relative imports
|
100 |
+
while not no_change:
|
101 |
+
new_imports = []
|
102 |
+
for f in files_to_check:
|
103 |
+
new_imports.extend(get_relative_imports(f))
|
104 |
+
|
105 |
+
module_path = Path(module_file).parent
|
106 |
+
new_import_files = [str(module_path / m) for m in new_imports]
|
107 |
+
new_import_files = [f for f in new_import_files if f not in all_relative_imports]
|
108 |
+
files_to_check = [f"{f}.py" for f in new_import_files]
|
109 |
+
|
110 |
+
no_change = len(new_import_files) == 0
|
111 |
+
all_relative_imports.extend(files_to_check)
|
112 |
+
|
113 |
+
return all_relative_imports
|
114 |
+
|
115 |
+
|
116 |
+
def check_imports(filename):
|
117 |
+
"""
|
118 |
+
Check if the current Python environment contains all the libraries that are imported in a file.
|
119 |
+
"""
|
120 |
+
with open(filename, "r", encoding="utf-8") as f:
|
121 |
+
content = f.read()
|
122 |
+
|
123 |
+
# Imports of the form `import xxx`
|
124 |
+
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
125 |
+
# Imports of the form `from xxx import yyy`
|
126 |
+
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
127 |
+
# Only keep the top-level module
|
128 |
+
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
129 |
+
|
130 |
+
# Unique-ify and test we got them all
|
131 |
+
imports = list(set(imports))
|
132 |
+
missing_packages = []
|
133 |
+
for imp in imports:
|
134 |
+
try:
|
135 |
+
importlib.import_module(imp)
|
136 |
+
except ImportError:
|
137 |
+
missing_packages.append(imp)
|
138 |
+
|
139 |
+
if len(missing_packages) > 0:
|
140 |
+
raise ImportError(
|
141 |
+
"This modeling file requires the following packages that were not found in your environment: "
|
142 |
+
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
|
143 |
+
)
|
144 |
+
|
145 |
+
return get_relative_imports(filename)
|
146 |
+
|
147 |
+
|
148 |
+
def get_class_in_module(class_name, module_path):
|
149 |
+
"""
|
150 |
+
Import a module on the cache directory for modules and extract a class from it.
|
151 |
+
"""
|
152 |
+
module_path = module_path.replace(os.path.sep, ".")
|
153 |
+
module = importlib.import_module(module_path)
|
154 |
+
|
155 |
+
if class_name is None:
|
156 |
+
return find_pipeline_class(module)
|
157 |
+
return getattr(module, class_name)
|
158 |
+
|
159 |
+
|
160 |
+
def find_pipeline_class(loaded_module):
|
161 |
+
"""
|
162 |
+
Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
|
163 |
+
inheriting from `DiffusionPipeline`.
|
164 |
+
"""
|
165 |
+
from .pipeline_utils import DiffusionPipeline
|
166 |
+
|
167 |
+
cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
|
168 |
+
|
169 |
+
pipeline_class = None
|
170 |
+
for cls_name, cls in cls_members.items():
|
171 |
+
if (
|
172 |
+
cls_name != DiffusionPipeline.__name__
|
173 |
+
and issubclass(cls, DiffusionPipeline)
|
174 |
+
and cls.__module__.split(".")[0] != "diffusers"
|
175 |
+
):
|
176 |
+
if pipeline_class is not None:
|
177 |
+
raise ValueError(
|
178 |
+
f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
|
179 |
+
f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
|
180 |
+
f" {loaded_module}."
|
181 |
+
)
|
182 |
+
pipeline_class = cls
|
183 |
+
|
184 |
+
return pipeline_class
|
185 |
+
|
186 |
+
|
187 |
+
def get_cached_module_file(
|
188 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
189 |
+
module_file: str,
|
190 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
191 |
+
force_download: bool = False,
|
192 |
+
resume_download: bool = False,
|
193 |
+
proxies: Optional[Dict[str, str]] = None,
|
194 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
195 |
+
revision: Optional[str] = None,
|
196 |
+
local_files_only: bool = False,
|
197 |
+
):
|
198 |
+
"""
|
199 |
+
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
200 |
+
Transformers module.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
204 |
+
This can be either:
|
205 |
+
|
206 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
207 |
+
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
|
208 |
+
under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
209 |
+
- a path to a *directory* containing a configuration file saved using the
|
210 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
211 |
+
|
212 |
+
module_file (`str`):
|
213 |
+
The name of the module file containing the class to look for.
|
214 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
215 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
216 |
+
cache should not be used.
|
217 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
218 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
219 |
+
exist.
|
220 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
221 |
+
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
222 |
+
proxies (`Dict[str, str]`, *optional*):
|
223 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
224 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
225 |
+
use_auth_token (`str` or *bool*, *optional*):
|
226 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
227 |
+
when running `transformers-cli login` (stored in `~/.huggingface`).
|
228 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
229 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
230 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
231 |
+
identifier allowed by git.
|
232 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
233 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
234 |
+
|
235 |
+
<Tip>
|
236 |
+
|
237 |
+
You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
|
238 |
+
or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
239 |
+
|
240 |
+
</Tip>
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
`str`: The path to the module inside the cache.
|
244 |
+
"""
|
245 |
+
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
246 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
247 |
+
|
248 |
+
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
249 |
+
|
250 |
+
if os.path.isfile(module_file_or_url):
|
251 |
+
resolved_module_file = module_file_or_url
|
252 |
+
submodule = "local"
|
253 |
+
elif pretrained_model_name_or_path.count("/") == 0:
|
254 |
+
# community pipeline on GitHub
|
255 |
+
github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path)
|
256 |
+
try:
|
257 |
+
resolved_module_file = cached_download(
|
258 |
+
github_url,
|
259 |
+
cache_dir=cache_dir,
|
260 |
+
force_download=force_download,
|
261 |
+
proxies=proxies,
|
262 |
+
resume_download=resume_download,
|
263 |
+
local_files_only=local_files_only,
|
264 |
+
use_auth_token=False,
|
265 |
+
)
|
266 |
+
submodule = "git"
|
267 |
+
module_file = pretrained_model_name_or_path + ".py"
|
268 |
+
except EnvironmentError:
|
269 |
+
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
270 |
+
raise
|
271 |
+
else:
|
272 |
+
try:
|
273 |
+
# Load from URL or cache if already cached
|
274 |
+
resolved_module_file = hf_hub_download(
|
275 |
+
pretrained_model_name_or_path,
|
276 |
+
module_file,
|
277 |
+
cache_dir=cache_dir,
|
278 |
+
force_download=force_download,
|
279 |
+
proxies=proxies,
|
280 |
+
resume_download=resume_download,
|
281 |
+
local_files_only=local_files_only,
|
282 |
+
use_auth_token=use_auth_token,
|
283 |
+
)
|
284 |
+
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
|
285 |
+
except EnvironmentError:
|
286 |
+
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
287 |
+
raise
|
288 |
+
|
289 |
+
# Check we have all the requirements in our environment
|
290 |
+
modules_needed = check_imports(resolved_module_file)
|
291 |
+
|
292 |
+
# Now we move the module inside our cached dynamic modules.
|
293 |
+
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
294 |
+
create_dynamic_module(full_submodule)
|
295 |
+
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
296 |
+
if submodule == "local" or submodule == "git":
|
297 |
+
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
298 |
+
# that hash, to only copy when there is a modification but it seems overkill for now).
|
299 |
+
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
300 |
+
shutil.copy(resolved_module_file, submodule_path / module_file)
|
301 |
+
for module_needed in modules_needed:
|
302 |
+
module_needed = f"{module_needed}.py"
|
303 |
+
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
304 |
+
else:
|
305 |
+
# Get the commit hash
|
306 |
+
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
307 |
+
if isinstance(use_auth_token, str):
|
308 |
+
token = use_auth_token
|
309 |
+
elif use_auth_token is True:
|
310 |
+
token = HfFolder.get_token()
|
311 |
+
else:
|
312 |
+
token = None
|
313 |
+
|
314 |
+
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
|
315 |
+
|
316 |
+
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
|
317 |
+
# benefit of versioning.
|
318 |
+
submodule_path = submodule_path / commit_hash
|
319 |
+
full_submodule = full_submodule + os.path.sep + commit_hash
|
320 |
+
create_dynamic_module(full_submodule)
|
321 |
+
|
322 |
+
if not (submodule_path / module_file).exists():
|
323 |
+
shutil.copy(resolved_module_file, submodule_path / module_file)
|
324 |
+
# Make sure we also have every file with relative
|
325 |
+
for module_needed in modules_needed:
|
326 |
+
if not (submodule_path / module_needed).exists():
|
327 |
+
get_cached_module_file(
|
328 |
+
pretrained_model_name_or_path,
|
329 |
+
f"{module_needed}.py",
|
330 |
+
cache_dir=cache_dir,
|
331 |
+
force_download=force_download,
|
332 |
+
resume_download=resume_download,
|
333 |
+
proxies=proxies,
|
334 |
+
use_auth_token=use_auth_token,
|
335 |
+
revision=revision,
|
336 |
+
local_files_only=local_files_only,
|
337 |
+
)
|
338 |
+
return os.path.join(full_submodule, module_file)
|
339 |
+
|
340 |
+
|
341 |
+
def get_class_from_dynamic_module(
|
342 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
343 |
+
module_file: str,
|
344 |
+
class_name: Optional[str] = None,
|
345 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
346 |
+
force_download: bool = False,
|
347 |
+
resume_download: bool = False,
|
348 |
+
proxies: Optional[Dict[str, str]] = None,
|
349 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
350 |
+
revision: Optional[str] = None,
|
351 |
+
local_files_only: bool = False,
|
352 |
+
**kwargs,
|
353 |
+
):
|
354 |
+
"""
|
355 |
+
Extracts a class from a module file, present in the local folder or repository of a model.
|
356 |
+
|
357 |
+
<Tip warning={true}>
|
358 |
+
|
359 |
+
Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
|
360 |
+
therefore only be called on trusted repos.
|
361 |
+
|
362 |
+
</Tip>
|
363 |
+
|
364 |
+
Args:
|
365 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
366 |
+
This can be either:
|
367 |
+
|
368 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
369 |
+
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
|
370 |
+
under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
371 |
+
- a path to a *directory* containing a configuration file saved using the
|
372 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
373 |
+
|
374 |
+
module_file (`str`):
|
375 |
+
The name of the module file containing the class to look for.
|
376 |
+
class_name (`str`):
|
377 |
+
The name of the class to import in the module.
|
378 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
379 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
380 |
+
cache should not be used.
|
381 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
382 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
383 |
+
exist.
|
384 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
385 |
+
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
386 |
+
proxies (`Dict[str, str]`, *optional*):
|
387 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
388 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
389 |
+
use_auth_token (`str` or `bool`, *optional*):
|
390 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
391 |
+
when running `transformers-cli login` (stored in `~/.huggingface`).
|
392 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
393 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
394 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
395 |
+
identifier allowed by git.
|
396 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
397 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
398 |
+
|
399 |
+
<Tip>
|
400 |
+
|
401 |
+
You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
|
402 |
+
or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
403 |
+
|
404 |
+
</Tip>
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
`type`: The class, dynamically imported from the module.
|
408 |
+
|
409 |
+
Examples:
|
410 |
+
|
411 |
+
```python
|
412 |
+
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
|
413 |
+
# module.
|
414 |
+
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
|
415 |
+
```"""
|
416 |
+
# And lastly we get the class inside our newly created module
|
417 |
+
final_module = get_cached_module_file(
|
418 |
+
pretrained_model_name_or_path,
|
419 |
+
module_file,
|
420 |
+
cache_dir=cache_dir,
|
421 |
+
force_download=force_download,
|
422 |
+
resume_download=resume_download,
|
423 |
+
proxies=proxies,
|
424 |
+
use_auth_token=use_auth_token,
|
425 |
+
revision=revision,
|
426 |
+
local_files_only=local_files_only,
|
427 |
+
)
|
428 |
+
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
src/diffusers_/hub_utils.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import os
|
18 |
+
import shutil
|
19 |
+
import sys
|
20 |
+
from pathlib import Path
|
21 |
+
from typing import Dict, Optional, Union
|
22 |
+
from uuid import uuid4
|
23 |
+
|
24 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
25 |
+
|
26 |
+
from . import __version__
|
27 |
+
from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
|
28 |
+
from .utils.import_utils import (
|
29 |
+
_flax_version,
|
30 |
+
_jax_version,
|
31 |
+
_onnxruntime_version,
|
32 |
+
_torch_version,
|
33 |
+
is_flax_available,
|
34 |
+
is_modelcards_available,
|
35 |
+
is_onnx_available,
|
36 |
+
is_torch_available,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
if is_modelcards_available():
|
41 |
+
from modelcards import CardData, ModelCard
|
42 |
+
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__)
|
45 |
+
|
46 |
+
|
47 |
+
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
|
48 |
+
SESSION_ID = uuid4().hex
|
49 |
+
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
|
50 |
+
|
51 |
+
|
52 |
+
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
53 |
+
"""
|
54 |
+
Formats a user-agent string with basic info about a request.
|
55 |
+
"""
|
56 |
+
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
|
57 |
+
if DISABLE_TELEMETRY:
|
58 |
+
return ua + "; telemetry/off"
|
59 |
+
if is_torch_available():
|
60 |
+
ua += f"; torch/{_torch_version}"
|
61 |
+
if is_flax_available():
|
62 |
+
ua += f"; jax/{_jax_version}"
|
63 |
+
ua += f"; flax/{_flax_version}"
|
64 |
+
if is_onnx_available():
|
65 |
+
ua += f"; onnxruntime/{_onnxruntime_version}"
|
66 |
+
# CI will set this value to True
|
67 |
+
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
|
68 |
+
ua += "; is_ci/true"
|
69 |
+
if isinstance(user_agent, dict):
|
70 |
+
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
|
71 |
+
elif isinstance(user_agent, str):
|
72 |
+
ua += "; " + user_agent
|
73 |
+
return ua
|
74 |
+
|
75 |
+
|
76 |
+
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
77 |
+
if token is None:
|
78 |
+
token = HfFolder.get_token()
|
79 |
+
if organization is None:
|
80 |
+
username = whoami(token)["name"]
|
81 |
+
return f"{username}/{model_id}"
|
82 |
+
else:
|
83 |
+
return f"{organization}/{model_id}"
|
84 |
+
|
85 |
+
|
86 |
+
def init_git_repo(args, at_init: bool = False):
|
87 |
+
"""
|
88 |
+
Args:
|
89 |
+
Initializes a git repo in `args.hub_model_id`.
|
90 |
+
at_init (`bool`, *optional*, defaults to `False`):
|
91 |
+
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
|
92 |
+
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
|
93 |
+
"""
|
94 |
+
deprecation_message = (
|
95 |
+
"Please use `huggingface_hub.Repository`. "
|
96 |
+
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
|
97 |
+
)
|
98 |
+
deprecate("init_git_repo()", "0.10.0", deprecation_message)
|
99 |
+
|
100 |
+
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
|
101 |
+
return
|
102 |
+
hub_token = args.hub_token if hasattr(args, "hub_token") else None
|
103 |
+
use_auth_token = True if hub_token is None else hub_token
|
104 |
+
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
|
105 |
+
repo_name = Path(args.output_dir).absolute().name
|
106 |
+
else:
|
107 |
+
repo_name = args.hub_model_id
|
108 |
+
if "/" not in repo_name:
|
109 |
+
repo_name = get_full_repo_name(repo_name, token=hub_token)
|
110 |
+
|
111 |
+
try:
|
112 |
+
repo = Repository(
|
113 |
+
args.output_dir,
|
114 |
+
clone_from=repo_name,
|
115 |
+
use_auth_token=use_auth_token,
|
116 |
+
private=args.hub_private_repo,
|
117 |
+
)
|
118 |
+
except EnvironmentError:
|
119 |
+
if args.overwrite_output_dir and at_init:
|
120 |
+
# Try again after wiping output_dir
|
121 |
+
shutil.rmtree(args.output_dir)
|
122 |
+
repo = Repository(
|
123 |
+
args.output_dir,
|
124 |
+
clone_from=repo_name,
|
125 |
+
use_auth_token=use_auth_token,
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
raise
|
129 |
+
|
130 |
+
repo.git_pull()
|
131 |
+
|
132 |
+
# By default, ignore the checkpoint folders
|
133 |
+
if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
|
134 |
+
with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
|
135 |
+
writer.writelines(["checkpoint-*/"])
|
136 |
+
|
137 |
+
return repo
|
138 |
+
|
139 |
+
|
140 |
+
def push_to_hub(
|
141 |
+
args,
|
142 |
+
pipeline,
|
143 |
+
repo: Repository,
|
144 |
+
commit_message: Optional[str] = "End of training",
|
145 |
+
blocking: bool = True,
|
146 |
+
**kwargs,
|
147 |
+
) -> str:
|
148 |
+
"""
|
149 |
+
Parameters:
|
150 |
+
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
|
151 |
+
commit_message (`str`, *optional*, defaults to `"End of training"`):
|
152 |
+
Message to commit while pushing.
|
153 |
+
blocking (`bool`, *optional*, defaults to `True`):
|
154 |
+
Whether the function should return only when the `git push` has finished.
|
155 |
+
kwargs:
|
156 |
+
Additional keyword arguments passed along to [`create_model_card`].
|
157 |
+
Returns:
|
158 |
+
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
|
159 |
+
commit and an object to track the progress of the commit if `blocking=True`
|
160 |
+
"""
|
161 |
+
deprecation_message = (
|
162 |
+
"Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. "
|
163 |
+
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
|
164 |
+
)
|
165 |
+
deprecate("push_to_hub()", "0.10.0", deprecation_message)
|
166 |
+
|
167 |
+
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
|
168 |
+
model_name = Path(args.output_dir).name
|
169 |
+
else:
|
170 |
+
model_name = args.hub_model_id.split("/")[-1]
|
171 |
+
|
172 |
+
output_dir = args.output_dir
|
173 |
+
os.makedirs(output_dir, exist_ok=True)
|
174 |
+
logger.info(f"Saving pipeline checkpoint to {output_dir}")
|
175 |
+
pipeline.save_pretrained(output_dir)
|
176 |
+
|
177 |
+
# Only push from one node.
|
178 |
+
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
|
179 |
+
return
|
180 |
+
|
181 |
+
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
|
182 |
+
if (
|
183 |
+
blocking
|
184 |
+
and len(repo.command_queue) > 0
|
185 |
+
and repo.command_queue[-1] is not None
|
186 |
+
and not repo.command_queue[-1].is_done
|
187 |
+
):
|
188 |
+
repo.command_queue[-1]._process.kill()
|
189 |
+
|
190 |
+
git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
|
191 |
+
# push separately the model card to be independent from the rest of the model
|
192 |
+
create_model_card(args, model_name=model_name)
|
193 |
+
try:
|
194 |
+
repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
|
195 |
+
except EnvironmentError as exc:
|
196 |
+
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
|
197 |
+
|
198 |
+
return git_head_commit_url
|
199 |
+
|
200 |
+
|
201 |
+
def create_model_card(args, model_name):
|
202 |
+
if not is_modelcards_available:
|
203 |
+
raise ValueError(
|
204 |
+
"Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
|
205 |
+
" install the package with `pip install modelcards`."
|
206 |
+
)
|
207 |
+
|
208 |
+
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
|
209 |
+
return
|
210 |
+
|
211 |
+
hub_token = args.hub_token if hasattr(args, "hub_token") else None
|
212 |
+
repo_name = get_full_repo_name(model_name, token=hub_token)
|
213 |
+
|
214 |
+
model_card = ModelCard.from_template(
|
215 |
+
card_data=CardData( # Card metadata object that will be converted to YAML block
|
216 |
+
language="en",
|
217 |
+
license="apache-2.0",
|
218 |
+
library_name="diffusers",
|
219 |
+
tags=[],
|
220 |
+
datasets=args.dataset_name,
|
221 |
+
metrics=[],
|
222 |
+
),
|
223 |
+
template_path=MODEL_CARD_TEMPLATE_PATH,
|
224 |
+
model_name=model_name,
|
225 |
+
repo_name=repo_name,
|
226 |
+
dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
|
227 |
+
learning_rate=args.learning_rate,
|
228 |
+
train_batch_size=args.train_batch_size,
|
229 |
+
eval_batch_size=args.eval_batch_size,
|
230 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps
|
231 |
+
if hasattr(args, "gradient_accumulation_steps")
|
232 |
+
else None,
|
233 |
+
adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
|
234 |
+
adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
|
235 |
+
adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
|
236 |
+
adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
|
237 |
+
lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
|
238 |
+
lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
|
239 |
+
ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
|
240 |
+
ema_power=args.ema_power if hasattr(args, "ema_power") else None,
|
241 |
+
ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
|
242 |
+
mixed_precision=args.mixed_precision,
|
243 |
+
)
|
244 |
+
|
245 |
+
card_path = os.path.join(args.output_dir, "README.md")
|
246 |
+
model_card.save(card_path)
|
src/diffusers_/modeling_utils.py
ADDED
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import os
|
18 |
+
from functools import partial
|
19 |
+
from typing import Callable, List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from torch import Tensor, device
|
23 |
+
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
26 |
+
from requests import HTTPError
|
27 |
+
|
28 |
+
from . import __version__
|
29 |
+
from .utils import (
|
30 |
+
CONFIG_NAME,
|
31 |
+
DIFFUSERS_CACHE,
|
32 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
33 |
+
WEIGHTS_NAME,
|
34 |
+
is_accelerate_available,
|
35 |
+
is_torch_version,
|
36 |
+
logging,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
|
43 |
+
if is_torch_version(">=", "1.9.0"):
|
44 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
45 |
+
else:
|
46 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
47 |
+
|
48 |
+
|
49 |
+
if is_accelerate_available():
|
50 |
+
import accelerate
|
51 |
+
from accelerate.utils import set_module_tensor_to_device
|
52 |
+
from accelerate.utils.versions import is_torch_version
|
53 |
+
|
54 |
+
|
55 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
56 |
+
try:
|
57 |
+
return next(parameter.parameters()).device
|
58 |
+
except StopIteration:
|
59 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
60 |
+
|
61 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
62 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
63 |
+
return tuples
|
64 |
+
|
65 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
66 |
+
first_tuple = next(gen)
|
67 |
+
return first_tuple[1].device
|
68 |
+
|
69 |
+
|
70 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
71 |
+
try:
|
72 |
+
return next(parameter.parameters()).dtype
|
73 |
+
except StopIteration:
|
74 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
75 |
+
|
76 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
77 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
78 |
+
return tuples
|
79 |
+
|
80 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
81 |
+
first_tuple = next(gen)
|
82 |
+
return first_tuple[1].dtype
|
83 |
+
|
84 |
+
|
85 |
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
86 |
+
"""
|
87 |
+
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
88 |
+
"""
|
89 |
+
try:
|
90 |
+
return torch.load(checkpoint_file, map_location="cpu")
|
91 |
+
except Exception as e:
|
92 |
+
try:
|
93 |
+
with open(checkpoint_file) as f:
|
94 |
+
if f.read().startswith("version"):
|
95 |
+
raise OSError(
|
96 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
97 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
98 |
+
"you cloned."
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
raise ValueError(
|
102 |
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
103 |
+
"model. Make sure you have saved the model properly."
|
104 |
+
) from e
|
105 |
+
except (UnicodeDecodeError, ValueError):
|
106 |
+
raise OSError(
|
107 |
+
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
|
108 |
+
f"at '{checkpoint_file}'. "
|
109 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def _load_state_dict_into_model(model_to_load, state_dict):
|
114 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
115 |
+
# copy state_dict so _load_from_state_dict can modify it
|
116 |
+
state_dict = state_dict.copy()
|
117 |
+
error_msgs = []
|
118 |
+
|
119 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
120 |
+
# so we need to apply the function recursively.
|
121 |
+
def load(module: torch.nn.Module, prefix=""):
|
122 |
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
123 |
+
module._load_from_state_dict(*args)
|
124 |
+
|
125 |
+
for name, child in module._modules.items():
|
126 |
+
if child is not None:
|
127 |
+
load(child, prefix + name + ".")
|
128 |
+
|
129 |
+
load(model_to_load)
|
130 |
+
|
131 |
+
return error_msgs
|
132 |
+
|
133 |
+
|
134 |
+
class ModelMixin(torch.nn.Module):
|
135 |
+
r"""
|
136 |
+
Base class for all models.
|
137 |
+
|
138 |
+
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
139 |
+
and saving models.
|
140 |
+
|
141 |
+
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
142 |
+
[`~modeling_utils.ModelMixin.save_pretrained`].
|
143 |
+
"""
|
144 |
+
config_name = CONFIG_NAME
|
145 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
146 |
+
_supports_gradient_checkpointing = False
|
147 |
+
|
148 |
+
def __init__(self):
|
149 |
+
super().__init__()
|
150 |
+
|
151 |
+
@property
|
152 |
+
def is_gradient_checkpointing(self) -> bool:
|
153 |
+
"""
|
154 |
+
Whether gradient checkpointing is activated for this model or not.
|
155 |
+
|
156 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
157 |
+
activations".
|
158 |
+
"""
|
159 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
160 |
+
|
161 |
+
def enable_gradient_checkpointing(self):
|
162 |
+
"""
|
163 |
+
Activates gradient checkpointing for the current model.
|
164 |
+
|
165 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
166 |
+
activations".
|
167 |
+
"""
|
168 |
+
if not self._supports_gradient_checkpointing:
|
169 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
170 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
171 |
+
|
172 |
+
def disable_gradient_checkpointing(self):
|
173 |
+
"""
|
174 |
+
Deactivates gradient checkpointing for the current model.
|
175 |
+
|
176 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
177 |
+
activations".
|
178 |
+
"""
|
179 |
+
if self._supports_gradient_checkpointing:
|
180 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
181 |
+
|
182 |
+
def save_pretrained(
|
183 |
+
self,
|
184 |
+
save_directory: Union[str, os.PathLike],
|
185 |
+
is_main_process: bool = True,
|
186 |
+
save_function: Callable = torch.save,
|
187 |
+
):
|
188 |
+
"""
|
189 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
190 |
+
`[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
|
191 |
+
|
192 |
+
Arguments:
|
193 |
+
save_directory (`str` or `os.PathLike`):
|
194 |
+
Directory to which to save. Will be created if it doesn't exist.
|
195 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
196 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
197 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
198 |
+
the main process to avoid race conditions.
|
199 |
+
save_function (`Callable`):
|
200 |
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
201 |
+
need to replace `torch.save` by another method.
|
202 |
+
"""
|
203 |
+
if os.path.isfile(save_directory):
|
204 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
205 |
+
return
|
206 |
+
|
207 |
+
os.makedirs(save_directory, exist_ok=True)
|
208 |
+
|
209 |
+
model_to_save = self
|
210 |
+
|
211 |
+
# Attach architecture to the config
|
212 |
+
# Save the config
|
213 |
+
if is_main_process:
|
214 |
+
model_to_save.save_config(save_directory)
|
215 |
+
|
216 |
+
# Save the model
|
217 |
+
state_dict = model_to_save.state_dict()
|
218 |
+
|
219 |
+
# Clean the folder from a previous save
|
220 |
+
for filename in os.listdir(save_directory):
|
221 |
+
full_filename = os.path.join(save_directory, filename)
|
222 |
+
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
223 |
+
# in distributed settings to avoid race conditions.
|
224 |
+
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
|
225 |
+
os.remove(full_filename)
|
226 |
+
|
227 |
+
# Save the model
|
228 |
+
save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
|
229 |
+
|
230 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
231 |
+
|
232 |
+
@classmethod
|
233 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
234 |
+
r"""
|
235 |
+
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
236 |
+
|
237 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
238 |
+
the model, you should first set it back in training mode with `model.train()`.
|
239 |
+
|
240 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
241 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
242 |
+
task.
|
243 |
+
|
244 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
245 |
+
weights are discarded.
|
246 |
+
|
247 |
+
Parameters:
|
248 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
249 |
+
Can be either:
|
250 |
+
|
251 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
252 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
253 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
254 |
+
`./my_model_directory/`.
|
255 |
+
|
256 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
257 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
258 |
+
standard cache should not be used.
|
259 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
260 |
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
261 |
+
will be automatically derived from the model's weights.
|
262 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
263 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
264 |
+
cached versions if they exist.
|
265 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
266 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
267 |
+
file exists.
|
268 |
+
proxies (`Dict[str, str]`, *optional*):
|
269 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
270 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
271 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
272 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
273 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
274 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
275 |
+
use_auth_token (`str` or *bool*, *optional*):
|
276 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
277 |
+
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
278 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
279 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
280 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
281 |
+
identifier allowed by git.
|
282 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
283 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
284 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
285 |
+
|
286 |
+
mirror (`str`, *optional*):
|
287 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
288 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
289 |
+
Please refer to the mirror site for more information.
|
290 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
291 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
292 |
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
293 |
+
same device.
|
294 |
+
|
295 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
296 |
+
more information about each option see [designing a device
|
297 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
298 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
299 |
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
300 |
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
301 |
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
302 |
+
setting this argument to `True` will raise an error.
|
303 |
+
|
304 |
+
<Tip>
|
305 |
+
|
306 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
307 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
308 |
+
|
309 |
+
</Tip>
|
310 |
+
|
311 |
+
<Tip>
|
312 |
+
|
313 |
+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
314 |
+
this method in a firewalled environment.
|
315 |
+
|
316 |
+
</Tip>
|
317 |
+
|
318 |
+
"""
|
319 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
320 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
321 |
+
force_download = kwargs.pop("force_download", False)
|
322 |
+
resume_download = kwargs.pop("resume_download", False)
|
323 |
+
proxies = kwargs.pop("proxies", None)
|
324 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
325 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
326 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
327 |
+
revision = kwargs.pop("revision", None)
|
328 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
329 |
+
subfolder = kwargs.pop("subfolder", None)
|
330 |
+
device_map = kwargs.pop("device_map", None)
|
331 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
332 |
+
|
333 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
334 |
+
low_cpu_mem_usage = False
|
335 |
+
logger.warning(
|
336 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
337 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
338 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
339 |
+
" install accelerate\n```\n."
|
340 |
+
)
|
341 |
+
|
342 |
+
if device_map is not None and not is_accelerate_available():
|
343 |
+
raise NotImplementedError(
|
344 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
345 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
346 |
+
)
|
347 |
+
|
348 |
+
# Check if we can handle device_map and dispatching the weights
|
349 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
350 |
+
raise NotImplementedError(
|
351 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
352 |
+
" `device_map=None`."
|
353 |
+
)
|
354 |
+
|
355 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
356 |
+
raise NotImplementedError(
|
357 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
358 |
+
" `low_cpu_mem_usage=False`."
|
359 |
+
)
|
360 |
+
|
361 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
362 |
+
raise ValueError(
|
363 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
364 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
365 |
+
)
|
366 |
+
|
367 |
+
user_agent = {
|
368 |
+
"diffusers": __version__,
|
369 |
+
"file_type": "model",
|
370 |
+
"framework": "pytorch",
|
371 |
+
}
|
372 |
+
|
373 |
+
# Load config if we don't provide a configuration
|
374 |
+
config_path = pretrained_model_name_or_path
|
375 |
+
|
376 |
+
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
377 |
+
# Load model
|
378 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
379 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
380 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
381 |
+
# Load from a PyTorch checkpoint
|
382 |
+
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
383 |
+
elif subfolder is not None and os.path.isfile(
|
384 |
+
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
385 |
+
):
|
386 |
+
model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
387 |
+
else:
|
388 |
+
raise EnvironmentError(
|
389 |
+
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
try:
|
393 |
+
# Load from URL or cache if already cached
|
394 |
+
model_file = hf_hub_download(
|
395 |
+
pretrained_model_name_or_path,
|
396 |
+
filename=WEIGHTS_NAME,
|
397 |
+
cache_dir=cache_dir,
|
398 |
+
force_download=force_download,
|
399 |
+
proxies=proxies,
|
400 |
+
resume_download=resume_download,
|
401 |
+
local_files_only=local_files_only,
|
402 |
+
use_auth_token=use_auth_token,
|
403 |
+
user_agent=user_agent,
|
404 |
+
subfolder=subfolder,
|
405 |
+
revision=revision,
|
406 |
+
)
|
407 |
+
|
408 |
+
except RepositoryNotFoundError:
|
409 |
+
raise EnvironmentError(
|
410 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
411 |
+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
412 |
+
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
413 |
+
"login`."
|
414 |
+
)
|
415 |
+
except RevisionNotFoundError:
|
416 |
+
raise EnvironmentError(
|
417 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
418 |
+
"this model name. Check the model page at "
|
419 |
+
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
420 |
+
)
|
421 |
+
except EntryNotFoundError:
|
422 |
+
raise EnvironmentError(
|
423 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
|
424 |
+
)
|
425 |
+
except HTTPError as err:
|
426 |
+
raise EnvironmentError(
|
427 |
+
"There was a specific connection error when trying to load"
|
428 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
429 |
+
)
|
430 |
+
except ValueError:
|
431 |
+
raise EnvironmentError(
|
432 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
433 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
434 |
+
f" directory containing a file named {WEIGHTS_NAME} or"
|
435 |
+
" \nCheckout your internet connection or see how to run the library in"
|
436 |
+
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
437 |
+
)
|
438 |
+
except EnvironmentError:
|
439 |
+
raise EnvironmentError(
|
440 |
+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
441 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
442 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
443 |
+
f"containing a file named {WEIGHTS_NAME}"
|
444 |
+
)
|
445 |
+
|
446 |
+
# restore default dtype
|
447 |
+
|
448 |
+
if low_cpu_mem_usage:
|
449 |
+
# Instantiate model with empty weights
|
450 |
+
with accelerate.init_empty_weights():
|
451 |
+
config, unused_kwargs = cls.load_config(
|
452 |
+
config_path,
|
453 |
+
cache_dir=cache_dir,
|
454 |
+
return_unused_kwargs=True,
|
455 |
+
force_download=force_download,
|
456 |
+
resume_download=resume_download,
|
457 |
+
proxies=proxies,
|
458 |
+
local_files_only=local_files_only,
|
459 |
+
use_auth_token=use_auth_token,
|
460 |
+
revision=revision,
|
461 |
+
subfolder=subfolder,
|
462 |
+
device_map=device_map,
|
463 |
+
**kwargs,
|
464 |
+
)
|
465 |
+
model = cls.from_config(config, **unused_kwargs)
|
466 |
+
|
467 |
+
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
|
468 |
+
if device_map is None:
|
469 |
+
param_device = "cpu"
|
470 |
+
state_dict = load_state_dict(model_file)
|
471 |
+
# move the parms from meta device to cpu
|
472 |
+
for param_name, param in state_dict.items():
|
473 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
474 |
+
else: # else let accelerate handle loading and dispatching.
|
475 |
+
# Load weights and dispatch according to the device_map
|
476 |
+
# by deafult the device_map is None and the weights are loaded on the CPU
|
477 |
+
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
478 |
+
|
479 |
+
loading_info = {
|
480 |
+
"missing_keys": [],
|
481 |
+
"unexpected_keys": [],
|
482 |
+
"mismatched_keys": [],
|
483 |
+
"error_msgs": [],
|
484 |
+
}
|
485 |
+
else:
|
486 |
+
config, unused_kwargs = cls.load_config(
|
487 |
+
config_path,
|
488 |
+
cache_dir=cache_dir,
|
489 |
+
return_unused_kwargs=True,
|
490 |
+
force_download=force_download,
|
491 |
+
resume_download=resume_download,
|
492 |
+
proxies=proxies,
|
493 |
+
local_files_only=local_files_only,
|
494 |
+
use_auth_token=use_auth_token,
|
495 |
+
revision=revision,
|
496 |
+
subfolder=subfolder,
|
497 |
+
device_map=device_map,
|
498 |
+
**kwargs,
|
499 |
+
)
|
500 |
+
model = cls.from_config(config, **unused_kwargs)
|
501 |
+
|
502 |
+
state_dict = load_state_dict(model_file)
|
503 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
504 |
+
model,
|
505 |
+
state_dict,
|
506 |
+
model_file,
|
507 |
+
pretrained_model_name_or_path,
|
508 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
509 |
+
)
|
510 |
+
|
511 |
+
loading_info = {
|
512 |
+
"missing_keys": missing_keys,
|
513 |
+
"unexpected_keys": unexpected_keys,
|
514 |
+
"mismatched_keys": mismatched_keys,
|
515 |
+
"error_msgs": error_msgs,
|
516 |
+
}
|
517 |
+
|
518 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
519 |
+
raise ValueError(
|
520 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
521 |
+
)
|
522 |
+
elif torch_dtype is not None:
|
523 |
+
model = model.to(torch_dtype)
|
524 |
+
|
525 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
526 |
+
|
527 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
528 |
+
model.eval()
|
529 |
+
if output_loading_info:
|
530 |
+
return model, loading_info
|
531 |
+
|
532 |
+
return model
|
533 |
+
|
534 |
+
@classmethod
|
535 |
+
def _load_pretrained_model(
|
536 |
+
cls,
|
537 |
+
model,
|
538 |
+
state_dict,
|
539 |
+
resolved_archive_file,
|
540 |
+
pretrained_model_name_or_path,
|
541 |
+
ignore_mismatched_sizes=False,
|
542 |
+
):
|
543 |
+
# Retrieve missing & unexpected_keys
|
544 |
+
model_state_dict = model.state_dict()
|
545 |
+
loaded_keys = [k for k in state_dict.keys()]
|
546 |
+
|
547 |
+
expected_keys = list(model_state_dict.keys())
|
548 |
+
|
549 |
+
original_loaded_keys = loaded_keys
|
550 |
+
|
551 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
552 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
553 |
+
|
554 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
555 |
+
model_to_load = model
|
556 |
+
|
557 |
+
def _find_mismatched_keys(
|
558 |
+
state_dict,
|
559 |
+
model_state_dict,
|
560 |
+
loaded_keys,
|
561 |
+
ignore_mismatched_sizes,
|
562 |
+
):
|
563 |
+
mismatched_keys = []
|
564 |
+
if ignore_mismatched_sizes:
|
565 |
+
for checkpoint_key in loaded_keys:
|
566 |
+
model_key = checkpoint_key
|
567 |
+
|
568 |
+
if (
|
569 |
+
model_key in model_state_dict
|
570 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
571 |
+
):
|
572 |
+
mismatched_keys.append(
|
573 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
574 |
+
)
|
575 |
+
del state_dict[checkpoint_key]
|
576 |
+
return mismatched_keys
|
577 |
+
|
578 |
+
if state_dict is not None:
|
579 |
+
# Whole checkpoint
|
580 |
+
mismatched_keys = _find_mismatched_keys(
|
581 |
+
state_dict,
|
582 |
+
model_state_dict,
|
583 |
+
original_loaded_keys,
|
584 |
+
ignore_mismatched_sizes,
|
585 |
+
)
|
586 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
587 |
+
|
588 |
+
if len(error_msgs) > 0:
|
589 |
+
error_msg = "\n\t".join(error_msgs)
|
590 |
+
if "size mismatch" in error_msg:
|
591 |
+
error_msg += (
|
592 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
593 |
+
)
|
594 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
595 |
+
|
596 |
+
if len(unexpected_keys) > 0:
|
597 |
+
logger.warning(
|
598 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
599 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
600 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
601 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
602 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
603 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
604 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
605 |
+
" BertForSequenceClassification model)."
|
606 |
+
)
|
607 |
+
else:
|
608 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
609 |
+
if len(missing_keys) > 0:
|
610 |
+
logger.warning(
|
611 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
612 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
613 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
614 |
+
)
|
615 |
+
elif len(mismatched_keys) == 0:
|
616 |
+
logger.info(
|
617 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
618 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
619 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
620 |
+
" without further training."
|
621 |
+
)
|
622 |
+
if len(mismatched_keys) > 0:
|
623 |
+
mismatched_warning = "\n".join(
|
624 |
+
[
|
625 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
626 |
+
for key, shape1, shape2 in mismatched_keys
|
627 |
+
]
|
628 |
+
)
|
629 |
+
logger.warning(
|
630 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
631 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
632 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
633 |
+
" able to use it for predictions and inference."
|
634 |
+
)
|
635 |
+
|
636 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
637 |
+
|
638 |
+
@property
|
639 |
+
def device(self) -> device:
|
640 |
+
"""
|
641 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
642 |
+
device).
|
643 |
+
"""
|
644 |
+
return get_parameter_device(self)
|
645 |
+
|
646 |
+
@property
|
647 |
+
def dtype(self) -> torch.dtype:
|
648 |
+
"""
|
649 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
650 |
+
"""
|
651 |
+
return get_parameter_dtype(self)
|
652 |
+
|
653 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
654 |
+
"""
|
655 |
+
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
656 |
+
|
657 |
+
Args:
|
658 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
659 |
+
Whether or not to return only the number of trainable parameters
|
660 |
+
|
661 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
662 |
+
Whether or not to return only the number of non-embeddings parameters
|
663 |
+
|
664 |
+
Returns:
|
665 |
+
`int`: The number of parameters.
|
666 |
+
"""
|
667 |
+
|
668 |
+
if exclude_embeddings:
|
669 |
+
embedding_param_names = [
|
670 |
+
f"{name}.weight"
|
671 |
+
for name, module_type in self.named_modules()
|
672 |
+
if isinstance(module_type, torch.nn.Embedding)
|
673 |
+
]
|
674 |
+
non_embedding_parameters = [
|
675 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
676 |
+
]
|
677 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
678 |
+
else:
|
679 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
680 |
+
|
681 |
+
|
682 |
+
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
|
683 |
+
"""
|
684 |
+
Recursively unwraps a model from potential containers (as used in distributed training).
|
685 |
+
|
686 |
+
Args:
|
687 |
+
model (`torch.nn.Module`): The model to unwrap.
|
688 |
+
"""
|
689 |
+
# since there could be multiple levels of wrapping, unwrap recursively
|
690 |
+
if hasattr(model, "module"):
|
691 |
+
return unwrap_model(model.module)
|
692 |
+
else:
|
693 |
+
return model
|
src/diffusers_/pipeline_utils.py
ADDED
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import importlib
|
18 |
+
import inspect
|
19 |
+
import os
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from pathlib import Path
|
22 |
+
from typing import Any, Dict, List, Optional, Union
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
|
27 |
+
import diffusers
|
28 |
+
import PIL
|
29 |
+
from huggingface_hub import snapshot_download
|
30 |
+
from packaging import version
|
31 |
+
from PIL import Image
|
32 |
+
from tqdm.auto import tqdm
|
33 |
+
|
34 |
+
from .configuration_utils import ConfigMixin
|
35 |
+
from .dynamic_modules_utils import get_class_from_dynamic_module
|
36 |
+
from .hub_utils import http_user_agent
|
37 |
+
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
38 |
+
from .scheduling_utils import SCHEDULER_CONFIG_NAME
|
39 |
+
from .utils import (
|
40 |
+
CONFIG_NAME,
|
41 |
+
DIFFUSERS_CACHE,
|
42 |
+
ONNX_WEIGHTS_NAME,
|
43 |
+
WEIGHTS_NAME,
|
44 |
+
BaseOutput,
|
45 |
+
deprecate,
|
46 |
+
is_accelerate_available,
|
47 |
+
is_torch_version,
|
48 |
+
is_transformers_available,
|
49 |
+
logging,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
if is_transformers_available():
|
54 |
+
import transformers
|
55 |
+
from transformers import PreTrainedModel
|
56 |
+
|
57 |
+
|
58 |
+
INDEX_FILE = "diffusion_pytorch_model.bin"
|
59 |
+
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
60 |
+
DUMMY_MODULES_FOLDER = "diffusers.utils"
|
61 |
+
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
|
62 |
+
|
63 |
+
|
64 |
+
logger = logging.get_logger(__name__)
|
65 |
+
|
66 |
+
|
67 |
+
LOADABLE_CLASSES = {
|
68 |
+
"diffusers": {
|
69 |
+
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
70 |
+
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
71 |
+
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
72 |
+
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
73 |
+
},
|
74 |
+
"transformers": {
|
75 |
+
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
76 |
+
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
77 |
+
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
78 |
+
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
79 |
+
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
80 |
+
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
81 |
+
},
|
82 |
+
"onnxruntime.training": {
|
83 |
+
"ORTModule": ["save_pretrained", "from_pretrained"],
|
84 |
+
},
|
85 |
+
}
|
86 |
+
|
87 |
+
ALL_IMPORTABLE_CLASSES = {}
|
88 |
+
for library in LOADABLE_CLASSES:
|
89 |
+
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
90 |
+
|
91 |
+
|
92 |
+
@dataclass
|
93 |
+
class ImagePipelineOutput(BaseOutput):
|
94 |
+
"""
|
95 |
+
Output class for image pipelines.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
99 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
100 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
101 |
+
"""
|
102 |
+
|
103 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class AudioPipelineOutput(BaseOutput):
|
108 |
+
"""
|
109 |
+
Output class for audio pipelines.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
audios (`np.ndarray`)
|
113 |
+
List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
|
114 |
+
denoised audio samples of the diffusion pipeline.
|
115 |
+
"""
|
116 |
+
|
117 |
+
audios: np.ndarray
|
118 |
+
|
119 |
+
|
120 |
+
class DiffusionPipeline(ConfigMixin):
|
121 |
+
r"""
|
122 |
+
Base class for all models.
|
123 |
+
|
124 |
+
[`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
|
125 |
+
and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
|
126 |
+
|
127 |
+
- move all PyTorch modules to the device of your choice
|
128 |
+
- enabling/disabling the progress bar for the denoising iteration
|
129 |
+
|
130 |
+
Class attributes:
|
131 |
+
|
132 |
+
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
|
133 |
+
components of the diffusion pipeline.
|
134 |
+
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
|
135 |
+
passed for the pipeline to function (should be overridden by subclasses).
|
136 |
+
"""
|
137 |
+
config_name = "model_index.json"
|
138 |
+
_optional_components = []
|
139 |
+
|
140 |
+
def register_modules(self, **kwargs):
|
141 |
+
# import it here to avoid circular import
|
142 |
+
from diffusers import pipelines
|
143 |
+
|
144 |
+
for name, module in kwargs.items():
|
145 |
+
# retrieve library
|
146 |
+
if module is None:
|
147 |
+
register_dict = {name: (None, None)}
|
148 |
+
else:
|
149 |
+
library = module.__module__.split(".")[0]
|
150 |
+
|
151 |
+
# check if the module is a pipeline module
|
152 |
+
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
|
153 |
+
path = module.__module__.split(".")
|
154 |
+
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
155 |
+
|
156 |
+
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
157 |
+
# Or if it's a pipeline module, then the module is inside the pipeline
|
158 |
+
# folder so we set the library to module name.
|
159 |
+
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
160 |
+
library = pipeline_dir
|
161 |
+
|
162 |
+
# retrieve class_name
|
163 |
+
class_name = module.__class__.__name__
|
164 |
+
|
165 |
+
register_dict = {name: (library, class_name)}
|
166 |
+
|
167 |
+
# save model index config
|
168 |
+
self.register_to_config(**register_dict)
|
169 |
+
|
170 |
+
# set models
|
171 |
+
setattr(self, name, module)
|
172 |
+
|
173 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
174 |
+
"""
|
175 |
+
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
|
176 |
+
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
|
177 |
+
method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
|
178 |
+
|
179 |
+
Arguments:
|
180 |
+
save_directory (`str` or `os.PathLike`):
|
181 |
+
Directory to which to save. Will be created if it doesn't exist.
|
182 |
+
"""
|
183 |
+
self.save_config(save_directory)
|
184 |
+
|
185 |
+
model_index_dict = dict(self.config)
|
186 |
+
model_index_dict.pop("_class_name")
|
187 |
+
model_index_dict.pop("_diffusers_version")
|
188 |
+
model_index_dict.pop("_module", None)
|
189 |
+
|
190 |
+
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
191 |
+
|
192 |
+
def is_saveable_module(name, value):
|
193 |
+
if name not in expected_modules:
|
194 |
+
return False
|
195 |
+
if name in self._optional_components and value[0] is None:
|
196 |
+
return False
|
197 |
+
return True
|
198 |
+
|
199 |
+
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
200 |
+
|
201 |
+
for pipeline_component_name in model_index_dict.keys():
|
202 |
+
sub_model = getattr(self, pipeline_component_name)
|
203 |
+
model_cls = sub_model.__class__
|
204 |
+
|
205 |
+
save_method_name = None
|
206 |
+
# search for the model's base class in LOADABLE_CLASSES
|
207 |
+
for library_name, library_classes in LOADABLE_CLASSES.items():
|
208 |
+
library = importlib.import_module(library_name)
|
209 |
+
for base_class, save_load_methods in library_classes.items():
|
210 |
+
class_candidate = getattr(library, base_class, None)
|
211 |
+
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
212 |
+
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
|
213 |
+
save_method_name = save_load_methods[0]
|
214 |
+
break
|
215 |
+
if save_method_name is not None:
|
216 |
+
break
|
217 |
+
|
218 |
+
if save_method_name is not None:
|
219 |
+
save_method = getattr(sub_model, save_method_name)
|
220 |
+
save_method(os.path.join(save_directory, pipeline_component_name))
|
221 |
+
|
222 |
+
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
223 |
+
if torch_device is None:
|
224 |
+
return self
|
225 |
+
|
226 |
+
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
227 |
+
for name in module_names.keys():
|
228 |
+
module = getattr(self, name)
|
229 |
+
if isinstance(module, torch.nn.Module):
|
230 |
+
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
|
231 |
+
logger.warning(
|
232 |
+
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
233 |
+
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
234 |
+
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
235 |
+
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
236 |
+
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
237 |
+
)
|
238 |
+
module.to(torch_device)
|
239 |
+
return self
|
240 |
+
|
241 |
+
@property
|
242 |
+
def device(self) -> torch.device:
|
243 |
+
r"""
|
244 |
+
Returns:
|
245 |
+
`torch.device`: The torch device on which the pipeline is located.
|
246 |
+
"""
|
247 |
+
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
248 |
+
for name in module_names.keys():
|
249 |
+
module = getattr(self, name)
|
250 |
+
if isinstance(module, torch.nn.Module):
|
251 |
+
return module.device
|
252 |
+
return torch.device("cpu")
|
253 |
+
|
254 |
+
@classmethod
|
255 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
256 |
+
r"""
|
257 |
+
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
258 |
+
|
259 |
+
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
260 |
+
|
261 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
262 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
263 |
+
task.
|
264 |
+
|
265 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
266 |
+
weights are discarded.
|
267 |
+
|
268 |
+
Parameters:
|
269 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
270 |
+
Can be either:
|
271 |
+
|
272 |
+
- A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
|
273 |
+
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
|
274 |
+
`CompVis/ldm-text2im-large-256`.
|
275 |
+
- A path to a *directory* containing pipeline weights saved using
|
276 |
+
[`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
|
277 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
278 |
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
279 |
+
will be automatically derived from the model's weights.
|
280 |
+
custom_pipeline (`str`, *optional*):
|
281 |
+
|
282 |
+
<Tip warning={true}>
|
283 |
+
|
284 |
+
This is an experimental feature and is likely to change in the future.
|
285 |
+
|
286 |
+
</Tip>
|
287 |
+
|
288 |
+
Can be either:
|
289 |
+
|
290 |
+
- A string, the *repo id* of a custom pipeline hosted inside a model repo on
|
291 |
+
https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
|
292 |
+
like `hf-internal-testing/diffusers-dummy-pipeline`.
|
293 |
+
|
294 |
+
<Tip>
|
295 |
+
|
296 |
+
It is required that the model repo has a file, called `pipeline.py` that defines the custom
|
297 |
+
pipeline.
|
298 |
+
|
299 |
+
</Tip>
|
300 |
+
|
301 |
+
- A string, the *file name* of a community pipeline hosted on GitHub under
|
302 |
+
https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
|
303 |
+
match exactly the file name without `.py` located under the above link, *e.g.*
|
304 |
+
`clip_guided_stable_diffusion`.
|
305 |
+
|
306 |
+
<Tip>
|
307 |
+
|
308 |
+
Community pipelines are always loaded from the current `main` branch of GitHub.
|
309 |
+
|
310 |
+
</Tip>
|
311 |
+
|
312 |
+
- A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
|
313 |
+
|
314 |
+
<Tip>
|
315 |
+
|
316 |
+
It is required that the directory has a file, called `pipeline.py` that defines the custom
|
317 |
+
pipeline.
|
318 |
+
|
319 |
+
</Tip>
|
320 |
+
|
321 |
+
For more information on how to load and create custom pipelines, please have a look at [Loading and
|
322 |
+
Adding Custom
|
323 |
+
Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
|
324 |
+
|
325 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
326 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
327 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
328 |
+
cached versions if they exist.
|
329 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
330 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
331 |
+
file exists.
|
332 |
+
proxies (`Dict[str, str]`, *optional*):
|
333 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
334 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
335 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
336 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
337 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
338 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
339 |
+
use_auth_token (`str` or *bool*, *optional*):
|
340 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
341 |
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
342 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
343 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
344 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
345 |
+
identifier allowed by git.
|
346 |
+
mirror (`str`, *optional*):
|
347 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
348 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
349 |
+
Please refer to the mirror site for more information. specify the folder name here.
|
350 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
351 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
352 |
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
353 |
+
same device.
|
354 |
+
|
355 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
356 |
+
more information about each option see [designing a device
|
357 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
358 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
359 |
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
360 |
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
361 |
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
362 |
+
setting this argument to `True` will raise an error.
|
363 |
+
|
364 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
365 |
+
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
366 |
+
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
367 |
+
`__init__` method. See example below for more information.
|
368 |
+
|
369 |
+
<Tip>
|
370 |
+
|
371 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
372 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
|
373 |
+
|
374 |
+
</Tip>
|
375 |
+
|
376 |
+
<Tip>
|
377 |
+
|
378 |
+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
379 |
+
this method in a firewalled environment.
|
380 |
+
|
381 |
+
</Tip>
|
382 |
+
|
383 |
+
Examples:
|
384 |
+
|
385 |
+
```py
|
386 |
+
>>> from diffusers import DiffusionPipeline
|
387 |
+
|
388 |
+
>>> # Download pipeline from huggingface.co and cache.
|
389 |
+
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
390 |
+
|
391 |
+
>>> # Download pipeline that requires an authorization token
|
392 |
+
>>> # For more information on access tokens, please refer to this section
|
393 |
+
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
|
394 |
+
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
395 |
+
|
396 |
+
>>> # Use a different scheduler
|
397 |
+
>>> from diffusers import LMSDiscreteScheduler
|
398 |
+
|
399 |
+
>>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
400 |
+
>>> pipeline.scheduler = scheduler
|
401 |
+
```
|
402 |
+
"""
|
403 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
404 |
+
resume_download = kwargs.pop("resume_download", False)
|
405 |
+
force_download = kwargs.pop("force_download", False)
|
406 |
+
proxies = kwargs.pop("proxies", None)
|
407 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
408 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
409 |
+
revision = kwargs.pop("revision", None)
|
410 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
411 |
+
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
412 |
+
provider = kwargs.pop("provider", None)
|
413 |
+
sess_options = kwargs.pop("sess_options", None)
|
414 |
+
device_map = kwargs.pop("device_map", None)
|
415 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
416 |
+
|
417 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
418 |
+
low_cpu_mem_usage = False
|
419 |
+
logger.warning(
|
420 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
421 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
422 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
423 |
+
" install accelerate\n```\n."
|
424 |
+
)
|
425 |
+
|
426 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
427 |
+
raise NotImplementedError(
|
428 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
429 |
+
" `device_map=None`."
|
430 |
+
)
|
431 |
+
|
432 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
433 |
+
raise NotImplementedError(
|
434 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
435 |
+
" `low_cpu_mem_usage=False`."
|
436 |
+
)
|
437 |
+
|
438 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
439 |
+
raise ValueError(
|
440 |
+
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
441 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
442 |
+
)
|
443 |
+
|
444 |
+
# 1. Download the checkpoints and configs
|
445 |
+
# use snapshot download here to get it working from from_pretrained
|
446 |
+
if not os.path.isdir(pretrained_model_name_or_path):
|
447 |
+
config_dict = cls.load_config(
|
448 |
+
pretrained_model_name_or_path,
|
449 |
+
cache_dir=cache_dir,
|
450 |
+
resume_download=resume_download,
|
451 |
+
force_download=force_download,
|
452 |
+
proxies=proxies,
|
453 |
+
local_files_only=local_files_only,
|
454 |
+
use_auth_token=use_auth_token,
|
455 |
+
revision=revision,
|
456 |
+
)
|
457 |
+
# make sure we only download sub-folders and `diffusers` filenames
|
458 |
+
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
|
459 |
+
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
460 |
+
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
461 |
+
|
462 |
+
# make sure we don't download flax weights
|
463 |
+
ignore_patterns = "*.msgpack"
|
464 |
+
|
465 |
+
if custom_pipeline is not None:
|
466 |
+
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
467 |
+
|
468 |
+
if cls != DiffusionPipeline:
|
469 |
+
requested_pipeline_class = cls.__name__
|
470 |
+
else:
|
471 |
+
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
472 |
+
user_agent = {"pipeline_class": requested_pipeline_class}
|
473 |
+
if custom_pipeline is not None:
|
474 |
+
user_agent["custom_pipeline"] = custom_pipeline
|
475 |
+
user_agent = http_user_agent(user_agent)
|
476 |
+
|
477 |
+
# download all allow_patterns
|
478 |
+
cached_folder = snapshot_download(
|
479 |
+
pretrained_model_name_or_path,
|
480 |
+
cache_dir=cache_dir,
|
481 |
+
resume_download=resume_download,
|
482 |
+
proxies=proxies,
|
483 |
+
local_files_only=local_files_only,
|
484 |
+
use_auth_token=use_auth_token,
|
485 |
+
revision=revision,
|
486 |
+
allow_patterns=allow_patterns,
|
487 |
+
ignore_patterns=ignore_patterns,
|
488 |
+
user_agent=user_agent,
|
489 |
+
)
|
490 |
+
else:
|
491 |
+
cached_folder = pretrained_model_name_or_path
|
492 |
+
|
493 |
+
config_dict = cls.load_config(cached_folder)
|
494 |
+
|
495 |
+
# 2. Load the pipeline class, if using custom module then load it from the hub
|
496 |
+
# if we load from explicit class, let's use it
|
497 |
+
if custom_pipeline is not None:
|
498 |
+
if custom_pipeline.endswith(".py"):
|
499 |
+
path = Path(custom_pipeline)
|
500 |
+
# decompose into folder & file
|
501 |
+
file_name = path.name
|
502 |
+
custom_pipeline = path.parent.absolute()
|
503 |
+
else:
|
504 |
+
file_name = CUSTOM_PIPELINE_FILE_NAME
|
505 |
+
import ipdb; ipdb.set_trace()
|
506 |
+
pipeline_class = get_class_from_dynamic_module(
|
507 |
+
custom_pipeline, module_file=file_name, cache_dir=custom_pipeline
|
508 |
+
)
|
509 |
+
elif cls != DiffusionPipeline:
|
510 |
+
pipeline_class = cls
|
511 |
+
else:
|
512 |
+
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
513 |
+
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
514 |
+
|
515 |
+
# To be removed in 1.0.0
|
516 |
+
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
517 |
+
version.parse(config_dict["_diffusers_version"]).base_version
|
518 |
+
) <= version.parse("0.5.1"):
|
519 |
+
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
520 |
+
|
521 |
+
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
522 |
+
|
523 |
+
deprecation_message = (
|
524 |
+
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
525 |
+
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
526 |
+
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
527 |
+
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
528 |
+
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
529 |
+
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
530 |
+
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
531 |
+
)
|
532 |
+
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
533 |
+
|
534 |
+
# some modules can be passed directly to the init
|
535 |
+
# in this case they are already instantiated in `kwargs`
|
536 |
+
# extract them here
|
537 |
+
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
538 |
+
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
539 |
+
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
540 |
+
|
541 |
+
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
542 |
+
|
543 |
+
# define init kwargs
|
544 |
+
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
545 |
+
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
546 |
+
|
547 |
+
# remove `null` components
|
548 |
+
def load_module(name, value):
|
549 |
+
if value[0] is None:
|
550 |
+
return False
|
551 |
+
if name in passed_class_obj and passed_class_obj[name] is None:
|
552 |
+
return False
|
553 |
+
return True
|
554 |
+
|
555 |
+
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
556 |
+
|
557 |
+
if len(unused_kwargs) > 0:
|
558 |
+
logger.warning(
|
559 |
+
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
560 |
+
)
|
561 |
+
|
562 |
+
# import it here to avoid circular import
|
563 |
+
from diffusers import pipelines
|
564 |
+
|
565 |
+
# 3. Load each module in the pipeline
|
566 |
+
for name, (library_name, class_name) in init_dict.items():
|
567 |
+
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
568 |
+
if class_name.startswith("Flax"):
|
569 |
+
class_name = class_name[4:]
|
570 |
+
|
571 |
+
is_pipeline_module = hasattr(pipelines, library_name)
|
572 |
+
loaded_sub_model = None
|
573 |
+
|
574 |
+
# if the model is in a pipeline module, then we load it from the pipeline
|
575 |
+
if name in passed_class_obj:
|
576 |
+
# 1. check that passed_class_obj has correct parent class
|
577 |
+
if not is_pipeline_module:
|
578 |
+
library = importlib.import_module(library_name)
|
579 |
+
class_obj = getattr(library, class_name)
|
580 |
+
importable_classes = LOADABLE_CLASSES[library_name]
|
581 |
+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
582 |
+
|
583 |
+
expected_class_obj = None
|
584 |
+
for class_name, class_candidate in class_candidates.items():
|
585 |
+
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
586 |
+
expected_class_obj = class_candidate
|
587 |
+
|
588 |
+
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
589 |
+
raise ValueError(
|
590 |
+
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
591 |
+
f" {expected_class_obj}"
|
592 |
+
)
|
593 |
+
else:
|
594 |
+
logger.warning(
|
595 |
+
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
596 |
+
" has the correct type"
|
597 |
+
)
|
598 |
+
|
599 |
+
# set passed class object
|
600 |
+
loaded_sub_model = passed_class_obj[name]
|
601 |
+
elif is_pipeline_module:
|
602 |
+
pipeline_module = getattr(pipelines, library_name)
|
603 |
+
class_obj = getattr(pipeline_module, class_name)
|
604 |
+
importable_classes = ALL_IMPORTABLE_CLASSES
|
605 |
+
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
606 |
+
else:
|
607 |
+
# else we just import it from the library.
|
608 |
+
library = importlib.import_module(library_name)
|
609 |
+
|
610 |
+
class_obj = getattr(library, class_name)
|
611 |
+
importable_classes = LOADABLE_CLASSES[library_name]
|
612 |
+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
613 |
+
|
614 |
+
if loaded_sub_model is None:
|
615 |
+
load_method_name = None
|
616 |
+
for class_name, class_candidate in class_candidates.items():
|
617 |
+
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
618 |
+
load_method_name = importable_classes[class_name][1]
|
619 |
+
|
620 |
+
if load_method_name is None:
|
621 |
+
none_module = class_obj.__module__
|
622 |
+
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
623 |
+
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
624 |
+
)
|
625 |
+
if is_dummy_path and "dummy" in none_module:
|
626 |
+
# call class_obj for nice error message of missing requirements
|
627 |
+
class_obj()
|
628 |
+
|
629 |
+
raise ValueError(
|
630 |
+
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
631 |
+
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
632 |
+
)
|
633 |
+
|
634 |
+
load_method = getattr(class_obj, load_method_name)
|
635 |
+
loading_kwargs = {}
|
636 |
+
|
637 |
+
if issubclass(class_obj, torch.nn.Module):
|
638 |
+
loading_kwargs["torch_dtype"] = torch_dtype
|
639 |
+
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
640 |
+
loading_kwargs["provider"] = provider
|
641 |
+
loading_kwargs["sess_options"] = sess_options
|
642 |
+
|
643 |
+
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
644 |
+
is_transformers_model = (
|
645 |
+
is_transformers_available()
|
646 |
+
and issubclass(class_obj, PreTrainedModel)
|
647 |
+
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
648 |
+
)
|
649 |
+
|
650 |
+
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
651 |
+
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
652 |
+
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
653 |
+
if is_diffusers_model or is_transformers_model:
|
654 |
+
loading_kwargs["device_map"] = device_map
|
655 |
+
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
656 |
+
|
657 |
+
# check if the module is in a subdirectory
|
658 |
+
if os.path.isdir(os.path.join(cached_folder, name)):
|
659 |
+
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
660 |
+
else:
|
661 |
+
# else load from the root directory
|
662 |
+
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
663 |
+
|
664 |
+
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
665 |
+
|
666 |
+
# 4. Potentially add passed objects if expected
|
667 |
+
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
668 |
+
passed_modules = list(passed_class_obj.keys())
|
669 |
+
optional_modules = pipeline_class._optional_components
|
670 |
+
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
671 |
+
for module in missing_modules:
|
672 |
+
init_kwargs[module] = passed_class_obj.get(module, None)
|
673 |
+
elif len(missing_modules) > 0:
|
674 |
+
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
675 |
+
# raise ValueError(
|
676 |
+
# f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
677 |
+
# )
|
678 |
+
|
679 |
+
# 5. Instantiate the pipeline
|
680 |
+
model = pipeline_class(**init_kwargs)
|
681 |
+
return model
|
682 |
+
|
683 |
+
@staticmethod
|
684 |
+
def _get_signature_keys(obj):
|
685 |
+
parameters = inspect.signature(obj.__init__).parameters
|
686 |
+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
687 |
+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
688 |
+
expected_modules = set(required_parameters.keys()) - set(["self"])
|
689 |
+
return expected_modules, optional_parameters
|
690 |
+
|
691 |
+
@property
|
692 |
+
def components(self) -> Dict[str, Any]:
|
693 |
+
r"""
|
694 |
+
|
695 |
+
The `self.components` property can be useful to run different pipelines with the same weights and
|
696 |
+
configurations to not have to re-allocate memory.
|
697 |
+
|
698 |
+
Examples:
|
699 |
+
|
700 |
+
```py
|
701 |
+
>>> from diffusers import (
|
702 |
+
... StableDiffusionPipeline,
|
703 |
+
... StableDiffusionImg2ImgPipeline,
|
704 |
+
... StableDiffusionInpaintPipeline,
|
705 |
+
... )
|
706 |
+
|
707 |
+
>>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
708 |
+
>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
|
709 |
+
>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
|
710 |
+
```
|
711 |
+
|
712 |
+
Returns:
|
713 |
+
A dictionaly containing all the modules needed to initialize the pipeline.
|
714 |
+
"""
|
715 |
+
expected_modules, optional_parameters = self._get_signature_keys(self)
|
716 |
+
components = {
|
717 |
+
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
|
718 |
+
}
|
719 |
+
|
720 |
+
if set(components.keys()) != expected_modules:
|
721 |
+
raise ValueError(
|
722 |
+
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
|
723 |
+
f" {expected_modules} to be defined, but {components} are defined."
|
724 |
+
)
|
725 |
+
|
726 |
+
return components
|
727 |
+
|
728 |
+
@staticmethod
|
729 |
+
def numpy_to_pil(images):
|
730 |
+
"""
|
731 |
+
Convert a numpy image or a batch of images to a PIL image.
|
732 |
+
"""
|
733 |
+
if images.ndim == 3:
|
734 |
+
images = images[None, ...]
|
735 |
+
images = (images * 255).round().astype("uint8")
|
736 |
+
if images.shape[-1] == 1:
|
737 |
+
# special case for grayscale (single channel) images
|
738 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
739 |
+
else:
|
740 |
+
pil_images = [Image.fromarray(image) for image in images]
|
741 |
+
|
742 |
+
return pil_images
|
743 |
+
|
744 |
+
def progress_bar(self, iterable):
|
745 |
+
if not hasattr(self, "_progress_bar_config"):
|
746 |
+
self._progress_bar_config = {}
|
747 |
+
elif not isinstance(self._progress_bar_config, dict):
|
748 |
+
raise ValueError(
|
749 |
+
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
750 |
+
)
|
751 |
+
|
752 |
+
return tqdm(iterable, **self._progress_bar_config)
|
753 |
+
|
754 |
+
def set_progress_bar_config(self, **kwargs):
|
755 |
+
self._progress_bar_config = kwargs
|
src/diffusers_/scheduling_utils.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import importlib
|
15 |
+
import os
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, Optional, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from .utils import BaseOutput
|
22 |
+
|
23 |
+
|
24 |
+
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class SchedulerOutput(BaseOutput):
|
29 |
+
"""
|
30 |
+
Base class for the scheduler's step function output.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
34 |
+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
35 |
+
denoising loop.
|
36 |
+
"""
|
37 |
+
|
38 |
+
prev_sample: torch.FloatTensor
|
39 |
+
|
40 |
+
|
41 |
+
class SchedulerMixin:
|
42 |
+
"""
|
43 |
+
Mixin containing common functions for the schedulers.
|
44 |
+
|
45 |
+
Class attributes:
|
46 |
+
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
47 |
+
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
48 |
+
by parent class).
|
49 |
+
"""
|
50 |
+
|
51 |
+
config_name = SCHEDULER_CONFIG_NAME
|
52 |
+
_compatibles = []
|
53 |
+
has_compatibles = True
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def from_pretrained(
|
57 |
+
cls,
|
58 |
+
pretrained_model_name_or_path: Dict[str, Any] = None,
|
59 |
+
subfolder: Optional[str] = None,
|
60 |
+
return_unused_kwargs=False,
|
61 |
+
**kwargs,
|
62 |
+
):
|
63 |
+
r"""
|
64 |
+
Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
|
65 |
+
|
66 |
+
Parameters:
|
67 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
68 |
+
Can be either:
|
69 |
+
|
70 |
+
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
71 |
+
organization name, like `google/ddpm-celebahq-256`.
|
72 |
+
- A path to a *directory* containing the schedluer configurations saved using
|
73 |
+
[`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
|
74 |
+
subfolder (`str`, *optional*):
|
75 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
76 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
77 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
78 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
79 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
80 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
81 |
+
standard cache should not be used.
|
82 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
83 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
84 |
+
cached versions if they exist.
|
85 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
86 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
87 |
+
file exists.
|
88 |
+
proxies (`Dict[str, str]`, *optional*):
|
89 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
90 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
91 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
92 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
93 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
94 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
95 |
+
use_auth_token (`str` or *bool*, *optional*):
|
96 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
97 |
+
when running `transformers-cli login` (stored in `~/.huggingface`).
|
98 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
99 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
100 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
101 |
+
identifier allowed by git.
|
102 |
+
|
103 |
+
<Tip>
|
104 |
+
|
105 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
106 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
107 |
+
|
108 |
+
</Tip>
|
109 |
+
|
110 |
+
<Tip>
|
111 |
+
|
112 |
+
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
113 |
+
use this method in a firewalled environment.
|
114 |
+
|
115 |
+
</Tip>
|
116 |
+
|
117 |
+
"""
|
118 |
+
config, kwargs = cls.load_config(
|
119 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
120 |
+
subfolder=subfolder,
|
121 |
+
return_unused_kwargs=True,
|
122 |
+
**kwargs,
|
123 |
+
)
|
124 |
+
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
125 |
+
|
126 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
127 |
+
"""
|
128 |
+
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
129 |
+
[`~SchedulerMixin.from_pretrained`] class method.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
save_directory (`str` or `os.PathLike`):
|
133 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
134 |
+
"""
|
135 |
+
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
136 |
+
|
137 |
+
@property
|
138 |
+
def compatibles(self):
|
139 |
+
"""
|
140 |
+
Returns all schedulers that are compatible with this scheduler
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
`List[SchedulerMixin]`: List of compatible schedulers
|
144 |
+
"""
|
145 |
+
return self._get_compatibles()
|
146 |
+
|
147 |
+
@classmethod
|
148 |
+
def _get_compatibles(cls):
|
149 |
+
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
|
150 |
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
151 |
+
compatible_classes = [
|
152 |
+
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
153 |
+
]
|
154 |
+
return compatible_classes
|
src/diffusers_/stable_diffusion/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import PIL
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from ..utils import (
|
10 |
+
BaseOutput,
|
11 |
+
is_torch_available,
|
12 |
+
is_transformers_available,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class StableDiffusionPipelineOutput(BaseOutput):
|
18 |
+
"""
|
19 |
+
Output class for Stable Diffusion pipelines.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
23 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
24 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
25 |
+
nsfw_content_detected (`List[bool]`)
|
26 |
+
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
27 |
+
(nsfw) content, or `None` if safety checking could not be performed.
|
28 |
+
"""
|
29 |
+
|
30 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
31 |
+
nsfw_content_detected: Optional[List[bool]]
|
32 |
+
|
33 |
+
|
34 |
+
if is_transformers_available() and is_torch_available():
|
35 |
+
from .pipeline_stable_diffusion import StableDiffusionPipeline
|
src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.47 kB). View file
|
|
src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (3.81 kB). View file
|
|
src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.48 kB). View file
|
|
src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-37.pyc
ADDED
Binary file (23.4 kB). View file
|
|
src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc
ADDED
Binary file (23.7 kB). View file
|
|
src/diffusers_/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-38.pyc
ADDED
Binary file (14.6 kB). View file
|
|
src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-38.pyc
ADDED
Binary file (9.97 kB). View file
|
|
src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-38.pyc
ADDED
Binary file (16.7 kB). View file
|
|
src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-38.pyc
ADDED
Binary file (17.3 kB). View file
|
|
src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-38.pyc
ADDED
Binary file (18.4 kB). View file
|
|